ChatGLM2-6B/web_demo2.py
2023-06-27 14:40:44 +08:00

76 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import AutoModel, AutoTokenizer
import streamlit as st
from streamlit_chat import message
st.set_page_config(
page_title="ChatGLM2-6b 演示",
page_icon=":robot:",
layout='wide'
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
model = model.eval()
return tokenizer, model
MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
def predict(input, max_length, top_p, temperature, history=None):
tokenizer, model = get_model()
if history is None:
history = []
with container:
if len(history) > 0:
if len(history)>MAX_BOXES:
history = history[-MAX_TURNS:]
for i, (query, response) in enumerate(history):
message(query, avatar_style="big-smile", key=str(i) + "_user")
message(response, avatar_style="bottts", key=str(i))
message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
st.write("AI正在回复:")
with st.empty():
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
query, response = history[-1]
st.write(response)
return history
container = st.container()
# create a prompt text for the text generation
prompt_text = st.text_area(label="用户命令输入",
height = 100,
placeholder="请在这儿输入您的命令")
max_length = st.sidebar.slider(
'max_length', 0, 32768, 8192, step=1
)
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.95, step=0.01
)
if 'state' not in st.session_state:
st.session_state['state'] = []
if st.button("发送", key="predict"):
with st.spinner("AI正在思考请稍等........"):
# text generation
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])