一共有4个类,第一个WebSocketConfig 配置类
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {@Autowiredprivate WebSocketHandler webSocketHandler;@Autowiredprivate WebSocketInterceptor webSocketInterceptor;@Overridepublic void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {registry.addHandler(webSocketHandler, "/ws").addInterceptors(webSocketInterceptor).setAllowedOrigins("*");}
}
第二个,拦截器,这里我区分了pc和app,因为代码需求是同一个id登录的用户要在pc端和app端同时连接websocket,为做区分,在pc的userid后面加了pc两个字母.
@Component
public class WebSocketInterceptor implements HandshakeInterceptor {private final Logger logger = LoggerFactory.getLogger(WebSocketInterceptor.class);@Resourceprivate ISysUserService userService;/*** 握手前* @param request 请求对象* @param response 响应对象* @param wsHandler 请求处理器* @param attributes 属性域* @return true放行,false拒绝* @throws Exception 可能抛出的异常*/@Overridepublic boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, org.springframework.web.socket.WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {// 获得请求参数Map<String, String> paramMap = HttpUtil.decodeParamMap(request.getURI().getQuery(), Charset.defaultCharset());String userId = paramMap.get("userId");if (CharSequenceUtil.isNotBlank(userId)) {if (userId.endsWith("pc")){
// String substring = userId.substring(0, userId.length() - 2);
// // 校验连接人在系统是否存在
// SysUser user = userService.selectUserById(Long.valueOf(substring));
// if (user == null) {
// response.setStatusCode(HttpStatus.UNAUTHORIZED);
// return false;
// }}else {// 校验连接人在系统是否存在SysUser user = userService.selectUserById(Long.valueOf(userId));if (user == null) {response.setStatusCode(HttpStatus.UNAUTHORIZED);return false;}}// 放入属性域attributes.put("userId", userId);logger.info("用户:{}握手成功!", userId);return true;} else {logger.info("接受到一个websocket连接请求但是没有参数!");}return false;}/*** 握手后** @param request 请求独享* @param response 响应对象* @param wsHandler 处理器* @param exception 抛出的异常*/@Overridepublic void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, org.springframework.web.socket.WebSocketHandler wsHandler, Exception exception) {logger.info("握手结束!");}
}
第三个是管理器,其中的add方法,本身是有一个判重机制,如果该连接已存在就把原来的踢下线,重新连接新的,防止出现多个同样的id的问题.但是这又导致了新的频繁关闭重连的问题,所以后来改成了如果已经存在就直接return
@Slf4j
public class WsSessionManager {private WsSessionManager() {}private static final Logger logger = LoggerFactory.getLogger(WsSessionManager.class);/*** 记录当前在线连接数*/private static AtomicInteger onlineCount = new AtomicInteger(0);/*** 保存连接 session 的地方*/private static final ConcurrentHashMap<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>(99999);/*** 添加 session** @param key 键* @param session 值*/public static synchronized void add(String key, WebSocketSession session) {WebSocketSession existingSession = SESSION_POOL.get(key);if (existingSession != null) {if (existingSession.equals(session)) {logger.info("用户 {} 的 WebSocket 已存在,无需重复添加", key);return;}
// if (existingSession.isOpen()) {
// try {
// existingSession.close();
// logger.info("关闭旧的连接, userId: {}", key);
// } catch (IOException e) {
// logger.error("关闭旧的连接时出现异常, userId: {}, 异常: {}", key, e.getMessage());
// }
// }if (existingSession.isOpen()) return;}SESSION_POOL.put(key, session);onlineCount.incrementAndGet();logger.info("新连接已添加, userId: {}, 当前在线人数: {}", key, getOnlineCount());}/*** 删除 session, 会返回删除的 session** @param key 键* @return 值*/public static synchronized WebSocketSession remove(String key) {WebSocketSession session = SESSION_POOL.remove(key);if (session != null) {onlineCount.decrementAndGet();logger.info("连接已移除, userId: {}, 当前在线人数: {}", key, getOnlineCount());}return session;}/*** 删除并同步关闭连接** @param key 键*/public static synchronized void removeAndClose(String key) {WebSocketSession session = remove(key);if (session != null) {try {session.close();logger.warn("关闭WebSocket会话, userId: {}", key);} catch (IOException e) {logger.error("关闭会话时出现异常, userId: {}, 异常: {}, {}", key, e.getMessage(), e);}}}/*** 获得 session** @param key 键* @return 值*/public static WebSocketSession get(String key) {return SESSION_POOL.get(key);}/*** 获取当前在线连接数** @return 在线连接数*/public static int getOnlineCount() {return onlineCount.get();}/*** 获得 Map** @return 值*/public static ConcurrentMap<String, WebSocketSession> getMap() {return SESSION_POOL;}
}
第四个是真正发送消息的处理器
@Component
public class WebSocketHandler extends TextWebSocketHandler {private final Logger logger = LoggerFactory.getLogger(WebSocketHandler.class);private static final String KEY = "userId";/*** socket 建立成功事件* @param session session对象*/@Overridepublic void afterConnectionEstablished(WebSocketSession session) {Object userId = session.getAttributes().get(KEY);if (userId != null) {// 将用户的连接放入 WsSessionManager,会自动关闭之前的旧连接WsSessionManager.add(userId.toString(), session);logger.info("用户连接成功, userId: {}", userId);} else {logger.warn("未能在连接中找到 userId 属性");}logger.info("建立连接了, 当前在线人数: {}, session: {}, 当前map: {}", WsSessionManager.getOnlineCount(), session, WsSessionManager.getMap());}/*** 接收消息事件** @param session session对象* @param message 接收到的消息* @throws Exception 可能抛出的异常*/@Overrideprotected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {// 获得客户端传来的消息
// String payload = message.getPayload();logger.info("收到ws消息: {}", message);// 返回一条确认消息给发消息的用户TextMessage responseMessage = new TextMessage("pong");session.sendMessage(responseMessage);}/*** socket 断开连接时** @param session session对象* @param status 断开状态* @throws Exception 可能抛出的异常*/@Overridepublic void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {logger.info("断开连接了,session为{}", session == null ? "" : session);Object token = session.getAttributes().get(KEY);if (token != null) {// 用户退出,移除缓存WsSessionManager.removeAndClose(token.toString());}}/*** 发送消息给指定设备** @param serialNumber 序列号* @param message 消息内容* @param type 1跳通知 2跳客户 3手机打电话 前端用 8 pc用未读消息数量 9 当前app是否在线,true或者false* @param noticeId 通知id,已读用*/public void sendMessage(String serialNumber, String message, Integer type, Long noticeId) {WebSocketSession webSocketSession = WsSessionManager.get(serialNumber);try {if (webSocketSession != null && webSocketSession.isOpen()) {JSONObject jsonObject;jsonObject = JSONObject.of("type", type, "value", message, "noticeId", noticeId);webSocketSession.sendMessage(new TextMessage(jsonObject.toString()));logger.info("发送消息给{},消息内容为{}", serialNumber, message);}} catch (Exception e) {logger.error("消息发送失败,设备{},失败原因{}{}", webSocketSession.getAttributes().get(KEY), e.getMessage(), e);}}/*** 发送消息给指定设备** @param serialNumber 序列号* @param message 消息内容* @param type 1跳通知 2跳客户 3手机打电话 前端用 8 pc用未读消息数量 9 当前app是否在线,true或者false* @param notice 通知整个对象*/public void sendMessage(String serialNumber, String message, Integer type, ClientNoticeDO notice, Integer other) {WebSocketSession webSocketSession = WsSessionManager.get(serialNumber);try {if (webSocketSession != null && webSocketSession.isOpen()) {JSONObject jsonObject = JSONObject.of("type", type, "value", message, "notice", notice);webSocketSession.sendMessage(new TextMessage(jsonObject.toString()));logger.info("发送消息给{},消息内容为{}", serialNumber, message);} else {logger.warn("WebSocket 会话不可用, userId: {}", serialNumber);}} catch (IOException e) {logger.error("WebSocket 消息发送失败, userId: {}, 原因: {}", serialNumber, e.getMessage(), e);WsSessionManager.remove(serialNumber); // 自动移除无效会话} catch (Exception e) {logger.error("消息发送时发生未知错误, userId: {}, 原因: {}", serialNumber, e.getMessage(), e);}}/*** 广播消息** @param message 消息*/public void sendMessageAll(String message) {WsSessionManager.getMap().keySet().forEach(e -> sendMessage(e, message, 2, (Long) null));}
}
其中的sendMessage方法根据自己的业务需求有一个重载方法,正常一个sendMessage就足够了.日志相关的酌情增减.
心跳:在handleTextMessage方法中,接收到前端任何消息都返回一个pong,前端如果一段时间未收到pong就会发起重连,以此保证连接不中断.如果业务有前端发来的其他消息则加个if判断即可.
最终使用的时候注入
@Autowired
private WebSocketHandler webSocketHandler;//然后调用webSocketHandler.sendMessage(XXX,XXX,XXX)//即可.
连接的地址:ws://IP:端口/?userId=1
其中/ws是在WebSocketConfig配置的,
userId是在WebSocketHandler配置的KEY
最后附上在线连接websocket测试的网站:http://www.websocket-test.com/
以及相关可以直接测试的idea插件:CoolRequest