Merge pull request #152 from mougua/main
fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到
This commit is contained in:
commit
b99e3d74c9
@ -11,9 +11,9 @@ from pydantic import BaseModel, Field
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from starlette.responses import StreamingResponse
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, request.model)
|
||||
return StreamingResponse(generate, media_type="text/event-stream")
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, _ = model.chat(tokenizer, query, history=history)
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
current_length = 0
|
||||
|
||||
@ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
@ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -6,3 +6,4 @@ gradio
|
||||
mdtex2html
|
||||
sentencepiece
|
||||
accelerate
|
||||
sse-starlette
|
Loading…
Reference in New Issue
Block a user