Merge pull request #152 from mougua/main

fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到
This commit is contained in:
Zhengxiao Du 2023-07-04 12:07:01 +08:00 committed by GitHub
commit b99e3d74c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 6 deletions

View File

@ -11,9 +11,9 @@ from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager @asynccontextmanager
@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream: if request.stream:
generate = predict(query, history, request.model) generate = predict(query, history, request.model)
return StreamingResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history) response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0 current_length = 0
@ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
@ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason="stop" finish_reason="stop"
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -5,4 +5,5 @@ torch>=2.0
gradio gradio
mdtex2html mdtex2html
sentencepiece sentencepiece
accelerate accelerate
sse-starlette