-
Notifications
You must be signed in to change notification settings - Fork 192
feat: Add prompt routing to AmazonBedrockChatGenerator
#2220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
...edrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
Outdated
Show resolved
Hide resolved
...edrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
Outdated
Show resolved
Hide resolved
...edrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
Outdated
Show resolved
Hide resolved
...edrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
Outdated
Show resolved
Hide resolved
Hey @medsriha it would be great if you could check if this solution works for your use case! |
@Amnah199 Taking a look at the code you added it seems like we are enabling two features here.
Looking at the issue originally opened by @medsriha it seems like we really only need to support the second case currently. Basically allow a user to pass a |
def test_to_dict_with_prompt_router_config(self, mock_boto3_session, boto3_config): | ||
""" | ||
Test that the to_dict method returns the correct dictionary without aws credentials | ||
""" | ||
generator = AmazonBedrockChatGenerator( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition to this test it would be good to add an integration test where we actually run the generator with the prompt router and check that it returns the expected output.
@sjrl I see what you mean. For simplicity, if we just allow passing router ARNs than the user can simply pass the ARNs for default routers or the ones they have configured in the AWS console. On the other hand, in the context issue we also have:
If we want to support this aspect within our component, we can keep my implementation. I am open to both depending on what are the cons of keeping this support for router_config. |
@medsriha could you weigh-in and let us know if you need both of these features? |
Thanks for pushing this forward, @Amnah199 and @sjrl. The ask was a bit ambiguous, I apologize. I think to keep it simple is to support reusing an existing prompt router via ARN for now. That way, users can create and manage their routers directly in AWS, while Haystack just consumes them. We can always revisit “create-on-the-fly” support later if there’s a demand. |
resolved_router_arn = resolve_secret(self.prompt_router_arn) | ||
bedrock_client = session.client("bedrock", config=config) | ||
prompt_router = bedrock_client.get_prompt_router( | ||
promptRouterArn=resolved_router_arn, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sjrl I believe we can show a specific error message in case of invalid ARN. If you agree I can raise an exception here with the message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that could be helpful. What did you have in mind?
prompt_router = bedrock_client.get_prompt_router( | ||
promptRouterArn=resolved_router_arn, | ||
) | ||
self.model = prompt_router["promptRouterArn"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, is this value prompt_router["promptRouterArn"]
the same as the one provided by the user so prompt_router_arn
at init time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it would be the same. Here we verify if it actually is a valid ARN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I see. Perhaps I could propose an alternative approach. If self.prompt_router_arn
works in the same way as we treat self.model
do we really need to make a new variable?
Would specifying an prompt_router_arn
in the model
field at init time already work with the existing integration?
@@ -158,6 +163,8 @@ def __init__( | |||
streaming_callback: Optional[StreamingCallbackT] = None, | |||
boto3_config: Optional[Dict[str, Any]] = None, | |||
tools: Optional[Union[List[Tool], Toolset]] = None, | |||
*, | |||
prompt_router_arn: Optional[Secret] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to double check, does prompt_router_arn
need to be a Secret? What was the motivation to make it a secret?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The custom router ARN contains your region and AWS account ID, which probably we dont want to expose.
@medsriha We realized that the existing generator implementation already supports passing a router ARN to the The only potential reason to introduce a separate |
@Amnah199 turns out your first solution was exactly what we needed. I had thought we weren’t supposed to set up a new ARN via Haystack, but what we actually need is a JSON configuration that tells Bedrock which model to use and when. |
Related Issues
Proposed Changes:
prompt_router_arn
during initializationHow did you test it?
Notes for the reviewer
Checklist
fix:
,feat:
,build:
,chore:
,ci:
,docs:
,style:
,refactor:
,perf:
,test:
.