Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions llama_stack/providers/remote/inference/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str:

class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
Inference,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
Expand Down Expand Up @@ -155,7 +155,7 @@ async def chat_completion(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model(**params)
chunk = next(res["body"])
chunk = res["body"].read()
result = json.loads(chunk.decode("utf-8"))

choice = OpenAICompatCompletionChoice(
Expand All @@ -172,14 +172,16 @@ async def _stream_chat_completion(self, request: ChatCompletionRequest) -> Async
event_stream = res["body"]

async def _generate_and_convert_to_openai_compat():
for chunk in event_stream:
chunk = chunk["chunk"]["bytes"]
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
yield OpenAICompatCompletionResponse(choices=[choice])
for event in event_stream:
if "chunk" in event:
chunk_data = event["chunk"]["bytes"]
result = json.loads(chunk_data.decode("utf-8"))
if "generation" in result:
choice = OpenAICompatCompletionChoice(
finish_reason=result.get("stop_reason"),
text=result["generation"],
)
yield OpenAICompatCompletionResponse(choices=[choice])

stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
Expand All @@ -193,8 +195,9 @@ async def _get_params_for_chat_completion(self, request: ChatCompletionRequest)

if sampling_params.max_tokens:
options["max_gen_len"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
# Note: repetition_penalty is not supported by AWS Bedrock Llama models
# if sampling_params.repetition_penalty > 0:
# options["repetition_penalty"] = sampling_params.repetition_penalty

prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))

Expand Down
Loading