WebSocket使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。在WebSocket API中,浏览器和服务器只需要完成一次握手,两者之间就可以创建持久性的连接,并进行双向数据传输。
一、为什么需要 WebSocket?
初次接触 WebSocket 的人,都会问同样的问题:我们已经有了 HTTP 协议,为什么还需要另一个协议?它能带来什么好处?
答案很简单,因为 HTTP 协议有一个缺陷:通信只能由客户端发起。
举例来说,我们想了解今天的天气,只能是客户端向服务器发出请求,服务器返回查询结果。HTTP 协议做不到服务器主动向客户端推送信息。
这种单向请求的特点,注定了如果服务器有连续的状态变化,客户端要获知就非常麻烦。我们只能使用"轮询":每隔一段时候,就发出一个询问,了解服务器有没有新的信息。最典型的场景就是聊天室。
轮询的效率低,非常浪费资源(因为必须不停连接,或者 HTTP 连接始终打开)。因此,工程师们一直在思考,有没有更好的方法。WebSocket 就是这样发明的。
二、简介
WebSocket 协议在2008年诞生,2011年成为国际标准。所有浏览器都已经支持了。
它的最大特点就是,服务器可以主动向客户端推送信息,客户端也可以主动向服务器发送信息,是真正的双向平等对话,属于服务器推送技术的一种。
其他特点包括:
(1)建立在 TCP 协议之上,服务器端的实现比较容易。
(2)与 HTTP 协议有着良好的兼容性。默认端口也是80和443,并且握手阶段采用 HTTP 协议,因此握手时不容易屏蔽,能通过各种 HTTP 代理服务器。
(3)数据格式比较轻量,性能开销小,通信高效。
(4)可以发送文本,也可以发送二进制数据。
(5)没有同源限制,客户端可以与任意服务器通信。
(6)协议标识符是ws(如果加密,则为wss),服务器网址就是 URL。
服务端的实现
依赖spring-boot-starter-websocket
模块实现WebSocket实时对话交互。
CustomTextWebSocketHandler
,扩展的TextWebSocketHandler
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;import java.util.concurrent.CountDownLatch;/*** 文本处理器** @see org.springframework.web.socket.handler.TextWebSocketHandler*/
@Slf4j
public class CustomTextWebSocketHandler extends TextWebSocketHandler {/*** 第三方身份,消息身份*/private String thirdPartyId;/*** 回复消息内容*/private String replyContent;private StringBuilder replyContentBuilder;/*** 完成信号*/private final CountDownLatch doneSignal;public CustomTextWebSocketHandler(CountDownLatch doneSignal) {this.doneSignal = doneSignal;}public String getThirdPartyId() {return thirdPartyId;}public String getReplyContent() {return replyContent;}@Overridepublic void afterConnectionEstablished(WebSocketSession session) throws Exception {log.info("connection established, session={}", session);replyContentBuilder = new StringBuilder(16);
// super.afterConnectionEstablished(session);}@Overridepublic void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {super.handleMessage(session, message);}/*** 消息已接收完毕("stop")*/private static final String MESSAGE_DONE = "[DONE]";@Overrideprotected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
// super.handleTextMessage(session, message);String payload = message.getPayload();log.info("payload={}", payload);OpenAiReplyResponse replyResponse = Jsons.fromJson(payload, OpenAiReplyResponse.class);if (replyResponse != null && replyResponse.isSuccess()) {String msg = replyResponse.getMsg();if (Strings.isEmpty(msg)) {return;} else if (msg.startsWith("【超出最大单次回复字数】")) {// {"msg":"【超出最大单次回复字数】该提示由GPT官方返回,非我司限制,请缩减回复字数","code":1,// "extParam":"{\"chatId\":\"10056:8889007174\",\"requestId\":\"b6af5830a5a64fa8a4ca9451d7cb5f6f\",\"bizId\":\"\"}",// "id":"chatcmpl-7LThw6J9KmBUOcwK1SSOvdBP2vK9w"}return;} else if (msg.startsWith("发送内容包含敏感词")) {// {"msg":"发送内容包含敏感词,请修改后重试。不合规汇如下:炸弹","code":1,// "extParam":"{\"chatId\":\"10024:8889006970\",\"requestId\":\"828068d945c8415d8f32598ef6ef4ad6\",\"bizId\":\"430\"}",// "id":"4d4106c3-f7d4-4393-8cce-a32766d43f8b"}matchSensitiveWords = msg;// 请求完成doneSignal.countDown();return;} else if (MESSAGE_DONE.equals(msg)) {// 消息已接收完毕replyContent = replyContentBuilder.toString();thirdPartyId = replyResponse.getId();// 请求完成doneSignal.countDown();log.info("replyContent={}", replyContent);return;}replyContentBuilder.append(msg);}}@Overrideprotected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {super.handlePongMessage(session, message);}@Overridepublic void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {replyContentBuilder = null;log.info("handle transport error, session={}", session, exception);doneSignal.countDown();
// super.handleTransportError(session, exception);}@Overridepublic void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {replyContentBuilder = null;log.info("connection closed, session={}, status={}", session, status);if (status == CloseStatus.NORMAL) {log.error("connection closed fail, session={}, status={}", session, status);}doneSignal.countDown();
// super.afterConnectionClosed(session, status);}
}
OpenAiHandler
/*** OpenAI处理器*/
public interface OpenAiHandler<Req, Rsp> {/*** 请求前置处理** @param req 入参*/default void beforeRequest(Req req) {//}/*** 响应后置处理** @param req 入参* @param rsp 出参*/default void afterResponse(Req req, Rsp rsp) {//}
}
OpenAiService
/*** OpenAI服务* <pre>* API reference introduction* https://platform.openai.com/docs/api-reference/introduction* </pre>*/
public interface OpenAiService<Req, Rsp> extends OpenAiHandler<Req, Rsp> {/*** 补全指令** @param req 入参* @return 出参*/default Rsp completions(Req req) {beforeRequest(req);Rsp rsp = doCompletions(req);afterResponse(req, rsp);return rsp;}/*** 操作补全指令** @param req 入参* @return 出参*/Rsp doCompletions(Req req);
}
OpenAiServiceImpl
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;import javax.annotation.Nullable;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;/*** OpenAI服务实现*/
@Slf4j
@Configuration(proxyBeanMethods = false)
@EnableConfigurationProperties(OpenAiProperties.class)
@Service("openAiService")
public class OpenAiServiceImpl implements OpenAiService<CompletionReq, CompletionRsp> {private final OpenAiProperties properties;/*** 套接字客户端*/private final WebSocketClient webSocketClient;/*** 模型请求记录服务*/private final ModelRequestRecordService modelRequestRecordService;private static final String THREAD_NAME_PREFIX = "gpt.openai";public OpenAiServiceImpl(OpenAiProperties properties,ModelRequestRecordService modelRequestRecordService) {this.properties = properties;this.modelRequestRecordService = modelRequestRecordService;webSocketClient = WebSocketUtil.applyWebSocketClient(THREAD_NAME_PREFIX);log.info("create OpenAiServiceImpl instance");}@Overridepublic void beforeRequest(CompletionReq req) {// 请求身份if (Strings.isEmpty(req.getRequestId())) {req.setRequestId(UuidUtil.getUuid());}}@Overridepublic void afterResponse(CompletionReq req, CompletionRsp rsp) {if (rsp == null || Strings.isEmpty(rsp.getReplyContent())) {return;}// 三方敏感词检测String matchSensitiveWords = rsp.getMatchSensitiveWords();if (Strings.isNotEmpty(matchSensitiveWords)) {// 敏感词命中rsp.setMatchSensitiveWords(matchSensitiveWords);return;}// 阶段任务耗时统计StopWatch stopWatch = new StopWatch(req.getRequestId());try {// 敏感词检测stopWatch.start("checkSensitiveWord");String replyContent = rsp.getReplyContent();
// ApiResult<String> apiResult = checkMsg(replyContent, false);
// stopWatch.stop();
// if (!apiResult.isSuccess() && Strings.isNotEmpty(apiResult.getData())) {
// // 敏感词命中
// rsp.setMatchSensitiveWords(apiResult.getData());
// return;
// }// 记录落库stopWatch.start("saveModelRequestRecord");ModelRequestRecord entity = applyModelRequestRecord(req, rsp);modelRequestRecordService.save(entity);} finally {if (stopWatch.isRunning()) {stopWatch.stop();}log.info("afterResponse execute time, {}", stopWatch);}}private static ModelRequestRecord applyModelRequestRecord(CompletionReq req, CompletionRsp rsp) {Long orgId = req.getOrgId();Long userId = req.getUserId();String chatId = applyChatId(orgId, userId);return new ModelRequestRecord().setOrgId(orgId).setUserId(userId).setModelType(req.getModelType()).setRequestId(req.getRequestId()).setBizId(req.getBizId()).setChatId(chatId).setThirdPartyId(rsp.getThirdPartyId()).setInputMessage(req.getMessage()).setReplyContent(rsp.getReplyContent());}private static String applyChatId(Long orgId, Long userId) {return orgId + ":" + userId;}private static String applySessionId(String appId, String chatId) {return appId + '_' + chatId;}private static final String URI_TEMPLATE = "wss://socket.******.com/websocket/{sessionId}";@Nullable@Overridepublic CompletionRsp doCompletions(CompletionReq req) {// 阶段任务耗时统计StopWatch stopWatch = new StopWatch(req.getRequestId());stopWatch.start("doHandshake");// 闭锁,相当于一扇门(同步工具类)CountDownLatch doneSignal = new CountDownLatch(1);CustomTextWebSocketHandler webSocketHandler = new CustomTextWebSocketHandler(doneSignal);String chatId = applyChatId(req.getOrgId(), req.getUserId());String sessionId = applySessionId(properties.getAppId(), chatId);ListenableFuture<WebSocketSession> listenableFuture = webSocketClient.doHandshake(webSocketHandler, URI_TEMPLATE, sessionId);stopWatch.stop();stopWatch.start("getWebSocketSession");long connectionTimeout = properties.getConnectionTimeout().getSeconds();try (WebSocketSession webSocketSession = listenableFuture.get(connectionTimeout, TimeUnit.SECONDS)) {stopWatch.stop();stopWatch.start("sendMessage");OpenAiParam param = applyParam(chatId, req);webSocketSession.sendMessage(new TextMessage(Jsons.toJson(param)));long requestTimeout = properties.getRequestTimeout().getSeconds();// wait for all to finishboolean await = doneSignal.await(requestTimeout, TimeUnit.SECONDS);if (!await) {log.error("await doneSignal fail, req={}", req);}String replyContent = webSocketHandler.getReplyContent();String matchSensitiveWords = webSocketHandler.getMatchSensitiveWords();if (Strings.isEmpty(replyContent) && Strings.isEmpty(matchSensitiveWords)) {// 消息回复异常return null;}String delimiters = properties.getDelimiters();replyContent = StrUtil.replaceFirst(replyContent, delimiters, "");replyContent = StrUtil.replaceLast(replyContent, delimiters, "");String thirdPartyId = webSocketHandler.getThirdPartyId();return new CompletionRsp().setThirdPartyId(thirdPartyId).setReplyContent(replyContent).setMatchSensitiveWords(matchSensitiveWords);} catch (InterruptedException | ExecutionException | TimeoutException e) {log.error("get WebSocketSession fail, req={}", req, e);} catch (IOException e) {log.error("sendMessage fail, req={}", req, e);} finally {if (stopWatch.isRunning()) {stopWatch.stop();}log.info("doCompletions execute time, {}", stopWatch);}return null;}private static final int MIN_TOKENS = 11;/*** 限制单次最大回复单词数(tokens)*/private static int applyMaxTokens(int reqMaxTokens, int maxTokensConfig) {if (reqMaxTokens < MIN_TOKENS || maxTokensConfig < reqMaxTokens) {return maxTokensConfig;}return reqMaxTokens;}private OpenAiParam applyParam(String chatId, CompletionReq req) {OpenAiDataExtParam extParam = new OpenAiDataExtParam().setChatId(chatId).setRequestId(req.getRequestId()).setBizId(req.getBizId());// 提示String prompt = req.getPrompt();// 分隔符String delimiters = properties.getDelimiters();String message = prompt + delimiters + req.getMessage() + delimiters;int maxTokens = applyMaxTokens(req.getMaxTokens(), properties.getMaxTokens());OpenAiData data = new OpenAiData().setMsg(message).setContext(properties.getContext()).setLimitTokens(maxTokens).setExtParam(extParam);String sign = OpenAiUtil.applySign(message, properties.getSecret());return new OpenAiParam().setData(data).setSign(sign);}
}
WebSocketUtil
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;/*** WebSocket辅助方法*/
public final class WebSocketUtil {/*** 创建一个新的WebSocket客户端*/public static WebSocketClient applyWebSocketClient(String threadNamePrefix) {StandardWebSocketClient webSocketClient = new StandardWebSocketClient();int cpuNum = Runtime.getRuntime().availableProcessors();ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();taskExecutor.setCorePoolSize(cpuNum);taskExecutor.setMaxPoolSize(200);taskExecutor.setDaemon(true);if (StringUtils.hasText(threadNamePrefix)) {taskExecutor.setThreadNamePrefix(threadNamePrefix);} else {taskExecutor.setThreadNamePrefix("gpt.web.socket");}taskExecutor.initialize();webSocketClient.setTaskExecutor(taskExecutor);return webSocketClient;}
}
OpenAiUtil
import org.springframework.util.DigestUtils;import java.nio.charset.StandardCharsets;/*** OpenAi辅助方法*/
public final class OpenAiUtil {/*** 对消息内容进行md5加密** @param message 消息内容* @param secret 加签密钥* @return 十六进制加密后的消息内容*/public static String applySign(String message, String secret) {String data = message + secret;byte[] dataBytes = data.getBytes(StandardCharsets.UTF_8);return DigestUtils.md5DigestAsHex(dataBytes);}
}
参考资料
- WebSocket - 维基百科
- WebSocket 教程 - 阮一峰
- 使用WebSocket - 廖雪峰
- WebSocket Support - Spring Framework
- Messaging WebSockets - Spring Boot
- Create WebSocket Endpoints Using @ServerEndpoint - “How-to” Guides - Spring Boot