Merge pull request #152 from mougua/main
fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到
This commit is contained in:
commit
b99e3d74c9
@ -11,9 +11,9 @@ from pydantic import BaseModel, Field
|
|||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(query, history, request.model)
|
generate = predict(query, history, request.model)
|
||||||
return StreamingResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
response, _ = model.chat(tokenizer, query, history=history)
|
response, _ = model.chat(tokenizer, query, history=history)
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
|
||||||
current_length = 0
|
current_length = 0
|
||||||
|
|
||||||
@ -152,7 +152,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
@ -160,7 +161,9 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|||||||
finish_reason="stop"
|
finish_reason="stop"
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
yield '[DONE]'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -5,4 +5,5 @@ torch>=2.0
|
|||||||
gradio
|
gradio
|
||||||
mdtex2html
|
mdtex2html
|
||||||
sentencepiece
|
sentencepiece
|
||||||
accelerate
|
accelerate
|
||||||
|
sse-starlette
|
Loading…
Reference in New Issue
Block a user