diff --git a/README.md b/README.md index 80b5ec5..e5b8fae 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,11 @@ git clone https://github.com/THUDM/ChatGLM2-6B cd ChatGLM2-6B ``` -然后使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.30.2`,`torch` 推荐使用 2.0 以上的版本,以获得最佳的推理性能。 +然后使用 pip 安装依赖: +``` +pip install -r requirements.txt +``` +其中 `transformers` 库版本推荐为 `4.30.2`,`torch` 推荐使用 2.0 及以上的版本,以获得最佳的推理性能。 ### 代码调用 @@ -188,23 +192,17 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm2-6b ![web-demo](resources/web-demo.gif) -首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py): - +可以通过以下命令启动基于 Streamlit 的网页版 demo: ```shell -python web_demo.py +streamlit run web_demo2.py ``` 程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。 -> 默认使用了 `share=False` 启动,不会生成公网链接。如有需要公网访问的需求,可以修改为 `share=True` 启动。 -> -感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo `web_demo2.py`。使用时首先需要额外安装以下依赖: + +[web_demo.py](./web_demo.py) 中提供了旧版基于 Gradio 的 web demo,可以通过如下命令运行: ```shell -pip install streamlit streamlit-chat -``` -然后通过以下命令运行: -```shell -streamlit run web_demo2.py +python web_demo.py ``` 经测试,如果输入的 prompt 较长的话,使用基于 Streamlit 的网页版 Demo 会更流畅。 diff --git a/requirements.txt b/requirements.txt index c5c9158..265b8eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ gradio mdtex2html sentencepiece accelerate -sse-starlette \ No newline at end of file +sse-starlette +streamlit>=1.24.0 \ No newline at end of file diff --git a/resources/web-demo.gif b/resources/web-demo.gif index 5394c8a..28eb1ad 100644 Binary files a/resources/web-demo.gif and b/resources/web-demo.gif differ diff --git a/web_demo2.py b/web_demo2.py index 6c66308..203cbdc 100644 --- a/web_demo2.py +++ b/web_demo2.py @@ -1,6 +1,5 @@ from transformers import AutoModel, AutoTokenizer import streamlit as st -from streamlit_chat import message st.set_page_config( @@ -21,40 +20,9 @@ def get_model(): return tokenizer, model -MAX_TURNS = 20 -MAX_BOXES = MAX_TURNS * 2 +tokenizer, model = get_model() - -def predict(input, max_length, top_p, temperature, history=None): - tokenizer, model = get_model() - if history is None: - history = [] - - with container: - if len(history) > 0: - if len(history)>MAX_BOXES: - history = history[-MAX_TURNS:] - for i, (query, response) in enumerate(history): - message(query, avatar_style="big-smile", key=str(i) + "_user") - message(response, avatar_style="bottts", key=str(i)) - - message(input, avatar_style="big-smile", key=str(len(history)) + "_user") - st.write("AI正在回复:") - with st.empty(): - for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, - temperature=temperature): - query, response = history[-1] - st.write(response) - - return history - - -container = st.container() - -# create a prompt text for the text generation -prompt_text = st.text_area(label="用户命令输入", - height = 100, - placeholder="请在这儿输入您的命令") +st.title("ChatGLM2-6B") max_length = st.sidebar.slider( 'max_length', 0, 32768, 8192, step=1 @@ -63,13 +31,40 @@ top_p = st.sidebar.slider( 'top_p', 0.0, 1.0, 0.8, step=0.01 ) temperature = st.sidebar.slider( - 'temperature', 0.0, 1.0, 0.95, step=0.01 + 'temperature', 0.0, 1.0, 0.8, step=0.01 ) -if 'state' not in st.session_state: - st.session_state['state'] = [] +if 'history' not in st.session_state: + st.session_state.history = [] -if st.button("发送", key="predict"): - with st.spinner("AI正在思考,请稍等........"): - # text generation - st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"]) +if 'past_key_values' not in st.session_state: + st.session_state.past_key_values = None + +for i, (query, response) in enumerate(st.session_state.history): + with st.chat_message(name="user", avatar="user"): + st.markdown(query) + with st.chat_message(name="assistant", avatar="assistant"): + st.markdown(response) +with st.chat_message(name="user", avatar="user"): + input_placeholder = st.empty() +with st.chat_message(name="assistant", avatar="assistant"): + message_placeholder = st.empty() + +prompt_text = st.text_area(label="用户命令输入", + height=100, + placeholder="请在这儿输入您的命令") + +button = st.button("发送", key="predict") + +if button: + input_placeholder.markdown(prompt_text) + history, past_key_values = st.session_state.history, st.session_state.past_key_values + for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history, + past_key_values=past_key_values, + max_length=max_length, top_p=top_p, + temperature=temperature, + return_past_key_values=True): + message_placeholder.markdown(response) + + st.session_state.history = history + st.session_state.past_key_values = past_key_values