Skip to content

Commit

Permalink
Fix gcp text to speech uri fetch (#42309)
Browse files Browse the repository at this point in the history
- Fix acces to the uri attribute, if it's provided via the RecognitionAudio model.

Co-authored-by: Oleg Kachur <[email protected]>
  • Loading branch information
olegkachur-e and Oleg Kachur authored Sep 27, 2024
1 parent 56ab422 commit dc43d31
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 27 deletions.
17 changes: 8 additions & 9 deletions airflow/providers/google/cloud/operators/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

FileDetailsLink.persist(
context=context,
task_instance=self,
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
uri=self.audio["uri"][5:],
project_id=self.project_id or hook.project_id,
)

if self.audio.uri:
FileDetailsLink.persist(
context=context,
task_instance=self,
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
uri=self.audio.uri[5:],
project_id=self.project_id or hook.project_id,
)
response = hook.recognize_speech(
config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout
)
Expand Down
15 changes: 8 additions & 7 deletions airflow/providers/google/cloud/operators/translate_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,14 @@ def execute(self, context: Context) -> dict:
raise AirflowException(
f"Wrong response '{recognize_dict}' returned - it should contain {key} field"
)

if self.audio.uri:
FileDetailsLink.persist(
context=context,
task_instance=self,
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
uri=self.audio.uri[5:],
project_id=self.project_id or translate_hook.project_id,
)
try:
translation = translate_hook.translate(
values=transcript,
Expand All @@ -179,12 +186,6 @@ def execute(self, context: Context) -> dict:
model=self.model,
)
self.log.info("Translated output: %s", translation)
FileDetailsLink.persist(
context=context,
task_instance=self,
uri=self.audio["uri"][5:],
project_id=self.project_id or translate_hook.project_id,
)
return translation
except ValueError as e:
self.log.error("An error has been thrown from translate speech method:")
Expand Down
28 changes: 25 additions & 3 deletions tests/providers/google/cloud/operators/test_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@

import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.speech_v1 import RecognizeResponse
from google.cloud.speech_v1 import RecognitionAudio, RecognitionConfig, RecognizeResponse

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator

PROJECT_ID = "project-id"
GCP_CONN_ID = "gcp-conn-id"
IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
CONFIG = {"encoding": "LINEAR16"}
AUDIO = {"uri": "gs://bucket/object"}
CONFIG = RecognitionConfig({"encoding": "LINEAR16"})
AUDIO = RecognitionAudio({"uri": "gs://bucket/object"})


class TestCloudSpeechToTextRecognizeSpeechOperator:
Expand Down Expand Up @@ -80,3 +80,25 @@ def test_missing_audio(self, mock_hook):
err = ctx.value
assert "audio" in str(err)
mock_hook.assert_not_called()

@patch("airflow.providers.google.cloud.operators.speech_to_text.FileDetailsLink.persist")
@patch("airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextHook")
def test_no_audio_uri(self, mock_hook, mock_file_link):
mock_hook.return_value.recognize_speech.return_value = RecognizeResponse()
AUDIO_NO_URI = RecognitionAudio({"content": b"set content data instead of uri"})

op = CloudSpeechToTextRecognizeSpeechOperator(
project_id=PROJECT_ID,
gcp_conn_id=GCP_CONN_ID,
config=CONFIG,
audio=AUDIO_NO_URI,
task_id="id",
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=MagicMock())

mock_hook.return_value.recognize_speech.assert_called_once_with(
config=CONFIG, audio=AUDIO_NO_URI, retry=DEFAULT, timeout=None
)
assert op.audio.uri == ""
mock_file_link.assert_not_called()
57 changes: 49 additions & 8 deletions tests/providers/google/cloud/operators/test_translate_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pytest
from google.cloud.speech_v1 import (
RecognitionAudio,
RecognitionConfig,
RecognizeResponse,
SpeechRecognitionAlternative,
SpeechRecognitionResult,
Expand Down Expand Up @@ -54,8 +56,8 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook):
]

op = CloudTranslateSpeechOperator(
audio={"uri": "gs://bucket/object"},
config={"encoding": "LINEAR16"},
audio=RecognitionAudio({"uri": "gs://bucket/object"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
target_language="pl",
format_="text",
source_language=None,
Expand All @@ -77,8 +79,8 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook):
)

mock_speech_hook.return_value.recognize_speech.assert_called_once_with(
audio={"uri": "gs://bucket/object"},
config={"encoding": "LINEAR16"},
audio=RecognitionAudio({"uri": "gs://bucket/object"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
)

mock_translate_hook.return_value.translate.assert_called_once_with(
Expand All @@ -104,8 +106,8 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook):
results=[SpeechRecognitionResult()]
)
op = CloudTranslateSpeechOperator(
audio={"uri": "gs://bucket/object"},
config={"encoding": "LINEAR16"},
audio=RecognitionAudio({"uri": "gs://bucket/object"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
target_language="pl",
format_="text",
source_language=None,
Expand All @@ -128,8 +130,47 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook):
)

mock_speech_hook.return_value.recognize_speech.assert_called_once_with(
audio={"uri": "gs://bucket/object"},
config={"encoding": "LINEAR16"},
audio=RecognitionAudio({"uri": "gs://bucket/object"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
)

mock_translate_hook.return_value.translate.assert_not_called()

@mock.patch("airflow.providers.google.cloud.operators.translate_speech.FileDetailsLink.persist")
@mock.patch("airflow.providers.google.cloud.operators.translate_speech.CloudSpeechToTextHook")
@mock.patch("airflow.providers.google.cloud.operators.translate_speech.CloudTranslateHook")
def test_no_audio_uri(self, mock_translate_hook, mock_speech_hook, file_link_mock):
mock_speech_hook.return_value.recognize_speech.return_value = RecognizeResponse(
results=[
SpeechRecognitionResult(
alternatives=[SpeechRecognitionAlternative(transcript="test speech recognition result")]
)
]
)
mock_translate_hook.return_value.translate.return_value = [
{
"translatedText": "sprawdzić wynik rozpoznawania mowy",
"detectedSourceLanguage": "en",
"model": "base",
"input": "test speech recognition result",
}
]
op = CloudTranslateSpeechOperator(
audio=RecognitionAudio({"content": b"set content data instead of uri"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
target_language="pl",
format_="text",
source_language=None,
model="base",
gcp_conn_id=GCP_CONN_ID,
task_id="id",
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=mock.MagicMock())

mock_speech_hook.return_value.recognize_speech.assert_called_once_with(
audio=RecognitionAudio({"content": b"set content data instead of uri"}),
config=RecognitionConfig({"encoding": "LINEAR16"}),
)
assert op.audio.uri == ""
file_link_mock.assert_not_called()

0 comments on commit dc43d31

Please sign in to comment.