diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index acadbb788a90..04242306f873 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -128,6 +128,8 @@ def prompt_language_model( """ Use the Vertex AI PaLM API to generate natural language text. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response. :param pretrained_model: A pre-trained model optimized for performing natural @@ -141,8 +143,6 @@ def prompt_language_model( of their probabilities equals the top_p value. Defaults to 0.8. :param top_k: A top_k of 1 means the selected token is the most probable among all tokens. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) @@ -178,11 +178,11 @@ def generate_text_embeddings( """ Use the Vertex AI PaLM API to generate text embeddings. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response. :param pretrained_model: A pre-trained model optimized for generating text embeddings. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) model = self.get_text_embedding_model(pretrained_model) @@ -210,16 +210,16 @@ def prompt_multimodal_model( """ Use the Vertex AI Gemini Pro foundation model to generate natural language text. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Multi-modal model, in order to elicit a specific response. - :param location: Required. The ID of the Google Cloud location that the service belongs to. :param generation_config: Optional. Generation configuration settings. :param safety_settings: Optional. Per request settings for blocking unsafe content. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) @@ -251,6 +251,8 @@ def prompt_multimodal_model_with_media( """ Use the Vertex AI Gemini Pro foundation model to generate natural language text. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Multi-modal model, in order to elicit a specific response. :param generation_config: Optional. Generation configuration settings. @@ -262,8 +264,6 @@ def prompt_multimodal_model_with_media( :param media_gcs_path: A GCS path to a content file such as an image or a video. Can be passed to the multi-modal model as part of the prompt. Used with vision models. :param mime_type: Validates the media type presented by the file in the media_gcs_path. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) @@ -290,6 +290,8 @@ def text_generation_model_predict( """ Use the Vertex AI PaLM API to generate natural language text. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response. :param pretrained_model: A pre-trained model optimized for performing natural @@ -303,8 +305,6 @@ def text_generation_model_predict( of their probabilities equals the top_p value. Defaults to 0.8. :param top_k: A top_k of 1 means the selected token is the most probable among all tokens. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) @@ -334,11 +334,11 @@ def text_embedding_model_get_embeddings( """ Use the Vertex AI PaLM API to generate text embeddings. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response. :param pretrained_model: A pre-trained model optimized for generating text embeddings. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) model = self.get_text_embedding_model(pretrained_model) @@ -355,26 +355,31 @@ def generative_model_generate_content( tools: list | None = None, generation_config: dict | None = None, safety_settings: dict | None = None, + system_instruction: str | None = None, pretrained_model: str = "gemini-pro", project_id: str = PROVIDE_PROJECT_ID, ) -> str: """ Use the Vertex AI Gemini Pro foundation model to generate natural language text. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param contents: Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response. - :param location: Required. The ID of the Google Cloud location that the service belongs to. :param generation_config: Optional. Generation configuration settings. :param safety_settings: Optional. Per request settings for blocking unsafe content. + :param tools: Optional. A list of tools available to the model during evaluation, such as a data store. + :param system_instruction: Optional. An instruction given to the model to guide its behavior. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) - model = self.get_generative_model(pretrained_model) + model = self.get_generative_model( + pretrained_model=pretrained_model, system_instruction=system_instruction + ) response = model.generate_content( contents=contents, tools=tools, @@ -400,12 +405,13 @@ def supervised_fine_tuning_train( """ Use the Supervised Fine Tuning API to create a tuning job. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param source_model: Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. :param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples. - :param location: Required. The ID of the Google Cloud location that the service belongs to. :param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up to 128 characters long and can consist of any UTF-8 characters. :param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be @@ -447,18 +453,18 @@ def count_tokens( """ Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param contents: Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response. - :param location: Required. The ID of the Google Cloud location that the service belongs to. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. """ vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) - model = self.get_generative_model(pretrained_model) + model = self.get_generative_model(pretrained_model=pretrained_model) response = model.count_tokens( contents=contents, ) @@ -484,6 +490,8 @@ def run_evaluation( """ Use the Rapid Evaluation API to evaluate a model. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param pretrained_model: Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. @@ -492,8 +500,6 @@ def run_evaluation( :param experiment_name: Required. The name of the evaluation experiment. :param experiment_run_name: Required. The specific run name or ID for this experiment. :param prompt_template: Required. The template used to format the model's prompts during evaluation. Adheres to Rapid Evaluation API. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. :param generation_config: Optional. A dictionary containing generation parameters for the model. :param safety_settings: Optional. A dictionary specifying harm category thresholds for blocking model outputs. :param system_instruction: Optional. An instruction given to the model to guide its behavior. diff --git a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index b0b9462e2151..fddd5dcf7286 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Sequence -from google.cloud.aiplatform_v1 import types as types_v1 from google.cloud.aiplatform_v1beta1 import types as types_v1beta1 from airflow.exceptions import AirflowProviderDeprecationWarning @@ -510,12 +509,14 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to (templated). - :param contents: Required. The multi-part content of a message that a user or a program - gives to the generative model, in order to elicit a specific response. :param location: Required. The ID of the Google Cloud location that the service belongs to (templated). + :param contents: Required. The multi-part content of a message that a user or a program + gives to the generative model, in order to elicit a specific response. :param generation_config: Optional. Generation configuration settings. :param safety_settings: Optional. Per request settings for blocking unsafe content. + :param tools: Optional. A list of tools available to the model during evaluation, such as a data store. + :param system_instruction: Optional. An instruction given to the model to guide its behavior. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can @@ -537,11 +538,12 @@ def __init__( self, *, project_id: str, - contents: list, location: str, + contents: list, tools: list | None = None, generation_config: dict | None = None, safety_settings: dict | None = None, + system_instruction: str | None = None, pretrained_model: str = "gemini-pro", gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -554,6 +556,7 @@ def __init__( self.tools = tools self.generation_config = generation_config self.safety_settings = safety_settings + self.system_instruction = system_instruction self.pretrained_model = pretrained_model self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -570,6 +573,7 @@ def execute(self, context: Context): tools=self.tools, generation_config=self.generation_config, safety_settings=self.safety_settings, + system_instruction=self.system_instruction, pretrained_model=self.pretrained_model, ) @@ -583,14 +587,14 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator): """ Use the Supervised Fine Tuning API to create a tuning job. + :param project_id: Required. The ID of the Google Cloud project that the + service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param source_model: Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. :param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples. - :param project_id: Required. The ID of the Google Cloud project that the - service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. :param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up to 128 characters long and can consist of any UTF-8 characters. :param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be @@ -617,10 +621,10 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator): def __init__( self, *, - source_model: str, - train_dataset: str, project_id: str, location: str, + source_model: str, + train_dataset: str, tuned_model_display_name: str | None = None, validation_dataset: str | None = None, epochs: int | None = None, @@ -631,6 +635,8 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) + self.project_id = project_id + self.location = location self.source_model = source_model self.train_dataset = train_dataset self.tuned_model_display_name = tuned_model_display_name @@ -638,8 +644,6 @@ def __init__( self.epochs = epochs self.adapter_size = adapter_size self.learning_rate_multiplier = learning_rate_multiplier - self.project_id = project_id - self.location = location self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -649,10 +653,10 @@ def execute(self, context: Context): impersonation_chain=self.impersonation_chain, ) response = self.hook.supervised_fine_tuning_train( - source_model=self.source_model, - train_dataset=self.train_dataset, project_id=self.project_id, location=self.location, + source_model=self.source_model, + train_dataset=self.train_dataset, validation_dataset=self.validation_dataset, epochs=self.epochs, adapter_size=self.adapter_size, @@ -666,7 +670,12 @@ def execute(self, context: Context): self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name) self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name) - return types_v1.TuningJob.to_dict(response) + result = { + "tuned_model_name": response.tuned_model_name, + "tuned_model_endpoint_name": response.tuned_model_endpoint_name, + } + + return result class CountTokensOperator(GoogleCloudBaseOperator): @@ -675,12 +684,10 @@ class CountTokensOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to (templated). - :param contents: Required. The multi-part content of a message that a user or a program - gives to the generative model, in order to elicit a specific response. :param location: Required. The ID of the Google Cloud location that the service belongs to (templated). - :param system_instruction: Optional. Instructions for the model to steer it toward better - performance. For example, "Answer as concisely as possible" + :param contents: Required. The multi-part content of a message that a user or a program + gives to the generative model, in order to elicit a specific response. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can @@ -702,8 +709,8 @@ def __init__( self, *, project_id: str, - contents: list, location: str, + contents: list, pretrained_model: str = "gemini-pro", gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -742,6 +749,8 @@ class RunEvaluationOperator(GoogleCloudBaseOperator): """ Use the Rapid Evaluation API to evaluate a model. + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. :param pretrained_model: Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. @@ -750,8 +759,6 @@ class RunEvaluationOperator(GoogleCloudBaseOperator): :param experiment_name: Required. The name of the evaluation experiment. :param experiment_run_name: Required. The specific run name or ID for this experiment. :param prompt_template: Required. The template used to format the model's prompts during evaluation. Adheres to Rapid Evaluation API. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. :param generation_config: Optional. A dictionary containing generation parameters for the model. :param safety_settings: Optional. A dictionary specifying harm category thresholds for blocking model outputs. :param system_instruction: Optional. An instruction given to the model to guide its behavior. @@ -781,14 +788,14 @@ class RunEvaluationOperator(GoogleCloudBaseOperator): def __init__( self, *, + project_id: str, + location: str, pretrained_model: str, eval_dataset: dict, metrics: list, experiment_name: str, experiment_run_name: str, prompt_template: str, - project_id: str, - location: str, generation_config: dict | None = None, safety_settings: dict | None = None, system_instruction: str | None = None, @@ -799,18 +806,18 @@ def __init__( ) -> None: super().__init__(**kwargs) + self.project_id = project_id + self.location = location self.pretrained_model = pretrained_model self.eval_dataset = eval_dataset self.metrics = metrics self.experiment_name = experiment_name self.experiment_run_name = experiment_run_name self.prompt_template = prompt_template - self.system_instruction = system_instruction self.generation_config = generation_config self.safety_settings = safety_settings + self.system_instruction = system_instruction self.tools = tools - self.project_id = project_id - self.location = location self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -820,17 +827,17 @@ def execute(self, context: Context): impersonation_chain=self.impersonation_chain, ) response = self.hook.run_evaluation( + project_id=self.project_id, + location=self.location, pretrained_model=self.pretrained_model, eval_dataset=self.eval_dataset, metrics=self.metrics, experiment_name=self.experiment_name, experiment_run_name=self.experiment_run_name, prompt_template=self.prompt_template, - project_id=self.project_id, - location=self.location, - system_instruction=self.system_instruction, generation_config=self.generation_config, safety_settings=self.safety_settings, + system_instruction=self.system_instruction, tools=self.tools, ) diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py index 2004eec29fef..19723a51b1dd 100644 --- a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py @@ -222,7 +222,10 @@ def test_generative_model_generate_content(self, mock_model) -> None: safety_settings=TEST_SAFETY_SETTINGS, pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL, ) - mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL) + mock_model.assert_called_once_with( + pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL, + system_instruction=None, + ) mock_model.return_value.generate_content.assert_called_once_with( contents=TEST_CONTENTS, tools=TEST_TOOLS, @@ -257,7 +260,9 @@ def test_count_tokens(self, mock_model) -> None: location=GCP_LOCATION, pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL, ) - mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL) + mock_model.assert_called_once_with( + pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL, + ) mock_model.return_value.count_tokens.assert_called_once_with( contents=TEST_CONTENTS, ) diff --git a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py index 3745850e7210..e8efb9601fce 100644 --- a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py +++ b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py @@ -356,6 +356,7 @@ def test_execute(self, mock_hook): HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, } generation_config = {"max_output_tokens": 256, "top_p": 0.8, "temperature": 0.0} + system_instruction = "be concise." op = GenerativeModelGenerateContentOperator( task_id=TASK_ID, @@ -366,6 +367,7 @@ def test_execute(self, mock_hook): generation_config=generation_config, safety_settings=safety_settings, pretrained_model=pretrained_model, + system_instruction=system_instruction, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) @@ -382,6 +384,7 @@ def test_execute(self, mock_hook): generation_config=generation_config, safety_settings=safety_settings, pretrained_model=pretrained_model, + system_instruction=system_instruction, )