ChatGLM2-6B/web_demo2.py
2023-07-07 16:09:05 +08:00

71 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import AutoModel, AutoTokenizer
import streamlit as st
st.set_page_config(
page_title="ChatGLM2-6b 演示",
page_icon=":robot:",
layout='wide'
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
model = model.eval()
return tokenizer, model
tokenizer, model = get_model()
st.title("ChatGLM2-6B")
max_length = st.sidebar.slider(
'max_length', 0, 32768, 8192, step=1
)
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.8, step=0.01
)
if 'history' not in st.session_state:
st.session_state.history = []
if 'past_key_values' not in st.session_state:
st.session_state.past_key_values = None
for i, (query, response) in enumerate(st.session_state.history):
with st.chat_message(name="user", avatar="user"):
st.markdown(query)
with st.chat_message(name="assistant", avatar="assistant"):
st.markdown(response)
with st.chat_message(name="user", avatar="user"):
input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
message_placeholder = st.empty()
prompt_text = st.text_area(label="用户命令输入",
height=100,
placeholder="请在这儿输入您的命令")
button = st.button("发送", key="predict")
if button:
input_placeholder.markdown(prompt_text)
history, past_key_values = st.session_state.history, st.session_state.past_key_values
for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history,
past_key_values=past_key_values,
max_length=max_length, top_p=top_p,
temperature=temperature,
return_past_key_values=True):
message_placeholder.markdown(response)
st.session_state.history = history
st.session_state.past_key_values = past_key_values