ChatGLM-6B流式HTTP API
本工程仿造OpneAI Chat Completion API(即GPT3.5 API)的实现,为ChatGLM-6B提供流式HTTP API。
文章目录
- ChatGLM-6B流式HTTP API
- 前言
- 一、下载代码安装环境
- 二、接口服务脚本代码
- 三、运行启动命令
- 总结
前言
现在市面上好多教chatglm-6b本地化部署,命令行部署,webui部署的,但是api部署的方式企业用的很多,官方给的api没有直接支持流式接口,调用起来时间响应很慢,这次给大家讲一下流式服务接口如何写,大大提升响应速度
一、下载代码安装环境
依赖环境
实际版本以ChatGLM-6B官方为准。但是这里需要提醒一下:
官方更新stream_chat方法后,已不能使用4.25.1的transformers包,故transformers==4.27.1。
cpm_kernel需要本机安装CUDA(若torch中的CUDA使用集成方式,如各种一键安装包时,需要注意这点。)
为了获得更好的性能,建议使用CUDA11.6或11.7配合PyTorch 1.13和torchvision 0.14.1。
我使用的是3090显卡
python版本是3.9
先安装这个protobuf>=3.18,<3.20.1transformers==4.27.1 transformers版本必须是4.27.1要不然会报错
torch安装命令用conda方式,不要使用pip要不然cpm_kernels会报错
安装命令
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
torch==1.12.1+cu113
torchvision==0.13.1
安装完以上的环境再安装下面的,保证万无一失
icetk
cpm_kernels
uvicorn==0.18.1必须这个版本,不然会报错
fastapi
二、接口服务脚本代码
from fastapi import FastAPI, Request
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import torch
from transformers import AutoTokenizer, AutoModel
import argparse
import logging
import os
import json
import sysdef getLogger(name, file_name, use_formatter=True):logger = logging.getLogger(name)logger.setLevel(logging.INFO)console_handler = logging.StreamHandler(sys.stdout)formatter = logging.Formatter('%(asctime)s %(message)s')console_handler.setFormatter(formatter)console_handler.setLevel(logging.INFO)logger.addHandler(console_handler)if file_name:handler = logging.FileHandler(file_name, encoding='utf8')handler.setLevel(logging.INFO)if use_formatter:formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s')handler.setFormatter(formatter)logger.addHandler(handler)return loggerlogger = getLogger('ChatGLM', 'chatlog.log')MAX_HISTORY = 5class ChatGLM():def __init__(self, quantize_level, gpu_id) -> None:logger.info("Start initialize model...")self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)self.model = self._model(quantize_level, gpu_id)self.model.eval()_, _ = self.model.chat(self.tokenizer, "你好", history=[])logger.info("Model initialization finished.")def _model(self, quantize_level, gpu_id):model_name = "THUDM/chatglm-6b"quantize = int(args.quantize)tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)model = Noneif gpu_id == '-1':if quantize == 8:print('CPU模式下量化等级只能是16或4,使用4')model_name = "THUDM/chatglm-6b-int4"elif quantize == 4:model_name = "THUDM/chatglm-6b-int4"model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()else:gpu_ids = gpu_id.split(",")self.devices = ["cuda:{}".format(id) for id in gpu_ids]if quantize == 16:model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().cuda()else:model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().quantize(quantize).cuda()return modeldef clear(self) -> None:if torch.cuda.is_available():for device in self.devices:with torch.cuda.device(device):torch.cuda.empty_cache()torch.cuda.ipc_collect()def answer(self, query: str, history):response, history = self.model.chat(self.tokenizer, query, history=history)history = [list(h) for h in history]return response, historydef stream(self, query, history):if query is None or history is None:yield {"query": "", "response": "", "history": [], "finished": True}size = 0response = ""for response, history in self.model.stream_chat(self.tokenizer, query, history):this_response = response[size:]history = [list(h) for h in history]size = len(response)yield {"delta": this_response, "response": response, "finished": False}logger.info("Answer - {}".format(response))yield {"query": query, "delta": "[EOS]", "response": response, "history": history, "finished": True}def start_server(quantize_level, http_address: str, port: int, gpu_id: str):os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'os.environ['CUDA_VISIBLE_DEVICES'] = gpu_idbot = ChatGLM(quantize_level, gpu_id)app = FastAPI()app.add_middleware( CORSMiddleware,allow_origins = ["*"],allow_credentials = True,allow_methods=["*"],allow_headers=["*"])@app.get("/")def index():return {'message': 'started', 'success': True}@app.post("/chat")async def answer_question(arg_dict: dict):result = {"query": "", "response": "", "success": False}try:text = arg_dict["query"]ori_history = arg_dict["history"]logger.info("Query - {}".format(text))if len(ori_history) > 0:logger.info("History - {}".format(ori_history))history = ori_history[-MAX_HISTORY:]history = [tuple(h) for h in history] response, history = bot.answer(text, history)logger.info("Answer - {}".format(response))ori_history.append((text, response))result = {"query": text, "response": response,"history": ori_history, "success": True}except Exception as e:logger.error(f"error: {e}")return result@app.post("/stream")def answer_question_stream(arg_dict: dict):def decorate(generator):for item in generator:yield ServerSentEvent(json.dumps(item, ensure_ascii=False), event='delta')result = {"query": "", "response": "", "success": False}try:text = arg_dict["query"]ori_history = arg_dict["history"]logger.info("Query - {}".format(text))if len(ori_history) > 0:logger.info("History - {}".format(ori_history))history = ori_history[-MAX_HISTORY:]history = [tuple(h) for h in history]return EventSourceResponse(decorate(bot.stream(text, history)))except Exception as e:logger.error(f"error: {e}")return EventSourceResponse(decorate(bot.stream(None, None)))@app.get("/clear")def clear():history = []try:bot.clear()return {"success": True}except Exception as e:return {"success": False}@app.get("/score")def score_answer(score: int):logger.info("score: {}".format(score))return {'success': True}logger.info("starting server...")uvicorn.run(app=app, host=http_address, port=port, debug = False)if __name__ == '__main__':parser = argparse.ArgumentParser(description='Stream API Service for ChatGLM-6B')parser.add_argument('--device', '-d', help='device,-1 means cpu, other means gpu ids', default='0')parser.add_argument('--quantize', '-q', help='level of quantize, option:16, 8 or 4', default=16)parser.add_argument('--host', '-H', help='host to listen', default='0.0.0.0')parser.add_argument('--port', '-P', help='port of this service', default=8800)args = parser.parse_args()start_server(args.quantize, args.host, int(args.port), args.device)
三、运行启动命令
python3 -u chatglm_service_fastapi.py --host 127.0.0.1 --port 8800 --quantize 8 --device 0
参数中,--device 为 -1 表示 cpu,其他数字i表示第i张卡。
根据自己的显卡配置来决定参数,--quantize 16 需要12g显存,显存小的话可以切换到4或者8
接口请求方式
流式接口,使用server-sent events技术。接口URL: http://{host_name}/stream请求方式:POST(JSON body)返回方式:使用Event Stream格式,返回服务端事件流,
事件名称:delta
数据类型:JSON
返回结果:字段名 类型 说明
delta string 产生的字符
query string 用户问题,为省流,finished为true时返回
response string 目前为止的回复,finished为true时,为完整的回复
history array[string] 会话历史,为省流,finished为true时返回
finished boolean true 表示结束,false 表示仍然有数据流。
curl 调用方式
curl --location --request POST 'http://hostname:8800/stream' \
--header 'Host: localhost:8001' \
--header 'User-Agent: python-requests/2.24.0' \
--header 'Accept: */*' \
--header 'Content-Type: application/json' \
--data-raw '{"query": "给我写个广告" ,"history": [] }'
总结
以上就是今天要讲的内容,更多大语言模型知识关注微信公众号:CV算法小屋
加我微信:Lh1141755859,加交流群,备注:进群