ChatGLM2-6B/web_demo2.py

71 lines
2.4 KiB
Python
Raw Normal View History

2023-06-26 17:07:24 +08:00
from transformers import AutoModel, AutoTokenizer
import streamlit as st
st.set_page_config(
page_title="ChatGLM2-6b 演示",
page_icon=":robot:",
layout='wide'
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
2023-06-27 14:40:44 +08:00
# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
2023-06-26 17:07:24 +08:00
model = model.eval()
return tokenizer, model
2023-07-07 16:09:00 +08:00
tokenizer, model = get_model()
2023-06-26 17:07:24 +08:00
2023-07-07 16:09:00 +08:00
st.title("ChatGLM2-6B")
2023-06-26 17:07:24 +08:00
max_length = st.sidebar.slider(
'max_length', 0, 32768, 8192, step=1
)
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.sidebar.slider(
2023-07-07 16:09:00 +08:00
'temperature', 0.0, 1.0, 0.8, step=0.01
2023-06-26 17:07:24 +08:00
)
2023-07-07 16:09:00 +08:00
if 'history' not in st.session_state:
st.session_state.history = []
if 'past_key_values' not in st.session_state:
st.session_state.past_key_values = None
2023-06-26 17:07:24 +08:00
2023-07-07 16:09:00 +08:00
for i, (query, response) in enumerate(st.session_state.history):
with st.chat_message(name="user", avatar="user"):
st.markdown(query)
with st.chat_message(name="assistant", avatar="assistant"):
st.markdown(response)
with st.chat_message(name="user", avatar="user"):
input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
message_placeholder = st.empty()
prompt_text = st.text_area(label="用户命令输入",
height=100,
placeholder="请在这儿输入您的命令")
button = st.button("发送", key="predict")
if button:
input_placeholder.markdown(prompt_text)
history, past_key_values = st.session_state.history, st.session_state.past_key_values
for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history,
past_key_values=past_key_values,
max_length=max_length, top_p=top_p,
temperature=temperature,
return_past_key_values=True):
message_placeholder.markdown(response)
st.session_state.history = history
st.session_state.past_key_values = past_key_values