Update openai_api.py

This commit is contained in:
hoshi-hiyouga 2023-06-26 18:13:21 +08:00 committed by GitHub
parent 9edf43fa53
commit 892691c3b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,6 @@
# coding=utf-8 # coding=utf-8
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) # Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py # Usage: python api_demo.py
# Visit http://localhost:8000/docs for documents. # Visit http://localhost:8000/docs for documents.
@ -26,6 +26,21 @@ async def lifespan(app: FastAPI): # collects GPU memory
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"] role: Literal["user", "assistant", "system"]
content: str content: str
@ -64,6 +79,13 @@ class ChatCompletionResponse(BaseModel):
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer global model, tokenizer