diff --git a/src/rev_ai/__init__.py b/src/rev_ai/__init__.py index 41c89c2c..30e64704 100644 --- a/src/rev_ai/__init__.py +++ b/src/rev_ai/__init__.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """Top-level package for rev_ai""" -__version__ = '2.20.0' +__version__ = '2.21.0' from .models import Job, JobStatus, Account, Transcript, Monologue, Element, MediaConfig, \ - CaptionType, CustomVocabulary, TopicExtractionJob, TopicExtractionResult, Topic, Informant, \ - SpeakerName, LanguageIdentificationJob, LanguageIdentificationResult, LanguageConfidence, \ - SentimentAnalysisResult, SentimentValue, SentimentMessage, SentimentAnalysisJob, \ - CustomerUrlData, RevAiApiDeploymentConfigMap, RevAiApiDeployment + CaptionType, GroupChannelsType, CustomVocabulary, TopicExtractionJob, TopicExtractionResult, \ + Topic, Informant, SpeakerName, LanguageIdentificationJob, LanguageIdentificationResult, \ + LanguageConfidence, SentimentAnalysisResult, SentimentValue, SentimentMessage, \ + SentimentAnalysisJob, CustomerUrlData, RevAiApiDeploymentConfigMap, RevAiApiDeployment diff --git a/src/rev_ai/apiclient.py b/src/rev_ai/apiclient.py index 830d6e4a..8f4d91fb 100644 --- a/src/rev_ai/apiclient.py +++ b/src/rev_ai/apiclient.py @@ -337,28 +337,45 @@ def get_list_of_jobs(self, limit=None, starting_after=None): return [Job.from_json(job) for job in response.json()] - def get_transcript_text(self, id_): + def get_transcript_text(self, id_, group_channels_by=None, group_channels_threshold_ms=None): """Get the transcript of a specific job as plain text. :param id_: id of job to be requested + :param group_channels_by: optional, GroupChannelsType grouping strategy for + multichannel transcripts. None for default. + :param group_channels_threshold_ms: optional, grouping threshold in milliseconds. + None for default. :returns: transcript data as text :raises: HTTPError """ if not id_: raise ValueError('id_ must be provided') + url = self._build_transcript_url( + id_, + group_channels_by=group_channels_by, + group_channels_threshold_ms=group_channels_threshold_ms + ) + response = self._make_http_request( "GET", - urljoin(self.base_url, 'jobs/{}/transcript'.format(id_)), + url, headers={'Accept': 'text/plain'} ) return response.text - def get_transcript_text_as_stream(self, id_): + def get_transcript_text_as_stream(self, + id_, + group_channels_by=None, + group_channels_threshold_ms=None): """Get the transcript of a specific job as a plain text stream. :param id_: id of job to be requested + :param group_channels_by: optional, GroupChannelsType grouping strategy for + multichannel transcripts. None for default. + :param group_channels_threshold_ms: optional, grouping threshold in milliseconds. + None for default. :returns: requests.models.Response HTTP response which can be used to stream the payload of the response :raises: HTTPError @@ -366,37 +383,63 @@ def get_transcript_text_as_stream(self, id_): if not id_: raise ValueError('id_ must be provided') + url = self._build_transcript_url( + id_, + group_channels_by=group_channels_by, + group_channels_threshold_ms=group_channels_threshold_ms + ) + response = self._make_http_request( "GET", - urljoin(self.base_url, 'jobs/{}/transcript'.format(id_)), + url, headers={'Accept': 'text/plain'}, stream=True ) return response - def get_transcript_json(self, id_): + def get_transcript_json(self, + id_, + group_channels_by=None, + group_channels_threshold_ms=None): """Get the transcript of a specific job as json. :param id_: id of job to be requested + :param group_channels_by: optional, GroupChannelsType grouping strategy for + multichannel transcripts. None for default. + :param group_channels_threshold_ms: optional, grouping threshold in milliseconds. + None for default. :returns: transcript data as json :raises: HTTPError """ if not id_: raise ValueError('id_ must be provided') + url = self._build_transcript_url( + id_, + group_channels_by=group_channels_by, + group_channels_threshold_ms=group_channels_threshold_ms + ) + response = self._make_http_request( "GET", - urljoin(self.base_url, 'jobs/{}/transcript'.format(id_)), + url, headers={'Accept': self.rev_json_content_type} ) return response.json() - def get_transcript_json_as_stream(self, id_): + def get_transcript_json_as_stream(self, + id_, + group_channels_by=None, + group_channels_threshold_ms=None): """Get the transcript of a specific job as streamed json. :param id_: id of job to be requested + :param group_channels_by: optional, GroupChannelsType grouping strategy for + multichannel transcripts. None for default. + :param group_channels_threshold_ms: optional, grouping threshold in milliseconds. + None for default. :returns: requests.models.Response HTTP response which can be used to stream the payload of the response :raises: HTTPError @@ -404,28 +447,44 @@ def get_transcript_json_as_stream(self, id_): if not id_: raise ValueError('id_ must be provided') + url = self._build_transcript_url( + id_, + group_channels_by=group_channels_by, + group_channels_threshold_ms=group_channels_threshold_ms + ) + response = self._make_http_request( "GET", - urljoin(self.base_url, 'jobs/{}/transcript'.format(id_)), + url, headers={'Accept': self.rev_json_content_type}, stream=True ) return response - def get_transcript_object(self, id_): + def get_transcript_object(self, id_, group_channels_by=None, group_channels_threshold_ms=None): """Get the transcript of a specific job as a python object`. :param id_: id of job to be requested + :param group_channels_by: optional, GroupChannelsType grouping strategy for + multichannel transcripts. None for default. + :param group_channels_threshold_ms: optional, grouping threshold in milliseconds. + None for default. :returns: transcript data as a python object :raises: HTTPError """ if not id_: raise ValueError('id_ must be provided') + url = self._build_transcript_url( + id_, + group_channels_by=group_channels_by, + group_channels_threshold_ms=group_channels_threshold_ms + ) + response = self._make_http_request( "GET", - urljoin(self.base_url, 'jobs/{}/transcript'.format(id_)), + url, headers={'Accept': self.rev_json_content_type} ) @@ -814,3 +873,22 @@ def _create_job_options_payload( def _create_captions_query(self, speaker_channel): return '' if speaker_channel is None else '?speaker_channel={}'.format(speaker_channel) + + def _build_transcript_url(self, id_, group_channels_by=None, group_channels_threshold_ms=None): + """Build the get transcript url. + + :param id_: id of job to be requested + :param group_channels_by: optional, GroupChannelsType grouping strategy for + multichannel transcripts. None for default. + :param group_channels_threshold_ms: optional, grouping threshold in milliseconds. + None for default. + :returns: url for getting the transcript + """ + params = [] + if group_channels_by is not None: + params.append('group_channels_by={}'.format(group_channels_by)) + if group_channels_threshold_ms is not None: + params.append('group_channels_threshold_ms={}'.format(group_channels_threshold_ms)) + + query = '?{}'.format('&'.join(params)) + return urljoin(self.base_url, 'jobs/{}/transcript{}'.format(id_, query)) diff --git a/src/rev_ai/models/__init__.py b/src/rev_ai/models/__init__.py index 23398a11..77e80976 100644 --- a/src/rev_ai/models/__init__.py +++ b/src/rev_ai/models/__init__.py @@ -4,7 +4,7 @@ from .customvocabulary import CustomVocabulary from .streaming import MediaConfig from .asynchronous import Job, JobStatus, Account, Transcript, Monologue, Element, CaptionType, \ - SpeakerName + SpeakerName, GroupChannelsType from .insights import TopicExtractionJob, TopicExtractionResult, Topic, Informant, \ SentimentAnalysisResult, SentimentValue, SentimentMessage, SentimentAnalysisJob from .language_id import LanguageIdentificationJob, LanguageIdentificationResult, LanguageConfidence diff --git a/src/rev_ai/models/asynchronous/__init__.py b/src/rev_ai/models/asynchronous/__init__.py index 090f2672..92f764cf 100644 --- a/src/rev_ai/models/asynchronous/__init__.py +++ b/src/rev_ai/models/asynchronous/__init__.py @@ -7,3 +7,4 @@ from .account import Account from .transcript import Transcript, Monologue, Element from .speaker_name import SpeakerName +from .group_channels_type import GroupChannelsType diff --git a/src/rev_ai/models/asynchronous/group_channels_type.py b/src/rev_ai/models/asynchronous/group_channels_type.py new file mode 100644 index 00000000..a28cbd76 --- /dev/null +++ b/src/rev_ai/models/asynchronous/group_channels_type.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +"""Enum for group_channels_by types""" + +from enum import Enum + + +class GroupChannelsType(str, Enum): + SPEAKER = 'speaker' + SENTENCE = 'sentence' + WORD = 'word' + + @classmethod + def from_string(cls, status): + return cls[status.upper()] diff --git a/tests/test_transcript.py b/tests/test_transcript.py index 5e45f3bf..2acc0212 100644 --- a/tests/test_transcript.py +++ b/tests/test_transcript.py @@ -6,6 +6,7 @@ from src.rev_ai.apiclient import RevAiAPIClient from src.rev_ai.models import RevAiApiDeploymentConfigMap, RevAiApiDeployment from src.rev_ai.models.asynchronous import Transcript, Monologue, Element +from src.rev_ai.models.asynchronous.group_channels_type import GroupChannelsType try: from urllib.parse import urljoin @@ -20,19 +21,32 @@ @pytest.mark.usefixtures('mock_session', 'make_mock_response') class TestTranscriptEndpoints(): - def test_get_transcript_text(self, mock_session, make_mock_response): + @pytest.mark.parametrize( + 'group_channels_by, group_channels_threshold_ms', + [(None, None), (GroupChannelsType.SENTENCE, 5000), (GroupChannelsType.WORD, 2000)] + ) + def test_get_transcript_text( + self, + mock_session, + make_mock_response, + group_channels_by, + group_channels_threshold_ms + ): data = 'Test' client = RevAiAPIClient(TOKEN) expected_headers = {'Accept': 'text/plain'} expected_headers.update(client.default_headers) response = make_mock_response(url=URL, text=data) mock_session.request.return_value = response + expected_url = URL + if group_channels_by and group_channels_threshold_ms: + expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}" - res = client.get_transcript_text(JOB_ID) + res = client.get_transcript_text(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms) assert res == data mock_session.request.assert_called_once_with("GET", - URL, + expected_url, headers=expected_headers) @pytest.mark.parametrize('id', [None, '']) @@ -40,19 +54,32 @@ def test_get_transcript_text_with_no_job_id(self, id, mock_session): with pytest.raises(ValueError, match='id_ must be provided'): RevAiAPIClient(TOKEN).get_transcript_text(id) - def test_get_transcript_text_as_stream(self, mock_session, make_mock_response): + @pytest.mark.parametrize( + 'group_channels_by, group_channels_threshold_ms', + [(None, None), (GroupChannelsType.SENTENCE, 5000), (GroupChannelsType.WORD, 2000)] + ) + def test_get_transcript_text_as_stream( + self, + mock_session, + make_mock_response, + group_channels_by, + group_channels_threshold_ms + ): data = 'Test' client = RevAiAPIClient(TOKEN) expected_headers = {'Accept': 'text/plain'} expected_headers.update(client.default_headers) response = make_mock_response(url=URL, text=data) mock_session.request.return_value = response + expected_url = URL + if group_channels_by and group_channels_threshold_ms: + expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}" - res = client.get_transcript_text_as_stream(JOB_ID) + res = client.get_transcript_text_as_stream(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms) assert res.content == data mock_session.request.assert_called_once_with("GET", - URL, + expected_url, headers=expected_headers, stream=True) @@ -61,7 +88,17 @@ def test_get_transcript_text_as_stream_with_no_job_id(self, id, mock_session): with pytest.raises(ValueError, match='id_ must be provided'): RevAiAPIClient(TOKEN).get_transcript_text_as_stream(id) - def test_get_transcript_json(self, mock_session, make_mock_response): + @pytest.mark.parametrize( + 'group_channels_by, group_channels_threshold_ms', + [(None, None), (GroupChannelsType.SENTENCE, 5000), (GroupChannelsType.WORD, 2000)] + ) + def test_get_transcript_json( + self, + mock_session, + make_mock_response, + group_channels_by, + group_channels_threshold_ms + ): data = { 'monologues': [{ 'speaker': 1, @@ -80,19 +117,32 @@ def test_get_transcript_json(self, mock_session, make_mock_response): expected_headers.update(client.default_headers) response = make_mock_response(url=URL, json_data=data) mock_session.request.return_value = response + expected_url = URL + if group_channels_by and group_channels_threshold_ms: + expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}" - res = client.get_transcript_json(JOB_ID) + res = client.get_transcript_json(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms) assert res == expected mock_session.request.assert_called_once_with( - "GET", URL, headers=expected_headers) + "GET", expected_url, headers=expected_headers) @pytest.mark.parametrize('id', [None, '']) def test_get_transcript_json_with_no_job_id(self, id, mock_session): with pytest.raises(ValueError, match='id_ must be provided'): RevAiAPIClient(TOKEN).get_transcript_json(id) - def test_get_transcript_json_as_stream(self, mock_session, make_mock_response): + @pytest.mark.parametrize( + 'group_channels_by, group_channels_threshold_ms', + [(None, None), (GroupChannelsType.SENTENCE, 5000), (GroupChannelsType.WORD, 2000)] + ) + def test_get_transcript_json_as_stream( + self, + mock_session, + make_mock_response, + group_channels_by, + group_channels_threshold_ms + ): data = { 'monologues': [{ 'speaker': 1, @@ -111,19 +161,32 @@ def test_get_transcript_json_as_stream(self, mock_session, make_mock_response): expected_headers.update(client.default_headers) response = make_mock_response(url=URL, json_data=data) mock_session.request.return_value = response + expected_url = URL + if group_channels_by and group_channels_threshold_ms: + expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}" - res = client.get_transcript_json_as_stream(JOB_ID) + res = client.get_transcript_json_as_stream(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms) assert json.loads(res.content.decode('utf-8').replace("\'", "\"")) == expected mock_session.request.assert_called_once_with( - "GET", URL, headers=expected_headers, stream=True) + "GET", expected_url, headers=expected_headers, stream=True) @pytest.mark.parametrize('id', [None, '']) def test_get_transcript_json_as_stream_with_no_job_id(self, id, mock_session): with pytest.raises(ValueError, match='id_ must be provided'): RevAiAPIClient(TOKEN).get_transcript_json_as_stream(id) - def test_get_transcript_object_with_success(self, mock_session, make_mock_response): + @pytest.mark.parametrize( + 'group_channels_by, group_channels_threshold_ms', + [(None, None), (GroupChannelsType.SENTENCE, 5000), (GroupChannelsType.WORD, 2000)] + ) + def test_get_transcript_object_with_success( + self, + mock_session, + make_mock_response, + group_channels_by, + group_channels_threshold_ms + ): data = { 'monologues': [{ 'speaker': 1, @@ -142,12 +205,15 @@ def test_get_transcript_object_with_success(self, mock_session, make_mock_respon expected_headers.update(client.default_headers) response = make_mock_response(url=URL, json_data=data) mock_session.request.return_value = response + expected_url = URL + if group_channels_by and group_channels_threshold_ms: + expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}" - res = client.get_transcript_object(JOB_ID) + res = client.get_transcript_object(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms) assert res == expected mock_session.request.assert_called_once_with( - "GET", URL, headers=expected_headers) + "GET", expected_url, headers=expected_headers) @pytest.mark.parametrize('id', [None, '']) def test_get_transcript_object_with_no_job_id(self, id, mock_session):