![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/157/84b2e7053c664b359ff064e0ca9506ec/977d841e-f1b1-4c4f-a701-4106919b62c3.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-2.png?x-oss-process=image/resize,w_50,m_lfit)
AI斗地主-量化离线部署推理模型
在大语言模型如潮水般涌现的今天,单纯刷榜评分其实意义不大,我们还是想看llms解决实际问题的能力怎么样,在没有人类专家反馈的情况下,能否自行互相打分迭代,此次我们选择了Baichuan-13B、gpt-3.5-turbo、llama-13B、ChatGLM-6B进行部署,其中ChatGLM-6B实现了类openai的后端接口,可以通过直接替换openai.api_base来调用;话不多说,我们来看看怎么部署。
llm部署的时候显存是一个非常值得考虑的事情,大部分开源模型都提供了量化选项来降低推理时的显存占用,但仍然通常需要10-20G的显存,这里我们选择开4卡V100;
在尝试了几种部署方式后,我选择采用封装openai类似的后端api来在不同的端口上提供不同的大模型服务,同时启动的时候还可以指定不同的CUDA_VISIABLE_DEVICE,让我不再受到streamlit里起一堆模型的困扰,而是简单的通过设置openai.api_base即可访问不同的模型。
这样的部署方式有几个优点,FastAPI+uuvicorn可以并发处理多个客户端发来的请求,模型本身也可以很方便的扩展到多卡上,方便随着业务容量灵活扩容,也可以直接为其他实现了chatgpt接口的应用提供服务。
大家可以在 http://39.98.39.50/ 体验一下。
Baichuan-13B-int8
下载模型
git clone https://github.com/baichuan-inc/Baichuan-13B.git
cd Baichuan-13B
pip install -r requirements.txt
streamlit run web_demo.py
如若遇到访问huggingface失败的情况,可考虑在web_demo.py中添加代理:
import os
os.environ["HTTP_PROXY"] = "http://ga.dp.tech:8118"
os.environ["HTTPS_PROXY"] = "http://ga.dp.tech:8118"
显存占用情况:
Precision | GPU Mem (GB) |
---|---|
bf16 / fp16 | 26.0 |
int8 | 15.8 |
int4 | 9.7 |
ChatGLM2-6B-int
下载部署模型
git clone https://github.com/THUDM/ChatGLM2-6B.git
cd ChatGLM2-6B
pip install -r requirements.txt
修改默认端口 由于Bohrium安全组策略,并不是所有的端口都对公网开放,我们需要修改为 50001~50005 之间:具体来说修改port部分即可
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device_map="cuda:1").cuda()
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
#from utils import load_model_on_gpus
#model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
#model.eval()
uvicorn.run(app, host='0.0.0.0', port=50001, workers=1)
API部署
python openai_api.py
Llama2-13B
获取llama的下载token: https://ai.meta.com/resources/models-and-libraries/llama-downloads/
接着会在邮箱里收到meta的来信,里面以 https://download.llamameta.net/?Policy= *开头的链接就是下面会用到的下载token
git clone https://github.com/facebookresearch/llama.git
cd llama
(base) root@bohrium-157-1021597:~/llama# bash download.sh
Enter the URL from email: https://download.llamameta.net/*?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoiP1x1MDA4ZD9cdTAwMDFcdTAyODkiLCJSZXNvdXJjZSI6Imh0dHBzOlwvXC9kb3dubG9hZC5sbGFtYW1ldGEubmV0XC8qIiwiQ29uZGl0aW9uIjp7IkRhdGVMZXNzVGhhbiI6eyJBV1M6RXBvY2hUaW1lIjoxNjkwNDUwOTYxfX19XX0_&Signature=LLo%7EkUxZwBFeltZFhQ4RUzfcuJ-1zvUGqjYHnr0i%7EfbFKDdLcY%7EfWM2GBsAHLcUt0w9nDc2ecnzQPvYj81%7E1f03mwY7csPkzMB6W-5QNMF-dOnvNnJ7u8Rg%7EJrXqmB426trMVENyM4568bUh8Kdq6to4o6BTq3YJrTqBehWEcLIcEes4T6yrauQe78qjVuydUnXsBJSSVIWbjFGSgECQlBBONqbGn31ungHzvlESiRXJbXv25sbAPzIv2qzi1BxybBShIXroK-a0jA30yfHPb4Dbs-OGGaBN5tbVybKYpn-nWqRcGmnawT-HWDsM5JRRzIcXA98bhLCtB22YHEePCA__&Key-Pair-Id=K15QRJLYKIFSLZ
Enter the list of models to download without spaces (7B,13B,70B,7B-chat,13B-chat,70B-chat), or press Enter for all: 13B-chat
测试一下模型是否可用,如果可以成功的打印返回内容,说明是OK的
import threading
import ipywidgets as widgets
from IPython.display import display
from typing import Optional
from llama import Llama
import os
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
class CustomModel:
def __init__(self, ckpt_dir="/root/llama/llama-2-7b-chat", tokenizer_path="/root/llama/tokenizer.model", temperature=0.6, top_p=0.9, max_seq_len=600, max_batch_size=4, max_gen_len=None):
# 设置参数
self.ckpt_dir = ckpt_dir
self.tokenizer_path = tokenizer_path
self.temperature = temperature
self.top_p = top_p
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.max_gen_len = max_gen_len
# 使用参数创建生成器
self.generator = Llama.build(
ckpt_dir=self.ckpt_dir,
tokenizer_path=self.tokenizer_path,
max_seq_len=self.max_seq_len,
max_batch_size=self.max_batch_size,
)
def generate_response(self, input_text):
# 使用生成器生成回复
results = self.generator.chat_completion(
input_text,
max_gen_len = self.max_gen_len,
temperature = self.temperature,
top_p = self.top_p
)
return results
if __name__ == "__main__":
custom_model = CustomModel()
input_text = [[{"role": "user", "content": "What is the meaning of life?"}]]
response = custom_model.generate_response(input_text)
print("Generated response: ", response)
后端逻辑
目录结构
(base) root@iZ8vb6e2ofgbi3044gsrkaZ:~/backend# tree
.
├── baichuan_api.py
└── llama_api.py
0 directories, 2 files
baichuan_api.py
# coding=utf-8
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
import time
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
import os
os.environ["HTTP_PROXY"] = "http://ga.dp.tech:8118"
os.environ["HTTPS_PROXY"] = "http://ga.dp.tech:8118"
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
messages = [{"role": i.role, "content": i.content} for i in request.messages]
if request.stream:
generate = predict(messages, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
print(request.messages)
print(messages)
response = model.chat(tokenizer, messages)
print(response)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(messages: List[Dict[str, str]], model_id: str):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
for new_response in model.chat(tokenizer, messages):
print(new_response)
# if len(new_response) == current_length:
# continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat",
use_fast=False,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat"
)
uvicorn.run(app, host='0.0.0.0', port=50001, workers=1)
llama_api.py
# coding=utf-8
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
import time
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
from llama import Llama
import os
os.environ["HTTP_PROXY"] = "http://ga.dp.tech:8118"
os.environ["HTTPS_PROXY"] = "http://ga.dp.tech:8118"
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
messages = [{"role": i.role, "content": i.content} for i in request.messages]
if request.stream:
generate = predict(messages, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
print(request.messages)
print(messages)
response = model.chat_completion([messages])
print(response)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response[0]["generation"]["content"]),
finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(messages: List[Dict[str, str]], model_id: str):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
for new_response in model.chat_completion( [messages]):
print(new_response)
# if len(new_response) == current_length:
# continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__":
model = Llama.build(
ckpt_dir="/root/llama/llama-2-7b-chat",
tokenizer_path="/root/llama/tokenizer.model",
max_seq_len=1024,
max_batch_size=4
)
uvicorn.run(app, host='0.0.0.0', port=50005, workers=1)
ChatGLM2-6B 我们直接用其提供的openai_api.py接口就可以,记得修改端口
# coding=utf-8
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
import time
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
generate = predict(query, history, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
for new_response, _ in model.stream_chat(tokenizer, query, history):
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device_map="cuda:1").cuda()
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
#from utils import load_model_on_gpus
#model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
#model.eval()
uvicorn.run(app, host='0.0.0.0', port=50001, workers=1)
此外,请注意 pydantic
的版本需要时 1.10.12
,如果太新会报错。
Streamlit 前端逻辑
import json
import torch
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
import openai
import threading
import ipywidgets as widgets
from IPython.display import display
from typing import Optional
from llama import Llama
import os
class LlamaModel:
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
def generate_response(self, messages):
openai.api_base = "http://39.98.39.50:50005/v1"
openai.api_key = "none"
openai.api_type = "open_ai"
openai.api_version = None
result = openai.ChatCompletion.create(
model="llama2-2-7b-chat",
messages=messages,
stream=False
)
print(result)
return result.choices[0].message.content
class BaichuanModel:
def generate_response(self, messages):
openai.api_base = "http://47.92.248.156:50001/v1"
openai.api_key = "none"
openai.api_type = "open_ai"
openai.api_version = None
result = openai.ChatCompletion.create(
model="baichuan-13b",
messages=messages,
stream=False
)
print(result)
return result.choices[0].message.content
class ChatGLMModel:
def generate_response(self, messages):
import openai
openai.api_base = "http://39.98.39.50:50001/v1"
openai.api_key = "none"
openai.api_type = "open_ai"
openai.api_version = None
print(openai.api_key)
print(openai.api_type)
print(openai.api_version)
# openai.api_type = ""
# openai.api_version = ""
result = ""
for chunk in openai.ChatCompletion.create(
model="chatglm2-6b",
messages=messages,
stream=True
):
if hasattr(chunk.choices[0].delta, "content"):
result += chunk.choices[0].delta.content
# print(chunk.choices[0].delta.content, end="", flush=True)
print(result)
return result
class ChatGPTModel:
def generate_response(self, messages):
openai.api_base = "https://your-domain.openai.azure.com/"
openai.api_key = ""
openai.api_type = "azure"
openai.api_version = "2023-05-15"
result = ""
result = openai.ChatCompletion.create(
engine="gpt-35-turbo",
messages=messages,
stream=False
)
print(result)
return result.choices[0].message.content
st.set_page_config(page_title="Chat-AI斗地主")
st.title("Chat-AI斗地主")
@st.cache_resource
def init_model():
llama_model = LlamaModel()
baichuan_model = BaichuanModel()
chatglm_model = ChatGLMModel()
chatgpt_model = ChatGPTModel()
return llama_model, baichuan_model, chatglm_model, chatgpt_model
def clear_chat_history():
del st.session_state.messages_llama
del st.session_state.messages_baichuan
del st.session_state.messages_chatgpt
del st.session_state.messages_chatglm
del st.session_state.messages # 用于前端展示?
def init_chat_history():
with st.chat_message("llama", avatar='🚶♂️'):
st.markdown("您好,我是llama2,很高兴为您服务🥰")
with st.chat_message("baichuan", avatar='👀'):
st.markdown("您好,我是百川大模型,很高兴为您服务🥰")
with st.chat_message("chatgpt", avatar='✋'):
st.markdown("您好,我是gpt-3.5-turbo,很高兴为您服务🥰")
with st.chat_message("chatglm", avatar='😡'):
st.markdown("您好,chatglm-6B,很高兴为您服务🥰")
if "messages" in st.session_state:
for message in st.session_state.messages:
if message["role"] == "user":
avatar = '🧑💻'
elif message["role"] == "llama":
avatar = "🚶♂️"
elif message["role"] == "baichuan":
avatar = "👀"
elif message["role"] == "chatgpt":
avatar = "✋"
elif message["role"] == "chatglm":
avatar = "😡"
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
else:
st.session_state.messages = []
if "messages_baichuan" not in st.session_state:
st.session_state.messages_baichuan = []
if "messages_chatgpt" not in st.session_state:
st.session_state.messages_chatgpt = []
if "messages_chatglm" not in st.session_state:
st.session_state.messages_chatglm = []
if "messages_llama" not in st.session_state:
st.session_state.messages_llama = []
return st.session_state.messages, st.session_state.messages_baichuan, st.session_state.messages_llama, st.session_state.messages_chatglm, st.session_state.messages_chatgpt
def main():
llama_model, baichuan_model, chatglm_model, chatgpt_model = init_model()
messages, messages_baichuan, messages_llama, messages_chatglm, messages_chatgpt = init_chat_history()
if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
with st.chat_message("user", avatar='🧑💻'):
st.markdown(prompt)
messages.append({"role": "user", "content": prompt})
messages_baichuan.append({"role": "user", "content": prompt})
messages_chatglm.append({"role": "user", "content": prompt})
messages_chatgpt.append({"role": "user", "content": prompt})
messages_llama.append({"role": "user", "content": prompt})
print(f"[user] {prompt}", flush=True)
with st.chat_message("llama", avatar='🚶♂️'):
placeholder = st.empty()
print(messages_llama)
response = llama_model.generate_response(messages_llama)
placeholder.markdown(response)
messages_llama.append({"role": "assistant", "content": response})
messages.append({"role": "llama", "content": response})
with st.chat_message("baichuan", avatar='👀'):
placeholder = st.empty()
response = baichuan_model.generate_response(messages_baichuan)
placeholder.markdown(response)
messages_baichuan.append({"role": "assistant", "content": response})
messages.append({"role": "baichuan", "content": response})
with st.chat_message("chatglm", avatar='😡'):
placeholder = st.empty()
response = chatglm_model.generate_response(messages_chatglm)
placeholder.markdown(response)
messages_chatglm.append({"role": "assistant", "content": response})
messages.append({"role": "chatglm", "content": response})
with st.chat_message("chatgpt", avatar='✋'):
placeholder = st.empty()
response = chatgpt_model.generate_response(messages_chatgpt)
placeholder.markdown(response)
messages_chatgpt.append({"role": "assistant", "content": response})
messages.append({"role": "chatgpt", "content": response})
print(json.dumps(messages, ensure_ascii=False), flush=True)
st.button("清空对话", on_click=clear_chat_history)
if __name__ == "__main__":
# chatgpt_model = ChatGPTModel()
# messages=[
# {"role": "user", "content": "如何评价文化大革命"}
# ]
# chatgpt_model.generate_response(messages)
# messages = [
# {'role': 'user', 'content': '你好'},
# {'role': 'assistant', 'content': ' Hello! 你好 (nǐ hǎo) means "Hello" in Chinese. It\'s great to meet you! I\'m here to help answer any questions you may have, while being safe and respectful. Please feel free to ask me anything, and I\'ll do my best to provide helpful and accurate information. If a question doesn\'t make sense or is not factually coherent, I\'ll explain why instead of answering something not correct. And if I don\'t know the answer to a question, I\'ll let you know instead of sharing false information. Please go ahead and ask me anything!'},
# {'role': 'user', 'content': '你是谁'}
# ]
# llama_model = LlamaModel()
# llama_model.generate_response(messages)
# chatgpt_model = ChatGPTModel()
# chatgpt_model.generate_response(messages)
# chatglm_model = ChatGLMModel()
# chatglm_model.generate_response(messages)
# baichuan_model = BaichuanModel()
# baichuan_model.generate_response(messages)
main()
下一步TODO
- 实现stream流式对话
- 目前只支持单轮对话,如果带着历史信息,llama很快会爆显存,未来考虑可以换80G的A100部署、或者量化部署,减少显存占用
![](https://cdn1.deepmd.net/static/img/d7d9741bda38a158-957c-4877-942f-4bf6f81fcc63.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-1.png?x-oss-process=image/resize,w_50,m_lfit)
![](https://cdn1.deepmd.net/static/img/d7d9741bda38a158-957c-4877-942f-4bf6f81fcc63.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-1.png?x-oss-process=image/resize,w_50,m_lfit)
yufeng
Roger
dingzh@dp.tech