Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 103 additions & 2 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,
EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse,
GenerateRequest, LogProbs, ModelCard, ModelList, ModelPermission,
PoolingRequest, PoolingResponse, TopLogprob, UpdateParamsRequest, UsageInfo)
GenerateReqInput, GenerateReqMetaOutput, GenerateReqOutput, GenerateRequest,
LogProbs, ModelCard, ModelList, ModelPermission, PoolingRequest,
PoolingResponse, TopLogprob, UpdateParamsRequest, UsageInfo)
from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager
from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
Expand Down Expand Up @@ -896,6 +897,106 @@ async def _inner_call(i, generator):
return response


@router.post('/generate', dependencies=[Depends(check_api_key)])
async def generate(request: GenerateReqInput, raw_request: Request = None):

if request.session_id == -1:
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0:
return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.')

gen_config = GenerationConfig(
max_new_tokens=request.max_tokens,
do_sample=True,
logprobs=1 if request.return_logprob else None,
top_k=request.top_k,
top_p=request.top_p,
min_p=request.min_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
stop_token_ids=request.stop_token_ids,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=request.spaces_between_special_tokens,
include_stop_str_in_output=request.include_stop_str_in_output,
)

result_generator = VariableInterface.async_engine.generate(
messages=request.prompt,
session_id=request.session_id,
input_ids=request.input_ids,
gen_config=gen_config,
stream_response=True, # always use stream to enable batching
sequence_start=True,
sequence_end=True,
do_preprocess=False,
)

def create_finish_reason(finish_reason):
# TODO: add detail info
if not finish_reason:
return None
if finish_reason == 'length':
return dict(type='length')
if finish_reason == 'stop':
return dict(type='stop')
return dict(type='abort')

def create_generate_response_json(res, text, output_ids, logprobs, finish_reason):
meta = GenerateReqMetaOutput(finish_reason=create_finish_reason(finish_reason),
output_token_logprobs=logprobs or None,
prompt_tokens=res.input_token_len,
completion_tokens=res.generate_token_len)
response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta)
return response.model_dump_json()

async def generate_stream_generator():
async for res in result_generator:
text = res.response or ''
output_ids = res.token_ids
logprobs = []
if res.logprobs:
for tok, tok_logprobs in zip(res.token_ids, res.logprobs):
logprobs.append((tok_logprobs[tok], tok))
response_json = create_generate_response_json(res, text, output_ids, logprobs, res.finish_reason)
yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n'

if request.stream:
return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')

response = None

async def _inner_call():
text = ''
output_ids = []
logprobs = []
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await VariableInterface.async_engine.stop_session(request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected')
text += res.response or ''
output_ids.extend(res.token_ids or [])
if res.logprobs:
for tok, tok_logprobs in zip(res.token_ids, res.logprobs):
logprobs.append((tok_logprobs[tok], tok))
nonlocal response
meta = GenerateReqMetaOutput(finish_reason=create_finish_reason(res.finish_reason),
output_token_logprobs=logprobs or None,
prompt_tokens=res.input_token_len,
completion_tokens=res.generate_token_len)
response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta)

await _inner_call()
return response


@router.post('/v1/embeddings', tags=['unsupported'])
async def create_embeddings(request: EmbeddingsRequest, raw_request: Request = None):
"""Creates embeddings for the text."""
Expand Down
35 changes: 35 additions & 0 deletions lmdeploy/serve/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,38 @@ class UpdateParamsRequest(BaseModel):
"""Update weights request."""
serialized_named_tensors: Union[str, List[str], Dict]
finished: bool = False


# /generate input
class GenerateReqInput(BaseModel):
session_id: Optional[int] = -1
prompt: Optional[str] = None
input_ids: Optional[List[int]] = None
return_logprob: Optional[bool] = None
max_tokens: int = 128
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
stream: Optional[bool] = False
temperature: float = 1.0
repetition_penalty: Optional[float] = 1.0
ignore_eos: Optional[bool] = False
top_p: float = 1.0
top_k: int = 0
min_p: float = 0.0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
include_stop_str_in_output: Optional[bool] = False


class GenerateReqMetaOutput(BaseModel):
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
finish_reason: Optional[Dict[str, Any]] = None
output_token_logprobs: Optional[List[tuple[float, int]]] = None # (logprob, token_id)


# /generate output
class GenerateReqOutput(BaseModel):
text: str
output_ids: List[int]
meta_info: GenerateReqMetaOutput