Merge pull request #152 from mougua/main

fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到
This commit is contained in:
Zhengxiao Du 2023-07-04 12:07:01 +08:00 committed by GitHub
commit b99e3d74c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 6 deletions

View File

@ -11,9 +11,9 @@ from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager
@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream:
generate = predict(query, history, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
@ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__":

View File

@ -6,3 +6,4 @@ gradio
mdtex2html
sentencepiece
accelerate
sse-starlette