文章目录
- 准备工作
- 整体思路
- 接入大模型
- 服务端和大模型连接
- 客户端和服务端的连接
- 测试
准备工作
-
到讯飞星火大模型上根据官方的提示申请tokens
申请成功后可以获得对应的secret,key还有之前创建的应用的appId,这些就是我们要用到的信息 -
搭建项目
整体思路
考虑到敏感信息等安全性问题,这里和大模型的交互都放到后端去做。
客户端,服务端,星火大模型均通过Websocket的方式建立连接,用户询问问题时向SpringBoot服务端发送消息,服务端接收到后,创建与星火大模型的连接,并访问大模型,获取到请求结果后发送给客户端
如果想实现根据上下文问答,就要把历史问题和历史回答结果全部传回大模型服务端
请求参数的构建和响应参数解析参照官方文档Web API文档
接入大模型
服务端和大模型连接
/*** 与大模型建立Socket连接** @author gwj*/
@Slf4j
public class BigModelNew extends WebSocketListener {public static final String appid = "appid";// 对话历史存储集合public static Map<Long,List<RoleContent>> hisMap = new ConcurrentHashMap<>();public static String totalAnswer = ""; // 大模型的答案汇总private static String newAsk = "";public static synchronized void ask(String question) {newAsk = question;}public static final Gson gson = new Gson();// 项目中需要用到的参数private Long userId;private Boolean wsCloseFlag;// 构造函数public BigModelNew(Long userId, Boolean wsCloseFlag) {this.userId = userId;this.wsCloseFlag = wsCloseFlag;}// 由于历史记录最大上线1.2W左右,需要判断是能能加入历史public boolean canAddHistory() {int len = 0;List<RoleContent> list = hisMap.get(userId);for (RoleContent temp : list) {len = len + temp.getContent().length();}if (len > 12000) {list.remove(0);list.remove(1);list.remove(2);list.remove(3);list.remove(4);return false;} else {return true;}}// 线程来发送参数class ModelThread extends Thread {private WebSocket webSocket;private Long userId;public ModelThread(WebSocket webSocket, Long userId) {this.webSocket = webSocket;this.userId = userId;}public void run() {try {JSONObject requestJson = new JSONObject();JSONObject header = new JSONObject(); // header参数header.put("app_id", appid);header.put("uid", userId+UUID.randomUUID().toString().substring(0,16));JSONObject parameter = new JSONObject(); // parameter参数JSONObject chat = new JSONObject();chat.put("domain", "4.0Ultra");chat.put("temperature", 0.5);chat.put("max_tokens", 4096);parameter.put("chat", chat);JSONObject payload = new JSONObject(); // payload参数JSONObject message = new JSONObject();JSONArray text = new JSONArray();// 历史问题获取List<RoleContent> list = hisMap.get(userId);if (list != null && !list.isEmpty()) {//log.info("his:{}",list);for (RoleContent tempRoleContent : list) {text.add(JSON.toJSON(tempRoleContent));}}// 最新问题RoleContent roleContent = new RoleContent();roleContent.setRole("user");roleContent.setContent(newAsk);text.add(JSON.toJSON(roleContent));hisMap.computeIfAbsent(userId, k -> new ArrayList<>());hisMap.get(userId).add(roleContent);message.put("text", text);payload.put("message", message);requestJson.put("header", header);requestJson.put("parameter", parameter);requestJson.put("payload", payload);// System.out.println(requestJson);webSocket.send(requestJson.toString());// 等待服务端返回完毕后关闭while (true) {// System.err.println(wsCloseFlag + "---");Thread.sleep(200);if (wsCloseFlag) {break;}}webSocket.close(1000, "");} catch (Exception e) {log.error("【大模型】发送消息错误,{}",e.getMessage());}}}@Overridepublic void onOpen(WebSocket webSocket, Response response) {super.onOpen(webSocket, response);log.info("上线");ModelThread modelThread = new ModelThread(webSocket,userId);modelThread.start();}@Overridepublic void onMessage(WebSocket webSocket, String text) {JsonParse json = gson.fromJson(text, JsonParse.class);if (json.getHeader().getCode() != 0) {log.error("发生错误,错误码为:{} sid为:{}", json.getHeader().getCode(),json.getHeader().getSid());//System.out.println(json);webSocket.close(1000, "");}List<Text> textList = json.getPayload().getChoices().getText();for (Text temp : textList) {// 向客户端发送回答信息,如有存储问答需求,在此处存储ModelChatEndpoint.sendMsgByUserId(userId,temp.getContent());totalAnswer = totalAnswer + temp.getContent();}if (json.getHeader().getStatus() == 2) {// 可以关闭连接,释放资源if (canAddHistory()) {RoleContent roleContent = new RoleContent();roleContent.setRole("assistant");roleContent.setContent(totalAnswer);hisMap.get(userId).add(roleContent);} else {hisMap.get(userId).remove(0);RoleContent roleContent = new RoleContent();roleContent.setRole("assistant");roleContent.setContent(totalAnswer);hisMap.get(userId).add(roleContent);}//收到响应后让等待的线程停止等待wsCloseFlag = true;}}@Overridepublic void onFailure(WebSocket webSocket, Throwable t, Response response) {super.onFailure(webSocket, t, response);try {if (null != response) {int code = response.code();System.out.println("onFailure code:" + code);System.out.println("onFailure body:" + response.body().string());if (101 != code) {System.out.println("connection failed");System.exit(0);}}} catch (IOException e) {e.printStackTrace();}}// 鉴权方法public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {URL url = new URL(hostUrl);// 时间SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);format.setTimeZone(TimeZone.getTimeZone("GMT"));String date = format.format(new Date());// 拼接String preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "GET " + url.getPath() + " HTTP/1.1";// System.err.println(preStr);// SHA256加密Mac mac = Mac.getInstance("hmacsha256");SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");mac.init(spec);byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));// Base64加密String sha = Base64.getEncoder().encodeToString(hexDigits);// System.err.println(sha);// 拼接String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);// 拼接地址HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().//addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).//addQueryParameter("date", date).//addQueryParameter("host", url.getHost()).//build();return httpUrl.toString();}
}
其中用来接收响应参数相关实体类
@Data
public class JsonParse {Header header;Payload payload;
}
@Data
public class Header {int code;int status;String sid;
}
@Data
public class Payload {Choices choices;
}
@Data
public class Choices {List<Text> text;
}
@Data
public class Text {String role;String content;
}
@Data
public class RoleContent {String role;String content;
}
客户端和服务端的连接
/*** 接收客户端请求** @author gwj* @date 2024/10/29 16:51*/
@ServerEndpoint(value = "/ws/model", configurator = GetUserConfigurator.class)
@Component
@Slf4j
public class ModelChatEndpoint {private static AtomicInteger online = new AtomicInteger(0);private static final ConcurrentHashMap<Long,ModelChatEndpoint> wsMap = new ConcurrentHashMap<>();private static BigModelConfig config;@Resourceprivate BigModelConfig modelConfig;@PostConstructpublic void init() {config = modelConfig;}private Session session;private Long userId;@OnOpenpublic void onOpen(EndpointConfig config, Session session) {String s = config.getUserProperties().get("id").toString();userId = Long.parseLong(s);this.session = session;wsMap.put(userId,this);online.incrementAndGet();log.info("用户{},连接成功,在线人数:{}",userId,online);}@OnClosepublic void onClose() {wsMap.remove(userId);online.incrementAndGet();log.info("{},退出,在线人数:{}",userId,online);}@OnErrorpublic void onError(Session session, Throwable error) {log.error("连接出错,{}", error.getMessage());}@OnMessagepublic void onMessage(String message,Session session) throws Exception {BigModelNew.ask(message);//构建鉴权urlString authUrl = BigModelNew.getAuthUrl(config.getHostUrl(), config.getApiKey(), config.getApiSecret());OkHttpClient client = new OkHttpClient.Builder().build();String url = authUrl.replace("http://", "ws://").replace("https://", "wss://");Request request = new Request.Builder().url(url).build();WebSocket webSocket = client.newWebSocket(request,new BigModelNew(this.userId, false));log.info("收到客户端{}的消息:{}", userId, message);}private void sendMsg(String message) {try {this.session.getBasicRemote().sendText(message);} catch (IOException e) {log.error("客户端{}发送{}失败",userId,message);}}/*** 根据userId向用户发送消息** @param userId 用户id* @param message 消息*/public static void sendMsgByUserId(Long userId,String message) {if (userId != null && wsMap.containsKey(userId)) {wsMap.get(userId).sendMsg(message);}}}
测试
这样就简单实现了一个ai问答功能