本文功能实现较为简陋,demo内容仅供参考,有不足之处还请指正。
背景
一个小项目,用于微信小程序的服务端,需要实现小程序端可以和他人1对1聊天
实现功能
Websocket、心跳检测、消息持久化、离线消息存储
Netty配置类
/*** @author Aseubel*/
@Component
@Slf4j
@EnableConfigurationProperties(NettyServerConfigProperties.class)
public class NettyServerConfig {private ChannelFuture serverChannelFuture;// 心跳间隔(秒)private static final int HEARTBEAT_INTERVAL = 15;// 读超时时间private static final int READ_TIMEOUT = HEARTBEAT_INTERVAL * 2;// 使用线程池管理private final EventLoopGroup bossGroup = new NioEventLoopGroup(1);private final EventLoopGroup workerGroup = new NioEventLoopGroup();private final NettyServerConfigProperties properties;// 由于在后面的handler中有依赖注入类,所以要通过springboot的ApplicationContext来获取Bean实例@Autowiredprivate ApplicationContext applicationContext;public NettyServerConfig(NettyServerConfigProperties properties) {this.properties = properties;}@PostConstructpublic void startNettyServer() {// 使用独立线程启动Netty服务new Thread(() -> {try {ServerBootstrap bootstrap = new ServerBootstrap();bootstrap.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<Channel>() {@Overrideprotected void initChannel(Channel ch) throws Exception {ChannelPipeline pipeline = ch.pipeline();SSLContext sslContext = SslUtil.createSSLContext("PKCS12",properties.getSslPath(), properties.getSslPassword());// SSLEngine 此类允许使用ssl安全套接层协议进行安全通信SSLEngine engine = sslContext.createSSLEngine();engine.setUseClientMode(false);pipeline.addLast(new SslHandler(engine)); // 设置SSLpipeline.addLast(new HttpServerCodec());pipeline.addLast(new HttpObjectAggregator(10 * 1024 * 1024));// 最大10MBpipeline.addLast(new ChunkedWriteHandler());pipeline.addLast(new HttpHandler());// 只有text和binarytext的帧能经过WebSocketServerProtocolHandler,所以心跳检测这两个都得放前面pipeline.addLast(new IdleStateHandler(READ_TIMEOUT, 0, 0, TimeUnit.SECONDS));pipeline.addLast(new HeartbeatHandler());pipeline.addLast(new WebSocketServerProtocolHandler("/ws", null, true, 10 * 1024 * 1024));pipeline.addLast(applicationContext.getBean(MessageHandler.class));pipeline.addLast(new ChannelInboundHandlerAdapter() {@Overridepublic void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {// 统一处理所有未被前面handler捕获的异常log.error("全局异常捕获: {}", cause.getMessage());ctx.channel().close();}});}});serverChannelFuture = bootstrap.bind(properties.getPort()).sync();// 保持通道开放serverChannelFuture.channel().closeFuture().sync();} catch (InterruptedException e) {Thread.currentThread().interrupt();}}).start();}@PreDestroypublic void stopNettyServer() {// 优雅关闭if (serverChannelFuture != null) {serverChannelFuture.channel().close();}bossGroup.shutdownGracefully();workerGroup.shutdownGracefully();}}
Handler
心跳检测
/*** @author Aseubel*/
public class HeartbeatHandler extends ChannelInboundHandlerAdapter {private static final int HEARTBEAT_INTERVAL = 15; // 心跳间隔(秒)private static final int MAX_MISSED_HEARTBEATS = 2; // 允许丢失的心跳次数// 记录每个连接的丢失心跳次数private final Map<ChannelId, Integer> missedHeartbeats = new ConcurrentHashMap<>();@Overridepublic void channelActive(ChannelHandlerContext ctx) {// 添加 IdleStateHandler 触发读空闲事件ctx.pipeline().addLast(new IdleStateHandler(HEARTBEAT_INTERVAL * MAX_MISSED_HEARTBEATS, 0, 0));scheduleHeartbeat(ctx);}private void scheduleHeartbeat(ChannelHandlerContext ctx) {ctx.executor().scheduleAtFixedRate(() -> {if (ctx.channel().isActive()) {ctx.writeAndFlush(new PingWebSocketFrame(Unpooled.copiedBuffer("HEARTBEAT", CharsetUtil.UTF_8)));// 记录丢失的心跳次数missedHeartbeats.compute(ctx.channel().id(), (k, v) -> v == null ? 1 : v + 1);}}, HEARTBEAT_INTERVAL, HEARTBEAT_INTERVAL, TimeUnit.SECONDS);}@Overridepublic void channelRead(ChannelHandlerContext ctx, Object msg) {if (msg instanceof PongWebSocketFrame) {// 收到 Pong 后重置丢失计数missedHeartbeats.remove(ctx.channel().id());ctx.fireChannelRead(msg); // 传递消息给后续处理器} else {ctx.fireChannelRead(msg);}}@Overridepublic void userEventTriggered(ChannelHandlerContext ctx, Object evt) {if (evt instanceof IdleStateEvent) {int missed = missedHeartbeats.getOrDefault(ctx.channel().id(), 0);if (missed >= MAX_MISSED_HEARTBEATS) {// 超过最大丢失次数,关闭连接System.out.println("连接超时,关闭连接" + ctx.channel().id().asLongText());ctx.close();cleanOfflineResources(ctx.channel());}}}private void cleanOfflineResources(Channel channel) {MessageHandler.removeUserChannel(channel);missedHeartbeats.remove(channel.id());}
}
处理http请求,建立连接
/*** @author Aseubel* @description 处理websocket连接请求,将code参数存入channel的attribute中* @date 2025-02-21 15:34*/
public class HttpHandler extends ChannelInboundHandlerAdapter {public static final AttributeKey<String> WS_TOKEN_KEY = AttributeKey.valueOf("code");@Overridepublic void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {// 判断是否是连接请求if (msg instanceof FullHttpRequest) {FullHttpRequest request = (FullHttpRequest) msg;try {QueryStringDecoder decoder = new QueryStringDecoder(request.uri());ctx.channel().attr(WS_TOKEN_KEY).set(decoder.parameters().get("code").get(0));} catch (Exception e) {throw new AppException("非法的websocket连接请求");}// 将 FullHttpRequest 转发到 MessageHandlerctx.fireChannelRead(request);// 重新设置 uri,将请求转发到 websocket handler,否则无法成功建立连接request.setUri("/ws");}// 消息直接交给下一个 handlersuper.channelRead(ctx, msg);}}
消息处理
/*** @author Aseubel* @description 处理 WebSocket 消息* @date 2025-02-21 15:33*/
@Component
@Slf4j
@Sharable
public class MessageHandler extends SimpleChannelInboundHandler<WebSocketFrame> {public static final AttributeKey<String> WS_TOKEN_KEY = AttributeKey.valueOf("code");public static final AttributeKey<String> WS_USER_ID_KEY = AttributeKey.valueOf("userId");private static final Map<String, Queue<WebSocketFrame>> OFFLINE_MSGS = new ConcurrentHashMap<>();private static final Map<String, Channel> userChannels = new ConcurrentHashMap<>();@Autowiredprivate ThreadPoolTaskExecutor threadPoolExecutor;@Resourceprivate IMessageRepository messageRepository;// 提供受控的访问方法public static void removeUserChannel(Channel channel) {userChannels.values().remove(channel);}public static boolean containsUser(String userId) {return userChannels.containsKey(userId);}@Overridepublic void channelActive(ChannelHandlerContext ctx) throws Exception {super.channelActive(ctx);}@Overridepublic void channelRead(ChannelHandlerContext ctx, Object req) throws Exception {if (req instanceof FullHttpRequest) {String code = getCodeFromRequest(ctx); // 从请求中提取 codeString userId = getOpenid(APPID, SECRET, code); // 验证 code 获取 openiduserChannels.put(userId, ctx.channel());ctx.channel().attr(WS_USER_ID_KEY).set(userId);System.out.println("客户端连接成功,用户id:" + userId);// 由于这里还在处理握手请求也就是建立连接,所以需要延迟发送离线消息new Thread(() -> {try {Thread.sleep(50);OFFLINE_MSGS.getOrDefault(userId, new LinkedList<>()).forEach(ctx::writeAndFlush);OFFLINE_MSGS.remove(userId);} catch (InterruptedException e) {Thread.currentThread().interrupt();}}).start();} else if (req instanceof TextWebSocketFrame ) {this.channelRead0(ctx, (TextWebSocketFrame) req);} else {ctx.fireChannelRead(req);}}@Overrideprotected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {if (frame instanceof TextWebSocketFrame) {MessageEntity message = validateMessage(ctx.channel().attr(WS_USER_ID_KEY).get(), (TextWebSocketFrame) frame);saveMessage(message);sendOrStoreMessage(message.getToUserId(), frame);} else {ctx.close();}}// 处理连接断开@Overridepublic void channelInactive(ChannelHandlerContext ctx) {System.out.println("客户端断开连接,用户id:" + ctx.channel().attr(WS_USER_ID_KEY).get());Channel channel = ctx.channel();for (Map.Entry<String, Channel> entry : userChannels.entrySet()) {if (entry.getValue() == channel) {userChannels.remove(entry.getKey());break;}}}private MessageEntity validateMessage(String userId, TextWebSocketFrame textFrame) {String message = textFrame.text();try {JsonObject json = JsonParser.parseString(message).getAsJsonObject();String toUserId = json.get("toUserId").getAsString();String content = json.get("content").getAsString();String type = json.get("type").getAsString();if (type.equals("text") || type.equals("image")) {return new MessageEntity(userId, toUserId, content, type);} else {throw new AppException("非法的消息类型!");}} catch (Exception e) {throw new AppException("非法的消息格式!");}}private void sendOrStoreMessage(String toUserId, WebSocketFrame message) {if (isUserOnline(toUserId)) {Channel targetChannel = userChannels.get(toUserId);if (targetChannel != null && targetChannel.isActive()) {targetChannel.writeAndFlush(message.retain());}} else {// 存储原始WebSocketFrame(需保留引用)OFFLINE_MSGS.computeIfAbsent(toUserId, k -> new LinkedList<>()).add(message.retain());}}private void saveMessage(MessageEntity message) {threadPoolExecutor.execute(() -> {messageRepository.saveMessage(message);});}private boolean isUserOnline(String userId) {return userChannels.containsKey(userId);}private String getCodeFromRequest(ChannelHandlerContext ctx) {String code = ctx.channel().attr(WS_TOKEN_KEY).get();// 检查 code 参数是否存在且非空if (code == null || code.isEmpty()) {throw new IllegalArgumentException("WebSocket token is missing or empty");}return code;}private String getOpenid(String appid, String secret, String code) {Map<String, String> paramMap = new HashMap<>();paramMap.put("appid", appid);paramMap.put("secret", secret);paramMap.put("js_code", code);paramMap.put("grant_type", "authorization_code");String result = HttpClientUtil.doGet(WX_LOGIN, paramMap);//获取请求结果JSONObject jsonObject = JSON.parseObject(result);String openid = jsonObject.getString("openid");//判断openid是否存在if (StringUtils.isEmpty(openid)) {throw new WxException(jsonObject.getString("errcode"), jsonObject.getString("errmsg"));}return openid;}@Overridepublic void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {if (cause instanceof AppException appCause) {log.error("AppException caught: {}", appCause.getInfo());} else if (cause instanceof WxException wxCause) {log.error("WxException caught: {}", wxCause.getMessage());} else {log.error("Exception caught: {}", cause.getMessage(), cause);}ctx.close(); // 建议关闭发生异常的连接}}
连接及消息格式:
wss://127.0.0.1:21611/ws?code=xxxxxx{"toUserId": "1001","type": "text","content": "Hello World!"
}
规定了type只有text和image两种,text为文本content,image则为Base64编码格式
本文功能实现较为简陋,demo内容仅供参考,可能有注释错误或设计不合理的地方