Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit ec3002a

Browse files
authored
Composer integration for %%bq pipeline (#682)
* Composer integration for %%bq pipeline * Addressing code-review feedback
1 parent 60d015b commit ec3002a

File tree

9 files changed

+347
-3
lines changed

9 files changed

+347
-3
lines changed

google/datalab/bigquery/commands/_bigquery.py

+16
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,12 @@ def _create_pipeline_subparser(parser):
899899
help='The Google Cloud Storage bucket for the Airflow dags.')
900900
pipeline_parser.add_argument('-f', '--gcs_dag_file_path', type=str,
901901
help='The file path suffix for the Airflow dags.')
902+
pipeline_parser.add_argument('-e', '--environment', type=str,
903+
help='The name of the Google Cloud Composer environment.')
904+
pipeline_parser.add_argument('-l', '--location', type=str,
905+
help='The location of the Google Cloud Composer environment. '
906+
'Refer https://cloud.google.com/about/locations/ for further '
907+
'details.')
902908
pipeline_parser.add_argument('-g', '--debug', type=str,
903909
help='Debug output with the airflow spec.')
904910
return pipeline_parser
@@ -937,6 +943,16 @@ def _pipeline_cell(args, cell_body):
937943
except AttributeError:
938944
return "Perhaps you're missing: import google.datalab.contrib.pipeline.airflow"
939945

946+
location = args.get('location')
947+
environment = args.get('environment')
948+
949+
if location and environment:
950+
try:
951+
composer = google.datalab.contrib.pipeline.composer.Composer(location, environment)
952+
composer.deploy(name, airflow_spec)
953+
except AttributeError:
954+
return "Perhaps you're missing: import google.datalab.contrib.pipeline.composer"
955+
940956
if args.get('debug'):
941957
error_message += '\n\n' + airflow_spec
942958

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright 2018 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4+
# in compliance with the License. You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software distributed under the License
9+
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10+
# or implied. See the License for the specific language governing permissions and limitations under
11+
# the License.
12+
from ._composer import Composer # noqa
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2018 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4+
# in compliance with the License. You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software distributed under the License
9+
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10+
# or implied. See the License for the specific language governing permissions and limitations under
11+
# the License.
12+
13+
"""Implements Composer HTTP API wrapper."""
14+
import google.datalab.utils
15+
16+
17+
class Api(object):
18+
"""A helper class to issue Composer HTTP requests."""
19+
20+
_ENDPOINT = 'https://composer.googleapis.com/v1alpha1'
21+
_ENVIRONMENTS_PATH_FORMAT = '/projects/%s/locations/%s/environments/%s'
22+
23+
@staticmethod
24+
def get_environment_details(zone, environment):
25+
""" Issues a request to Composer to get the environment details.
26+
27+
Args:
28+
zone: GCP zone of the composer environment
29+
environment: name of the Composer environment
30+
Returns:
31+
A parsed result object.
32+
Raises:
33+
Exception if there is an error performing the operation.
34+
"""
35+
default_context = google.datalab.Context.default()
36+
url = (Api._ENDPOINT + (Api._ENVIRONMENTS_PATH_FORMAT % (default_context.project_id, zone,
37+
environment)))
38+
39+
return google.datalab.utils.Http.request(url, credentials=default_context.credentials)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2018 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4+
# in compliance with the License. You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software distributed under the License
9+
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10+
# or implied. See the License for the specific language governing permissions and limitations under
11+
# the License.
12+
13+
import google.datalab.storage as storage
14+
from google.datalab.contrib.pipeline.composer._api import Api
15+
import re
16+
17+
18+
class Composer(object):
19+
""" Represents a Composer object that encapsulates a set of functionality relating to the
20+
Cloud Composer service.
21+
22+
This object can be used to generate the python airflow spec.
23+
"""
24+
25+
gcs_file_regexp = re.compile('gs://.*')
26+
27+
def __init__(self, zone, environment):
28+
""" Initializes an instance of a Composer object.
29+
30+
Args:
31+
zone: Zone in which Composer environment has been created.
32+
environment: Name of the Composer environment.
33+
"""
34+
self._zone = zone
35+
self._environment = environment
36+
self._gcs_dag_location = None
37+
38+
def deploy(self, name, dag_string):
39+
bucket_name, file_path = self.gcs_dag_location.split('/', 3)[2:] # setting maxsplit to 3
40+
file_name = '{0}{1}.py'.format(file_path, name)
41+
42+
bucket = storage.Bucket(bucket_name)
43+
file_object = bucket.object(file_name)
44+
file_object.write_stream(dag_string, 'text/plain')
45+
46+
@property
47+
def gcs_dag_location(self):
48+
if not self._gcs_dag_location:
49+
environment_details = Api.get_environment_details(self._zone, self._environment)
50+
51+
if ('config' not in environment_details or
52+
'gcsDagLocation' not in environment_details.get('config')):
53+
raise ValueError('Dag location unavailable from Composer environment {0}'.format(
54+
self._environment))
55+
gcs_dag_location = environment_details['config']['gcsDagLocation']
56+
57+
if gcs_dag_location is None or not self.gcs_file_regexp.match(gcs_dag_location):
58+
raise ValueError(
59+
'Dag location {0} from Composer environment {1} is in incorrect format'.format(
60+
gcs_dag_location, self._environment))
61+
62+
self._gcs_dag_location = gcs_dag_location
63+
if gcs_dag_location.endswith('/') is False:
64+
self._gcs_dag_location = self._gcs_dag_location + '/'
65+
66+
return self._gcs_dag_location

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
'google.datalab.contrib.mlworkbench.commands',
3434
'google.datalab.contrib.pipeline',
3535
'google.datalab.contrib.pipeline.airflow',
36+
'google.datalab.contrib.pipeline.composer',
3637
'google.datalab.contrib.pipeline.commands',
3738
'google.datalab.data',
3839
'google.datalab.kernel',

tests/bigquery/pipeline_tests.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -577,21 +577,25 @@ def compare_parameters(self, actual_parameters, user_parameters):
577577
for item in user_parameters}
578578
self.assertDictEqual(actual_paramaters_dict, user_parameters_dict)
579579

580+
@mock.patch('google.datalab.contrib.pipeline.composer._api.Api.get_environment_details')
580581
@mock.patch('google.datalab.Context.default')
581582
@mock.patch('google.datalab.utils.commands.notebook_environment')
582583
@mock.patch('google.datalab.utils.commands.get_notebook_item')
583584
@mock.patch('google.datalab.bigquery.Table.exists')
584585
@mock.patch('google.datalab.bigquery.commands._bigquery._get_table')
585586
@mock.patch('google.datalab.storage.Bucket')
586587
def test_pipeline_cell_golden(self, mock_bucket_class, mock_get_table, mock_table_exists,
587-
mock_notebook_item, mock_environment, mock_default_context):
588+
mock_notebook_item, mock_environment, mock_default_context,
589+
mock_composer_env):
588590
import google.datalab.contrib.pipeline.airflow
589591
table = google.datalab.bigquery.Table('project.test.table')
590592
mock_get_table.return_value = table
591593
mock_table_exists.return_value = True
592594
context = TestCases._create_context()
593595
mock_default_context.return_value = context
594-
596+
mock_composer_env.return_value = {
597+
'config': {'gcsDagLocation': 'gs://foo_bucket/dags'}
598+
}
595599
env = {
596600
'endpoint': 'Interact2',
597601
'job_id': '1234',
@@ -720,6 +724,6 @@ def test_pipeline_cell_golden(self, mock_bucket_class, mock_get_table, mock_tabl
720724
name, cell_body_dict)
721725

722726
mock_bucket_class.assert_called_with('foo_bucket')
723-
mock_bucket_class.return_value.object.assert_called_with('foo_file_path/bq_pipeline_test.py')
727+
mock_bucket_class.return_value.object.assert_called_with('dags/bq_pipeline_test.py')
724728
mock_bucket_class.return_value.object.return_value.write_stream.assert_called_with(
725729
expected_airflow_spec, 'text/plain')

tests/main.py

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
import mlworkbench_magic.ml_tests
5757
import mlworkbench_magic.shell_process_tests
5858
import pipeline.airflow_tests
59+
import pipeline.composer_tests
60+
import pipeline.composer_api_tests
5961
import pipeline.pipeline_tests
6062
import stackdriver.commands.monitoring_tests
6163
import stackdriver.monitoring.group_tests
@@ -104,6 +106,8 @@
104106
ml.metrics_tests,
105107
ml.summary_tests,
106108
mlworkbench_magic.ml_tests,
109+
pipeline.composer_api_tests,
110+
pipeline.composer_tests,
107111
pipeline.airflow_tests,
108112
pipeline.pipeline_tests,
109113
stackdriver.commands.monitoring_tests,

tests/pipeline/composer_api_tests.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2018 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4+
# in compliance with the License. You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software distributed under the License
9+
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10+
# or implied. See the License for the specific language governing permissions and limitations under
11+
# the License.
12+
13+
import unittest
14+
import mock
15+
16+
import google.auth
17+
import google.datalab.utils
18+
from google.datalab.contrib.pipeline.composer._api import Api
19+
20+
21+
class TestCases(unittest.TestCase):
22+
23+
TEST_PROJECT_ID = 'test_project'
24+
25+
def validate(self, mock_http_request, expected_url, expected_args=None, expected_data=None,
26+
expected_headers=None, expected_method=None):
27+
url = mock_http_request.call_args[0][0]
28+
kwargs = mock_http_request.call_args[1]
29+
self.assertEquals(expected_url, url)
30+
if expected_args is not None:
31+
self.assertEquals(expected_args, kwargs['args'])
32+
else:
33+
self.assertNotIn('args', kwargs)
34+
if expected_data is not None:
35+
self.assertEquals(expected_data, kwargs['data'])
36+
else:
37+
self.assertNotIn('data', kwargs)
38+
if expected_headers is not None:
39+
self.assertEquals(expected_headers, kwargs['headers'])
40+
else:
41+
self.assertNotIn('headers', kwargs)
42+
if expected_method is not None:
43+
self.assertEquals(expected_method, kwargs['method'])
44+
else:
45+
self.assertNotIn('method', kwargs)
46+
47+
@mock.patch('google.datalab.Context.default')
48+
@mock.patch('google.datalab.utils.Http.request')
49+
def test_environment_details_get(self, mock_http_request, mock_context_default):
50+
mock_context_default.return_value = TestCases._create_context()
51+
Api.get_environment_details('ZONE', 'ENVIRONMENT')
52+
self.validate(mock_http_request,
53+
'https://composer.googleapis.com/v1alpha1/projects/test_project/locations/ZONE/'
54+
'environments/ENVIRONMENT')
55+
56+
@staticmethod
57+
def _create_context():
58+
project_id = TestCases.TEST_PROJECT_ID
59+
creds = mock.Mock(spec=google.auth.credentials.Credentials)
60+
return google.datalab.Context(project_id, creds)

0 commit comments

Comments
 (0)