diff --git a/README.rst b/README.rst index b914bc1dee1d..7bfb8780b2bc 100644 --- a/README.rst +++ b/README.rst @@ -42,6 +42,7 @@ Cloud Platform services: - `Google Cloud Vision`_ (`Vision README`_) - `Google Cloud Bigtable - HappyBase`_ (`HappyBase README`_) - `Google Cloud Runtime Configuration`_ (`Runtime Config README`_) +- `Cloud Spanner`_ (`Cloud Spanner README`_) **Alpha** indicates that the client library for a particular service is still a work-in-progress and is more likely to get backwards-incompatible @@ -79,6 +80,8 @@ updates. See `versioning`_ for more details. .. _HappyBase README: https://github.com/GoogleCloudPlatform/google-cloud-python-happybase .. _Google Cloud Runtime Configuration: https://cloud.google.com/deployment-manager/runtime-configurator/ .. _Runtime Config README: https://github.com/GoogleCloudPlatform/google-cloud-python/tree/master/runtimeconfig +.. _Cloud Spanner: https://cloud.google.com/spanner/ +.. _Cloud Spanner README: https://github.com/GoogleCloudPlatform/google-cloud-python/tree/master/spanner .. _versioning: https://github.com/GoogleCloudPlatform/google-cloud-python/blob/master/CONTRIBUTING.rst#versioning If you need support for other Google APIs, check out the diff --git a/bigtable/google/cloud/bigtable/client.py b/bigtable/google/cloud/bigtable/client.py index 38a103c7c005..06b35c6d9e94 100644 --- a/bigtable/google/cloud/bigtable/client.py +++ b/bigtable/google/cloud/bigtable/client.py @@ -18,12 +18,14 @@ In the hierarchy of API concepts -* a :class:`Client` owns an :class:`.Instance` -* an :class:`.Instance` owns a :class:`~google.cloud.bigtable.table.Table` +* a :class:`~google.cloud.bigtable.client.Client` owns an + :class:`~google.cloud.bigtable.instance.Instance` +* an :class:`~google.cloud.bigtable.instance.Instance` owns a + :class:`~google.cloud.bigtable.table.Table` * a :class:`~google.cloud.bigtable.table.Table` owns a :class:`~.column_family.ColumnFamily` -* a :class:`~google.cloud.bigtable.table.Table` owns a :class:`~.row.Row` - (and all the cells in the row) +* a :class:`~google.cloud.bigtable.table.Table` owns a + :class:`~google.cloud.bigtable.row.Row` (and all the cells in the row) """ @@ -342,7 +344,7 @@ def instance(self, instance_id, location=_EXISTING_INSTANCE_LOCATION_ID, :param serve_nodes: (Optional) The number of nodes in the instance's cluster; used to set up the instance's cluster. - :rtype: :class:`.Instance` + :rtype: :class:`~google.cloud.bigtable.instance.Instance` :returns: an instance owned by this client. """ return Instance(instance_id, self, location, @@ -353,8 +355,9 @@ def list_instances(self): :rtype: tuple :returns: A pair of results, the first is a list of - :class:`.Instance` objects returned and the second is a - list of strings (the failed locations in the request). + :class:`~google.cloud.bigtable.instance.Instance` objects + returned and the second is a list of strings (the failed + locations in the request). """ request_pb = bigtable_instance_admin_pb2.ListInstancesRequest( parent=self.project_name) diff --git a/bigtable/google/cloud/bigtable/cluster.py b/bigtable/google/cloud/bigtable/cluster.py index c2418576dde9..80b9068958db 100644 --- a/bigtable/google/cloud/bigtable/cluster.py +++ b/bigtable/google/cloud/bigtable/cluster.py @@ -72,7 +72,7 @@ class Cluster(object): :type cluster_id: str :param cluster_id: The ID of the cluster. - :type instance: :class:`.instance.Instance` + :type instance: :class:`~google.cloud.bigtable.instance.Instance` :param instance: The instance where the cluster resides. :type serve_nodes: int @@ -104,7 +104,7 @@ def from_pb(cls, cluster_pb, instance): :type cluster_pb: :class:`instance_pb2.Cluster` :param cluster_pb: A cluster protobuf object. - :type instance: :class:`.instance.Instance>` + :type instance: :class:`~google.cloud.bigtable.instance.Instance>` :param instance: The instance that owns the cluster. :rtype: :class:`Cluster` diff --git a/bigtable/google/cloud/bigtable/table.py b/bigtable/google/cloud/bigtable/table.py index f2120ddc5416..3fbd198d6b65 100644 --- a/bigtable/google/cloud/bigtable/table.py +++ b/bigtable/google/cloud/bigtable/table.py @@ -49,7 +49,7 @@ class Table(object): :type table_id: str :param table_id: The ID of the table. - :type instance: :class:`Instance <.instance.Instance>` + :type instance: :class:`~google.cloud.bigtable.instance.Instance` :param instance: The instance that owns the table. """ diff --git a/datastore/google/cloud/datastore/client.py b/datastore/google/cloud/datastore/client.py index b76a8cce7fc1..87ab8f6ee0c6 100644 --- a/datastore/google/cloud/datastore/client.py +++ b/datastore/google/cloud/datastore/client.py @@ -253,7 +253,8 @@ def get(self, key, missing=None, deferred=None, transaction=None): :param deferred: (Optional) If a list is passed, the keys returned by the backend as "deferred" will be copied into it. - :type transaction: :class:`~.transaction.Transaction` + :type transaction: + :class:`~google.cloud.datastore.transaction.Transaction` :param transaction: (Optional) Transaction to use for read consistency. If not passed, uses current transaction, if set. @@ -281,7 +282,8 @@ def get_multi(self, keys, missing=None, deferred=None, transaction=None): by the backend as "deferred" will be copied into it. If the list is not empty, an error will occur. - :type transaction: :class:`~.transaction.Transaction` + :type transaction: + :class:`~google.cloud.datastore.transaction.Transaction` :param transaction: (Optional) Transaction to use for read consistency. If not passed, uses current transaction, if set. diff --git a/docs/index.rst b/docs/index.rst index 3913c55b7ea1..0b5ab3be38df 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -197,6 +197,32 @@ runtimeconfig-config runtimeconfig-variable +.. toctree:: + :maxdepth: 0 + :hidden: + :caption: Spanner + + spanner-usage + spanner-client-usage + spanner-instance-usage + spanner-database-usage + spanner-session-crud-usage + spanner-session-implicit-txn-usage + spanner-session-pool-usage + spanner-batch-usage + spanner-snapshot-usage + spanner-transaction-usage + + spanner-client-api + spanner-instance-api + spanner-database-api + spanner-session-api + spanner-keyset-api + spanner-snapshot-api + spanner-batch-api + spanner-transaction-api + spanner-streamed-api + .. toctree:: :maxdepth: 0 :hidden: diff --git a/docs/spanner-batch-api.rst b/docs/spanner-batch-api.rst new file mode 100644 index 000000000000..0d0b927e9654 --- /dev/null +++ b/docs/spanner-batch-api.rst @@ -0,0 +1,8 @@ +Batch API +========= + +.. automodule:: google.cloud.spanner.batch + :members: + :show-inheritance: + + diff --git a/docs/spanner-batch-usage.rst b/docs/spanner-batch-usage.rst new file mode 100644 index 000000000000..3bf4917958d2 --- /dev/null +++ b/docs/spanner-batch-usage.rst @@ -0,0 +1,179 @@ +Batching Modifications +###################### + +A :class:`~google.cloud.spanner.batch.Batch` represents a set of data +modification operations to be performed on tables in a dataset. Use of a +``Batch`` does not require creating an explicit +:class:`~google.cloud.spanner.snapshot.Snapshot` or +:class:`~google.cloud.spanner.transaction.Transaction`. Until +:meth:`~google.cloud.spanner.batch.Batch.commit` is called on a ``Batch``, +no changes are propagated to the back-end. + + +Starting a Batch +---------------- + +.. code:: python + + batch = session.batch() + + +Inserting records using a Batch +------------------------------- + +:meth:`Batch.insert` adds one or more new records to a table. Fails if +any of the records already exists. + +.. code:: python + + batch.insert( + 'citizens', columns=['email', 'first_name', 'last_name', 'age'], + values=[ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ]) + +.. note:: + + Ensure that data being sent for ``STRING`` columns uses a text string + (``str`` in Python 3; ``unicode`` in Python 2). + + Additionally, if you are writing data intended for a ``BYTES`` column, you + must base64 encode it. + + +Update records using a Batch +------------------------------- + +:meth:`Batch.update` updates one or more existing records in a table. Fails +if any of the records does not already exist. + +.. code:: python + + batch.update( + 'citizens', columns=['email', 'age'], + values=[ + ['phred@exammple.com', 33], + ['bharney@example.com', 32], + ]) + +.. note:: + + Ensure that data being sent for ``STRING`` columns uses a text string + (``str`` in Python 3; ``unicode`` in Python 2). + + Additionally, if you are writing data intended for a ``BYTES`` column, you + must base64 encode it. + + +Insert or update records using a Batch +-------------------------------------- + +:meth:`Batch.insert_or_update` inserts *or* updates one or more records in a +table. Existing rows have values for the supplied columns overwritten; other +column values are preserved. + +.. code:: python + + batch.insert_or_update( + 'citizens', columns=['email', 'first_name', 'last_name', 'age'], + values=[ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 31], + ['wylma@example.com', 'Wylma', 'Phlyntstone', 29], + ]) + +.. note:: + + Ensure that data being sent for ``STRING`` columns uses a text string + (``str`` in Python 3; ``unicode`` in Python 2). + + Additionally, if you are writing data intended for a ``BYTES`` column, you + must base64 encode it. + + +Replace records using a Batch +----------------------------- + +:meth:`Batch.replace` inserts *or* updates one or more records in a +table. Existing rows have values for the supplied columns overwritten; other +column values are set to null. + +.. code:: python + + batch.replace( + 'citizens', columns=['email', 'first_name', 'last_name', 'age'], + values=[ + ['bharney@example.com', 'Bharney', 'Rhubble', 30], + ['bhettye@example.com', 'Bhettye', 'Rhubble', 30], + ]) + +.. note:: + + Ensure that data being sent for ``STRING`` columns uses a text string + (``str`` in Python 3; ``unicode`` in Python 2). + + Additionally, if you are writing data intended for a ``BYTES`` column, you + must base64 encode it. + + +Delete records using a Batch +---------------------------- + +:meth:`Batch.delete` removes one or more records from a table. Non-existent +rows do not cause errors. + +.. code:: python + + batch.delete('citizens', + keyset['bharney@example.com', 'nonesuch@example.com']) + + +Commit changes for a Batch +-------------------------- + +After describing the modifications to be made to table data via the +:meth:`Batch.insert`, :meth:`Batch.update`, :meth:`Batch.insert_or_update`, +:meth:`Batch.replace`, and :meth:`Batch.delete` methods above, send them to +the back-end by calling :meth:`Batch.commit`, which makes the ``Commit`` +API call. + +.. code:: python + + batch.commit() + + +Use a Batch as a Context Manager +-------------------------------- + +Rather than calling :meth:`Batch.commit` manually, you can use the +:class:`Batch` instance as a context manager, and have it called automatically +if the ``with`` block exits without raising an exception. + +.. code:: python + + with session.batch() as batch: + + batch.insert( + 'citizens', columns=['email', 'first_name', 'last_name', 'age'], + values=[ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ]) + + batch.update( + 'citizens', columns=['email', 'age'], + values=[ + ['phred@exammple.com', 33], + ['bharney@example.com', 32], + ]) + + ... + + batch.delete('citizens', + keyset['bharney@example.com', 'nonesuch@example.com']) + + +Next Step +--------- + +Next, learn about :doc:`spanner-snapshot-usage`. diff --git a/docs/spanner-client-api.rst b/docs/spanner-client-api.rst new file mode 100644 index 000000000000..0716ee742e3a --- /dev/null +++ b/docs/spanner-client-api.rst @@ -0,0 +1,7 @@ +Spanner Client +============== + +.. automodule:: google.cloud.spanner.client + :members: + :show-inheritance: + diff --git a/docs/spanner-client-usage.rst b/docs/spanner-client-usage.rst new file mode 100644 index 000000000000..f42782cb7d85 --- /dev/null +++ b/docs/spanner-client-usage.rst @@ -0,0 +1,68 @@ +Base for Everything +=================== + +To use the API, the :class:`~google.cloud.spanner.client.Client` +class defines a high-level interface which handles authorization +and creating other objects: + +.. code:: python + + from google.cloud.spanner.client import Client + client = Client() + +Long-lived Defaults +------------------- + +When creating a :class:`~google.cloud.spanner.client.Client`, the +``user_agent`` and ``timeout_seconds`` arguments have sensible +defaults +(:data:`~google.cloud.spanner.client.DEFAULT_USER_AGENT` and +:data:`~google.cloud.spanner.client.DEFAULT_TIMEOUT_SECONDS`). +However, you may over-ride them and these will be used throughout all API +requests made with the ``client`` you create. + +Configuration +------------- + +- For an overview of authentication in ``google.cloud-python``, + see :doc:`google-cloud-auth`. + +- In addition to any authentication configuration, you can also set the + :envvar:`GCLOUD_PROJECT` environment variable for the Google Cloud Console + project you'd like to interact with. If your code is running in Google App + Engine or Google Compute Engine the project will be detected automatically. + (Setting this environment variable is not required, you may instead pass the + ``project`` explicitly when constructing a + :class:`~google.cloud.storage.client.Client`). + +- After configuring your environment, create a + :class:`~google.cloud.storage.client.Client` + + .. code:: + + >>> from google.cloud import spanner + >>> client = spanner.Client() + + or pass in ``credentials`` and ``project`` explicitly + + .. code:: + + >>> from google.cloud import spanner + >>> client = spanner.Client(project='my-project', credentials=creds) + +.. tip:: + + Be sure to use the **Project ID**, not the **Project Number**. + + +Next Step +--------- + +After a :class:`~google.cloud.spanner.client.Client`, the next +highest-level object is an :class:`~google.cloud.spanner.instance.Instance`. +You'll need one before you can interact with databases. + +Next, learn about the :doc:`spanner-instance-usage`. + +.. _Instance Admin: https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.instance.v1 +.. _Database Admin: https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1 diff --git a/docs/spanner-database-api.rst b/docs/spanner-database-api.rst new file mode 100644 index 000000000000..1eeed674e7d6 --- /dev/null +++ b/docs/spanner-database-api.rst @@ -0,0 +1,8 @@ +Database API +============ + +.. automodule:: google.cloud.spanner.database + :members: + :show-inheritance: + + diff --git a/docs/spanner-database-usage.rst b/docs/spanner-database-usage.rst new file mode 100644 index 000000000000..3d1a51c6ed9a --- /dev/null +++ b/docs/spanner-database-usage.rst @@ -0,0 +1,124 @@ +Database Admin API +================== + +After creating a :class:`~google.cloud.spanner.instance.Instance`, you can +interact with individual databases for that instance. + + +List Databases +-------------- + +To list of all existing databases for an instance, use its +:meth:`~google.cloud.spanner.instance.Instance.list_databases` method: + +.. code:: python + + databases, token = instance.list_databases() + + +Database Factory +---------------- + +To create a :class:`~google.cloud.spanner.database.Database` object: + +.. code:: python + + database = instance.database(database_id, ddl_statements) + +- ``ddl_statements`` is a string containing DDL for the new database. + +You can also use :meth:`Instance.database` to create a local wrapper for +a database that has already been created: + +.. code:: python + + database = instance.database(existing_database_id) + + +Create a new Database +--------------------- + +After creating the database object, use its +:meth:`~google.cloud.spanner.database.Database.create` method to +trigger its creation on the server: + +.. code:: python + + operation = database.create() + +.. note:: + + Creating an instance triggers a "long-running operation" and + returns an :class:`google.cloud.spanner.database.Operation` + object. See :ref:`check-on-current-database-operation` for polling + to find out if the operation is completed. + + +Update an existing Database +--------------------------- + +After creating the database object, you can apply additional DDL statements +via its :meth:`~google.cloud.spanner.database.Database.update_ddl` method: + +.. code:: python + + operation = instance.update_ddl(ddl_statements, operation_id) + +- ``ddl_statements`` is a string containing DDL to be applied to + the database. + +- ``operation_id`` is a string ID for the long-running operation. + +.. note:: + + Update an instance triggers a "long-running operation" and + returns a :class:`google.cloud.spanner.database.Operation` + object. See :ref:`check-on-current-database-operation` for polling + to find out if the operation is completed. + + +Drop a Database +--------------- + +Drop a databse using its +:meth:`~google.cloud.spanner.database.Database.drop` method: + +.. code:: python + + database.drop() + + +.. _check-on-current-database-operation: + +Check on Current Database Operation +----------------------------------- + +The :meth:`~google.cloud.spanner.database.Database.create` and +:meth:`~google.cloud.spanner.database.Database.update` methods of instance +object trigger long-running operations on the server, and return instances +of the :class:`~google.cloud.spanner.database.Operation` class. + +You can check if a long-running operation has finished +by using its :meth:`~google.cloud.spanner.database.Operation.finished` +method: + +.. code:: python + + >>> operation = instance.create() + >>> operation.finished() + True + +.. note:: + + Once an :class:`~google.cloud.spanner.instance.Operation` object + has returned :data:`True` from its + :meth:`~google.cloud.spanner.instance.Operation.finished` method, the + object should not be re-used. Subsequent calls to + :meth:`~google.cloud.spanner.instance.Operation.finished` + will result in an :exc`ValueError` being raised. + + +Next Step +--------- + +Next, learn about :doc:`spanner-session-crud-usage`. diff --git a/docs/spanner-instance-api.rst b/docs/spanner-instance-api.rst new file mode 100644 index 000000000000..181bed686b28 --- /dev/null +++ b/docs/spanner-instance-api.rst @@ -0,0 +1,8 @@ +Instance API +============ + +.. automodule:: google.cloud.spanner.instance + :members: + :show-inheritance: + + diff --git a/docs/spanner-instance-usage.rst b/docs/spanner-instance-usage.rst new file mode 100644 index 000000000000..035f09c38b81 --- /dev/null +++ b/docs/spanner-instance-usage.rst @@ -0,0 +1,181 @@ +Instance Admin API +================== + +After creating a :class:`~google.cloud.spanner.client.Client`, you can +interact with individual instances for a project. + +Instance Configurations +----------------------- + +Each instance within a project maps to a named "instance configuration", +specifying the location and other parameters for a set of instances. These +configurations are defined by the server, and cannot be changed. + +To list of all instance configurations available to your project, use the +:meth:`~google.cloud.spanner.client.Client.list_instance_configurations` +method of the client: + +.. code:: python + + configs, token = client.list_instance_configurations() + + +To fetch a single instance configuration, use the +:meth:`~google.cloud.spanner.client.Client.get_instance_configuration` +method of the client: + +.. code:: python + + config = client.get_instance_configuration('confg-name') + + +List Instances +-------------- + +If you want a comprehensive list of all existing instances, use the +:meth:`~google.cloud.spanner.client.Client.list_instances` method of +the client: + +.. code:: python + + instances, token = client.list_instances() + + +Instance Factory +---------------- + +To create a :class:`~google.cloud.spanner.instance.Instance` object: + +.. code:: python + + config = configs[0] + instance = client.instance(instance_id, + configuration_name=config.name, + node_count=10, + display_name='My Instance') + +- ``configuration_name`` is the name of the instance configuration to which the + instance will be bound. It must be one of the names configured for your + project, discoverable via + :meth:`google.cloud.spanner.client.Client.list_instance_configurations`. + +- ``node_count`` is a postitive integral count of the number of nodes used + by the instance. More nodes allows for higher performance, but at a higher + billing cost. + +- ``display_name`` is optional. When not provided, ``display_name`` defaults + to the ``instance_id`` value. + +You can also use :meth:`Client.instance` to create a local wrapper for +an instance that has already been created: + +.. code:: python + + instance = client.instance(existing_instance_id) + instance.reload() + + +Create a new Instance +--------------------- + +After creating the instance object, use its +:meth:`~google.cloud.spanner.instance.Instance.create` method to +trigger its creation on the server: + +.. code:: python + + instance.display_name = 'My very own instance' + operation = instance.create() + +.. note:: + + Creating an instance triggers a "long-running operation" and + returns an :class:`google.cloud.spanner.instance.Operation` + object. See :ref:`check-on-current-instance-operation` for polling + to find out if the operation is completed. + + +Refresh metadata for an existing Instance +----------------------------------------- + +After creating the instance object, reload its server-side configuration +using its :meth:`~google.cloud.spanner.instance.Instance.reload` method: + +.. code:: python + + instance.reload() + +This will load ``display_name``, ``config_name``, and ``node_count`` +for the existing ``instance`` object from the back-end. + + +Update an existing Instance +--------------------------- + +After creating the instance object, you can update its metadata via +its :meth:`~google.cloud.spanner.instance.Instance.update` method: + +.. code:: python + + client.display_name = 'New display_name' + operation = instance.update() + +.. note:: + + Update an instance triggers a "long-running operation" and + returns a :class:`google.cloud.spanner.instance.Operation` + object. See :ref:`check-on-current-instance-operation` for polling + to find out if the operation is completed. + + +Delete an existing Instance +--------------------------- + +Delete an instance using its +:meth:`~google.cloud.spanner.instance.Instance.delete` method: + +.. code:: python + + instance.delete() + + +.. _check-on-current-instance-operation: + +Check on Current Instance Operation +----------------------------------- + +The :meth:`~google.cloud.spanner.instance.Instance.create` and +:meth:`~google.cloud.spanner.instance.Instance.update` methods of instance +object trigger long-running operations on the server, and return instances +of the :class:`~google.cloud.spanner.instance.Operation` class. + +You can check if a long-running operation has finished +by using its :meth:`~google.cloud.spanner.instance.Operation.finished` +method: + +.. code:: python + + >>> operation = instance.create() + >>> operation.finished() + True + +.. note:: + + Once an :class:`~google.cloud.spanner.instance.Operation` object + has returned :data:`True` from its + :meth:`~google.cloud.spanner.instance.Operation.finished` method, the + object should not be re-used. Subsequent calls to + :meth:`~google.cloud.spanner.instance.Operation.finished` + will result in an :exc`ValueError` being raised. + +Next Step +--------- + +Now we go down the hierarchy from +:class:`~google.cloud.spanner.instance.Instance` to a +:class:`~google.cloud.spanner.database.Database`. + +Next, learn about the :doc:`spanner-database-usage`. + + +.. _Instance Admin API: https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.instance.v1 diff --git a/docs/spanner-keyset-api.rst b/docs/spanner-keyset-api.rst new file mode 100644 index 000000000000..3f46c6dc95f6 --- /dev/null +++ b/docs/spanner-keyset-api.rst @@ -0,0 +1,8 @@ +Keyset API +========== + +.. automodule:: google.cloud.spanner.keyset + :members: + :show-inheritance: + + diff --git a/docs/spanner-session-api.rst b/docs/spanner-session-api.rst new file mode 100644 index 000000000000..b41c3e6b8c20 --- /dev/null +++ b/docs/spanner-session-api.rst @@ -0,0 +1,15 @@ +Session API +=========== + +.. automodule:: google.cloud.spanner.session + :members: + :show-inheritance: + + +Session Pools API +================= + +.. automodule:: google.cloud.spanner.pool + :members: + :show-inheritance: + diff --git a/docs/spanner-session-crud-usage.rst b/docs/spanner-session-crud-usage.rst new file mode 100644 index 000000000000..43e983f787d8 --- /dev/null +++ b/docs/spanner-session-crud-usage.rst @@ -0,0 +1,80 @@ +Session Creation / Deletion +=========================== + +Outside of the admin APIs, all work with actual table data in a database +occurs in the context of a session. + + +Session Factory +--------------- + +To create a :class:`~google.cloud.spanner.session.Session` object: + +.. code:: python + + session = database.session() + + +Create a new Session +-------------------- + +After creating the session object, use its +:meth:`~google.cloud.spanner.session.Session.create` method to +trigger its creation on the server: + +.. code:: python + + session.create() + + +Test for the existence of a Session +----------------------------------- + +After creating the session object, use its +:meth:`~google.cloud.spanner.session.Session.exists` method to determine +whether the session still exists on the server: + +.. code:: python + + assert session.exists() + + +Delete a Session +---------------- + +Once done with the session object, use its +:meth:`~google.cloud.spanner.session.Session.delete` method to free up +its resources on the server: + +.. code:: python + + session.delete() + + +Using a Session as a Context Manager +------------------------------------ + +Rather than calling the Session's +:meth:`~google.cloud.spanner.session.Session.create` and +:meth:`~google.cloud.spanner.session.Session.delete` methods directly, +you can use the session as a Python context manager: + +.. code:: python + + with database.session() as session: + + assert session.exists() + # perform session operations here + +.. note:: + + At the beginning of the ``with`` block, the session's + :meth:`~google.cloud.spanner.session.Session.create` method is called. + At the end of the ``with`` block, the session's + :meth:`~google.cloud.spanner.session.Session.delete` method is called. + + +Next Step +--------- + +Next, learn about :doc:`spanner-session-implicit-txn-usage`. diff --git a/docs/spanner-session-implicit-txn-usage.rst b/docs/spanner-session-implicit-txn-usage.rst new file mode 100644 index 000000000000..e814d2552887 --- /dev/null +++ b/docs/spanner-session-implicit-txn-usage.rst @@ -0,0 +1,92 @@ +Implicit Transactions +##################### + +The following operations on a session to not require creating an explicit +:class:`~google.cloud.spanner.snapshot.Snapshot` or +:class:`~google.cloud.spanner.transaction.Transaction`. + + +Read Table Data +--------------- + +Read data for selected rows from a table in the session's database. Calls +the ``Read`` API, which returns all rows specified in ``key_set``, or else +fails if the result set is too large, + +.. code:: python + + result = session.read( + table='table-name', columns=['first_name', 'last_name', 'age'], + key_set=['phred@example.com', 'bharney@example.com']) + + for row in result.rows: + print(row) + + +Read Streaming Table Data +------------------------- + +Read data for selected rows from a table in the session's database. Calls +the ``StreamingRead`` API, which returns partial result sets. +:meth:`Session.streaming_read` coalesces these partial result sets as its +result object's rows are iterated. + +.. code:: python + + result = session.read_streaming( + table='table-name', columns=['first_name', 'last_name', 'age'], + key_set=VERY_LONG_LIST_OF_KEYS) + + for row in result.rows: + print(row) + +.. note:: + + If streaming a chunk fails due to a "resumable" error, + :meth:`Session.read_streaming` retries the ``StreamingRead`` API reqeust, + passing the ``resume_token`` from the last partial result streamed. + + +Execute a SQL Select Statement +------------------------------ + +Read data from a query against tables in the session's database. Calls +the ``ExecuteSql`` API, which returns all rows matching the query, or else +fails if the result set is too large, + +.. code:: python + + QUERY = ( + 'SELECT e.first_name, e.last_name, p.telephone ' + 'FROM employees as e, phones as p ' + 'WHERE p.employee_id == e.employee_id') + result = session.execute_sql(QUERY) + + for row in result.rows: + print(row) + + +Execute a Streaming SQL Select Statement +---------------------------------------- + +Read data a query against tables in the session's database. Calls +the ``ExecuteStreamingSql`` API, which returns partial result sets. +:meth:`Session.execute_streaming_sql` coalesces these partial result sets as +its result object's rows are iterated. + +.. code:: python + + QUERY = ( + 'SELECT e.first_name, e.last_name, p.telephone ' + 'FROM employees as e, phones as p ' + 'WHERE p.employee_id == e.employee_id') + result = session.execute_streaming_sql(QUERY) + + for row in result.rows: + print(row) + + +Next Step +--------- + +Next, learn about :doc:`spanner-batch-usage`. diff --git a/docs/spanner-session-pool-usage.rst b/docs/spanner-session-pool-usage.rst new file mode 100644 index 000000000000..883bb6d720b2 --- /dev/null +++ b/docs/spanner-session-pool-usage.rst @@ -0,0 +1,198 @@ +Session Pools +############# + +In order to minimize the latency of session creation, you can set up a +session pool on your database. For instance, to use a pool which does *not* +block when exhausted, and which pings each session at checkout: + +Configuring a session pool for a database +----------------------------------------- + +.. code-block:: python + + from google.cloud.spanner import Client + from google.cloud.spanner import FixedSizePool + client = Client() + instance = client.instance(INSTANCE_NAME) + database = instance.database(DATABASE_NAME) + pool = FixedSizePool(database, size=10, default_timeout=5) + +Note that creating the pool presumes that its database already exists, as +it may need to pre-create sessions (rather than creating them on demand). + +You can supply your own pool implementation, which must satisfy the +contract laid out in +:class:`~google.cloud.spanner.session.AbstractSessionPool`: + +.. code-block:: python + + from google.cloud.spanner import AbstractSessionPool + + class MyCustomPool(AbstractSessionPool): + + def __init__(self, database, custom_param): + super(MyCustomPool, self).__init__(database) + self.custom_param = custom_param + + def get(self, read_only=False): + ... + + def put(self, session, discard_if_full=True): + ... + + database = instance.database(DATABASE_NAME, pool=pool) + pool = MyCustomPool(database, custom_param=42) + + +Checking out sessions from the pool +----------------------------------- + +No matter what kind of pool you create for the database, you can check out +a session from the pool, rather than creating it manually. The +:meth:`~google.cloud.spanner.session.AbstractSessionPool.session` method +returns an object designed to be used as a context manager, checking the +session out from the pool and returning it automatically: + +.. code-block:: python + + with pool.session() as session: + + snapshot = session.snapshot() + + result = snapshot.read( + table='table-name', columns=['first_name', 'last_name', 'age'], + key_set=['phred@example.com', 'bharney@example.com']) + + for row in result.rows: + print(row) + +Some pool implementations may allow additional keyword arguments when checked +out: + +.. code-block:: python + + with pool.session(read_only=True) as session: + + snapshot = session.snapshot() + + result = snapshot.read( + table='table-name', columns=['first_name', 'last_name', 'age'], + key_set=['phred@example.com', 'bharney@example.com']) + + for row in result.rows: + print(row) + + +Lowering latency for read / query operations +-------------------------------------------- + +Some applications may need to minimize latency for read operations, including +particularly the overhead of making an API request to create or refresh a +session. :class:`~google.cloud.spanner.pool.PingingPool` is designed for such +applications, which need to configure a background thread to do the work of +keeping the sessions fresh. + +Create an instance of :class:`~google.cloud.spanner.pool.PingingPool`: + +.. code-block:: python + + from google.cloud.spanner import Client + from google.cloud.spanner import PingingPool + + client = Client() + instance = client.instance(INSTANCE_NAME) + pool = PingingPool(size=10, default_timeout=5, ping_interval=300) + database = instance.database(DATABASE_NAME, pool=pool) + +Set up a background thread to ping the pool's session, keeping them +from becoming stale: + +.. code-block:: python + + import threading + + background = threading.Thread(target=pool.ping, name='ping-pool') + background.daemon = True + background.start() + +``database.execute_sql()`` is a shortcut, which checks out a session, creates a +snapshot, and uses the snapshot to execute a query: + +.. code-block:: python + + QUERY = """\ + SELECT first_name, last_name, age FROM table-name + WHERE email in ["phred@example.com", "bharney@example.com"] + """ + result = database.execute_sql(QUERY) + + for row in result: + do_something_with(row) + + +Lowering latency for mixed read-write operations +------------------------------------------------ + +Some applications may need to minimize latency for read write operations, +including particularly the overhead of making an API request to create or +refresh a session or to begin a session's transaction. +:class:`~google.cloud.spanner.pool.TransactionPingingPool` is designed for +such applications, which need to configure a background thread to do the work +of keeping the sessions fresh and starting their transactions after use. + +Create an instance of +:class:`~google.cloud.spanner.pool.TransactionPingingPool`: + +.. code-block:: python + + from google.cloud.spanner import Client + from google.cloud.spanner import TransactionPingingPool + + client = Client() + instance = client.instance(INSTANCE_NAME) + pool = TransactionPingingPool(size=10, default_timeout=5, ping_interval=300) + database = instance.database(DATABASE_NAME, pool=pool) + +Set up a background thread to ping the pool's session, keeping them +from becoming stale, and ensuring that each session has a new transaction +started before it is used: + +.. code-block:: python + + import threading + + background = threading.Thread(target=pool.ping, name='ping-pool') + background.daemon = True + background.start() + +``database.run_in_transaction()`` is a shortcut: it checks out a session +and uses it to perform a set of read and write operations inside the context +of a transaction, retrying if aborted. The application must supply a callback +function, which is passed a transaction (plus any additional parameters +passed), and does its work using that transaction. + +.. code-block:: python + + import datetime + + QUERY = """\ + SELECT employee_id, sum(hours) FROM daily_hours + WHERE start_date >= %s AND end_date < %s + GROUP BY employee_id id ORDER BY employee_id id""" + + def unit_of_work(transaction, month_start, month_end): + """Compute rolled-up hours for a given month.""" + query = QUERY % (month_start.isoformat(), + (month_end + datetime.timedelta(1)).isoformat()) + row_iter = transaction.execute_sql(query) + + for emp_id, hours, pay in _compute_pay(row_iter): + transaction.insert_or_update( + table='monthly_hours', + columns=['employee_id', 'month', 'hours', 'pay'], + values=[emp_id, month_start, hours, pay]) + + database.run_in_transaction( + unit_of_work, + month_start=datetime.date(2016, 12, 1), + month_end.date(2016, 12, 31)) diff --git a/docs/spanner-snapshot-api.rst b/docs/spanner-snapshot-api.rst new file mode 100644 index 000000000000..b28d55a19feb --- /dev/null +++ b/docs/spanner-snapshot-api.rst @@ -0,0 +1,8 @@ +Snapshot API +============ + +.. automodule:: google.cloud.spanner.snapshot + :members: + :show-inheritance: + + diff --git a/docs/spanner-snapshot-usage.rst b/docs/spanner-snapshot-usage.rst new file mode 100644 index 000000000000..888bbd027b43 --- /dev/null +++ b/docs/spanner-snapshot-usage.rst @@ -0,0 +1,122 @@ +Read-only Transactions via Snapshots +#################################### + +A :class:`~google.cloud.spanner.snapshot.Snapshot` represents a read-only +transaction: when multiple read operations are peformed via a Snapshot, +the results are consistent as of a particular point in time. + + +Beginning a Snapshot +-------------------- + +To begin using a snapshot using the default "bound" (which is "strong"), +meaning all reads are performed at a timestamp where all previously-committed +transactions are visible: + +.. code:: python + + snapshot = session.snapshot() + +You can also specify a weaker bound, which can either be to perform all +reads as of a given timestamp: + +.. code:: python + + import datetime + from pytz import UTC + TIMESTAMP = datetime.utcnow().replace(tzinfo=UTC) + snapshot = session.snapshot(read_timestamp=TIMESTAMP) + +or as of a given duration in the past: + +.. code:: python + + import datetime + DURATION = datetime.timedelta(seconds=5) + snapshot = session.snapshot(exact_staleness=DURATION) + + +Read Table Data +--------------- + +Read data for selected rows from a table in the session's database. Calls +the ``Read`` API, which returns all rows specified in ``key_set``, or else +fails if the result set is too large, + +.. code:: python + + result = snapshot.read( + table='table-name', columns=['first_name', 'last_name', 'age'], + key_set=['phred@example.com', 'bharney@example.com']) + + for row in result.rows: + print(row) + + +Read Streaming Table Data +------------------------- + +Read data for selected rows from a table in the session's database. Calls +the ``StreamingRead`` API, which returns partial result sets. +:meth:`Session.streaming_read` coalesces these partial result sets as its +result object's rows are iterated. + +.. code:: python + + result = snapshot.read_streaming( + table='table-name', columns=['first_name', 'last_name', 'age'], + key_set=VERY_LONG_LIST_OF_KEYS) + + for row in result.rows: + print(row) + +.. note:: + + If streaming a chunk fails due to a "resumable" error, + :meth:`Session.read_streaming` retries the ``StreamingRead`` API reqeust, + passing the ``resume_token`` from the last partial result streamed. + + +Execute a SQL Select Statement +------------------------------ + +Read data from a query against tables in the session's database. Calls +the ``ExecuteSql`` API, which returns all rows matching the query, or else +fails if the result set is too large, + +.. code:: python + + QUERY = ( + 'SELECT e.first_name, e.last_name, p.telephone ' + 'FROM employees as e, phones as p ' + 'WHERE p.employee_id == e.employee_id') + result = snapshot.execute_sql(QUERY) + + for row in result.rows: + print(row) + + +Execute a Streaming SQL Select Statement +---------------------------------------- + +Read data a query against tables in the session's database. Calls +the ``ExecuteStreamingSql`` API, which returns partial result sets. +:meth:`Session.execute_streaming_sql` coalesces these partial result sets as +its result object's rows are iterated. + +.. code:: python + + QUERY = ( + 'SELECT e.first_name, e.last_name, p.telephone ' + 'FROM employees as e, phones as p ' + 'WHERE p.employee_id == e.employee_id') + result = snapshot.execute_streaming_sql(QUERY) + + for row in result.rows: + print(row) + + +Next Step +--------- + +Next, learn about :doc:`spanner-transaction-usage`. diff --git a/docs/spanner-streamed-api.rst b/docs/spanner-streamed-api.rst new file mode 100644 index 000000000000..e17180accf8b --- /dev/null +++ b/docs/spanner-streamed-api.rst @@ -0,0 +1,8 @@ +StreamedResultSet API +===================== + +.. automodule:: google.cloud.spanner.streamed + :members: + :show-inheritance: + + diff --git a/docs/spanner-transaction-api.rst b/docs/spanner-transaction-api.rst new file mode 100644 index 000000000000..c16213de54bd --- /dev/null +++ b/docs/spanner-transaction-api.rst @@ -0,0 +1,8 @@ +Transaction API +=============== + +.. automodule:: google.cloud.spanner.transaction + :members: + :show-inheritance: + + diff --git a/docs/spanner-transaction-usage.rst b/docs/spanner-transaction-usage.rst new file mode 100644 index 000000000000..adb8ffe320a9 --- /dev/null +++ b/docs/spanner-transaction-usage.rst @@ -0,0 +1,271 @@ +Read-write Transactions +####################### + +A :class:`~google.cloud.spanner.transaction.Transaction` represents a +transaction: when the transaction commits, it will send any accumulated +mutations to the server. + + +Begin a Transaction +------------------- + +To begin using a transaction: + +.. code:: python + + transaction = session.transaction() + + +Read Table Data +--------------- + +Read data for selected rows from a table in the session's database. Calls +the ``Read`` API, which returns all rows specified in ``key_set``, or else +fails if the result set is too large, + +.. code:: python + + result = transaction.read( + table='table-name', columns=['first_name', 'last_name', 'age'], + key_set=['phred@example.com', 'bharney@example.com']) + + for row in result.rows: + print(row) + + +Read Streaming Table Data +------------------------- + +Read data for selected rows from a table in the session's database. Calls +the ``StreamingRead`` API, which returns partial result sets. +:meth:`Session.streaming_read` coalesces these partial result sets as its +result object's rows are iterated. + +.. code:: python + + result = transaction.read_streaming( + table='table-name', columns=['first_name', 'last_name', 'age'], + key_set=VERY_LONG_LIST_OF_KEYS) + + for row in result.rows: + print(row) + +.. note:: + + If streaming a chunk fails due to a "resumable" error, + :meth:`Session.read_streaming` retries the ``StreamingRead`` API reqeust, + passing the ``resume_token`` from the last partial result streamed. + + +Execute a SQL Select Statement +------------------------------ + +Read data from a query against tables in the session's database. Calls +the ``ExecuteSql`` API, which returns all rows matching the query, or else +fails if the result set is too large, + +.. code:: python + + QUERY = ( + 'SELECT e.first_name, e.last_name, p.telephone ' + 'FROM employees as e, phones as p ' + 'WHERE p.employee_id == e.employee_id') + result = transaction.execute_sql(QUERY) + + for row in result.rows: + print(row) + + +Execute a Streaming SQL Select Statement +---------------------------------------- + +Read data a query against tables in the session's database. Calls +the ``ExecuteStreamingSql`` API, which returns partial result sets. +:meth:`Session.execute_streaming_sql` coalesces these partial result sets as +its result object's rows are iterated. + +.. code:: python + + QUERY = ( + 'SELECT e.first_name, e.last_name, p.telephone ' + 'FROM employees as e, phones as p ' + 'WHERE p.employee_id == e.employee_id') + result = transaction.execute_streaming_sql(QUERY) + + for row in result.rows: + print(row) + + +Insert records using a Transaction +---------------------------------- + +:meth:`Transaction.insert` adds one or more new records to a table. Fails if +any of the records already exists. + +.. code:: python + + transaction.insert( + 'citizens', columns=['email', 'first_name', 'last_name', 'age'], + values=[ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ]) + +.. note:: + + Ensure that data being sent for ``STRING`` columns uses a text string + (``str`` in Python 3; ``unicode`` in Python 2). + + Additionally, if you are writing data intended for a ``BYTES`` column, you + must base64 encode it. + + +Update records using a Transaction +---------------------------------- + +:meth:`Transaction.update` updates one or more existing records in a table. Fails +if any of the records does not already exist. + +.. code:: python + + transaction.update( + 'citizens', columns=['email', 'age'], + values=[ + ['phred@exammple.com', 33], + ['bharney@example.com', 32], + ]) + +.. note:: + + Ensure that data being sent for ``STRING`` columns uses a text string + (``str`` in Python 3; ``unicode`` in Python 2). + + Additionally, if you are writing data intended for a ``BYTES`` column, you + must base64 encode it. + + +Insert or update records using a Transaction +-------------------------------------------- + +:meth:`Transaction.insert_or_update` inserts *or* updates one or more records +in a table. Existing rows have values for the supplied columns overwritten; +other column values are preserved. + +.. code:: python + + transaction.insert_or_update( + 'citizens', columns=['email', 'first_name', 'last_name', 'age'], + values=[ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 31], + ['wylma@example.com', 'Wylma', 'Phlyntstone', 29], + ]) + +.. note:: + + Ensure that data being sent for ``STRING`` columns uses a text string + (``str`` in Python 3; ``unicode`` in Python 2). + + Additionally, if you are writing data intended for a ``BYTES`` column, you + must base64 encode it. + + +Replace records using a Transaction +----------------------------------- + +:meth:`Transaction.replace` inserts *or* updates one or more records in a +table. Existing rows have values for the supplied columns overwritten; other +column values are set to null. + +.. code:: python + + transaction.replace( + 'citizens', columns=['email', 'first_name', 'last_name', 'age'], + values=[ + ['bharney@example.com', 'Bharney', 'Rhubble', 30], + ['bhettye@example.com', 'Bhettye', 'Rhubble', 30], + ]) + +.. note:: + + Ensure that data being sent for ``STRING`` columns uses a text string + (``str`` in Python 3; ``unicode`` in Python 2). + + Additionally, if you are writing data intended for a ``BYTES`` column, you + must base64 encode it. + + +Delete records using a Transaction +---------------------------------- + +:meth:`Transaction.delete` removes one or more records from a table. +Non-existent rows do not cause errors. + +.. code:: python + + transaction.delete( + 'citizens', keyset=['bharney@example.com', 'nonesuch@example.com']) + + +Commit changes for a Transaction +-------------------------------- + +After describing the modifications to be made to table data via the +:meth:`Transaction.insert`, :meth:`Transaction.update`, +:meth:`Transaction.insert_or_update`, :meth:`Transaction.replace`, and +:meth:`Transaction.delete` methods above, send them to +the back-end by calling :meth:`Transaction.commit`, which makes the ``Commit`` +API call. + +.. code:: python + + transaction.commit() + + +Roll back changes for a Transaction +----------------------------------- + +After describing the modifications to be made to table data via the +:meth:`Transaction.insert`, :meth:`Transaction.update`, +:meth:`Transaction.insert_or_update`, :meth:`Transaction.replace`, and +:meth:`Transaction.delete` methods above, cancel the transaction on the +the back-end by calling :meth:`Transaction.rollback`, which makes the +``Rollback`` API call. + +.. code:: python + + transaction.rollback() + + +Use a Transaction as a Context Manager +-------------------------------------- + +Rather than calling :meth:`Transaction.commit` or :meth:`Transaction.rollback` +manually, you can use the :class:`Transaction` instance as a context manager: +in that case, the transaction's :meth:`~Transaction.commit` method will +called automatically if the ``with`` block exits without raising an exception. + +If an exception is raised inside the ``with`` block, the transaction's +:meth:`~Transaction.rollback` method will be called instead. + +.. code:: python + + with session.transaction() as transaction + + transaction.insert( + 'citizens', columns=['email', 'first_name', 'last_name', 'age'], + values=[ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ]) + + transaction.update( + 'citizens', columns=['email', 'age'], + values=[ + ['phred@exammple.com', 33], + ['bharney@example.com', 32], + ]) + + ... + + transaction.delete('citizens', + keyset['bharney@example.com', 'nonesuch@example.com']) diff --git a/docs/spanner-usage.rst b/docs/spanner-usage.rst new file mode 100644 index 000000000000..f308201b674f --- /dev/null +++ b/docs/spanner-usage.rst @@ -0,0 +1,20 @@ +Using the API +============= + +API requests are sent to the `Cloud Spanner`_ API via RPC over +HTTP/2. In order to support this, we'll rely on `gRPC`_. + +Get started by learning about the :class:`~google.cloud.spanner.client.Client` +on the :doc:`spanner-client-usage` page. + +In the hierarchy of API concepts + +* a :class:`~google.cloud.spanner.client.Client` owns an + :class:`~google.cloud.spanner.instance.Instance` +* an :class:`~google.cloud.spanner.instance.Instance` owns a + :class:`~google.cloud.spanner.database.Database` + +.. _Cloud Spanner: https://cloud.google.com/spanner/docs/ +.. _gRPC: http://www.grpc.io/ +.. _grpcio: https://pypi.python.org/pypi/grpcio + diff --git a/scripts/verify_included_modules.py b/scripts/verify_included_modules.py index 2e300e196a06..5e01643713d4 100644 --- a/scripts/verify_included_modules.py +++ b/scripts/verify_included_modules.py @@ -42,6 +42,7 @@ 'google.cloud.monitoring', 'google.cloud.pubsub', 'google.cloud.resource_manager', + 'google.cloud.spanner', 'google.cloud.speech', 'google.cloud.storage', 'google.cloud.streaming', @@ -68,6 +69,7 @@ 'pubsub', 'resource_manager', 'runtimeconfig', + 'spanner', 'speech', 'storage', 'translate', diff --git a/spanner/.coveragerc b/spanner/.coveragerc new file mode 100644 index 000000000000..a54b99aa14b7 --- /dev/null +++ b/spanner/.coveragerc @@ -0,0 +1,11 @@ +[run] +branch = True + +[report] +fail_under = 100 +show_missing = True +exclude_lines = + # Re-enable the standard pragma + pragma: NO COVER + # Ignore debug-only repr + def __repr__ diff --git a/spanner/MANIFEST.in b/spanner/MANIFEST.in new file mode 100644 index 000000000000..cb3a2b9ef4fa --- /dev/null +++ b/spanner/MANIFEST.in @@ -0,0 +1,4 @@ +include README.rst +graft google +graft unit_tests +global-exclude *.pyc diff --git a/spanner/README.rst b/spanner/README.rst new file mode 100644 index 000000000000..bc2000c3daa4 --- /dev/null +++ b/spanner/README.rst @@ -0,0 +1,11 @@ +Python Client for Cloud Spanner +=============================== + + Python idiomatic client for `Cloud Spanner`_ + +Quick Start +----------- + +:: + + $ pip install --upgrade google-cloud-spanner diff --git a/spanner/google/__init__.py b/spanner/google/__init__.py new file mode 100644 index 000000000000..b2b833373882 --- /dev/null +++ b/spanner/google/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import pkg_resources + pkg_resources.declare_namespace(__name__) +except ImportError: + import pkgutil + __path__ = pkgutil.extend_path(__path__, __name__) diff --git a/spanner/google/cloud/__init__.py b/spanner/google/cloud/__init__.py new file mode 100644 index 000000000000..b2b833373882 --- /dev/null +++ b/spanner/google/cloud/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import pkg_resources + pkg_resources.declare_namespace(__name__) +except ImportError: + import pkgutil + __path__ = pkgutil.extend_path(__path__, __name__) diff --git a/spanner/google/cloud/spanner/__init__.py b/spanner/google/cloud/spanner/__init__.py new file mode 100644 index 000000000000..6c5f790366b9 --- /dev/null +++ b/spanner/google/cloud/spanner/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Spanner API package.""" + + +from google.cloud.spanner.client import Client + +from google.cloud.spanner.keyset import KeyRange +from google.cloud.spanner.keyset import KeySet + +from google.cloud.spanner.pool import AbstractSessionPool +from google.cloud.spanner.pool import BurstyPool +from google.cloud.spanner.pool import FixedSizePool diff --git a/spanner/google/cloud/spanner/_fixtures.py b/spanner/google/cloud/spanner/_fixtures.py new file mode 100644 index 000000000000..c63d942f9883 --- /dev/null +++ b/spanner/google/cloud/spanner/_fixtures.py @@ -0,0 +1,33 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test fixtures.""" + + +DDL = """\ +CREATE TABLE contacts ( + contact_id INT64, + first_name STRING(1024), + last_name STRING(1024), + email STRING(1024) ) + PRIMARY KEY (contact_id); +CREATE TABLE contact_phones ( + contact_id INT64, + phone_type STRING(1024), + phone_number STRING(1024) ) + PRIMARY KEY (contact_id, phone_type), + INTERLEAVE IN PARENT contacts ON DELETE CASCADE; +""" + +DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(';') if stmt.strip()] diff --git a/spanner/google/cloud/spanner/_helpers.py b/spanner/google/cloud/spanner/_helpers.py new file mode 100644 index 000000000000..8d64106ba4fc --- /dev/null +++ b/spanner/google/cloud/spanner/_helpers.py @@ -0,0 +1,268 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for Cloud Spanner.""" + +import datetime +import math + +import six + +from google.gax import CallOptions +from google.protobuf.struct_pb2 import ListValue +from google.protobuf.struct_pb2 import Value +from google.cloud.proto.spanner.v1 import type_pb2 + +from google.cloud._helpers import _date_from_iso8601_date +from google.cloud._helpers import _datetime_to_rfc3339 +from google.cloud._helpers import _RFC3339_NANOS +from google.cloud._helpers import _RFC3339_NO_FRACTION +from google.cloud._helpers import UTC + + +class TimestampWithNanoseconds(datetime.datetime): + """Track nanosecond in addition to normal datetime attrs. + + nanosecond can be passed only as a keyword argument. + """ + __slots__ = ('_nanosecond',) + + def __new__(cls, *args, **kw): + nanos = kw.pop('nanosecond', 0) + if nanos > 0: + if 'microsecond' in kw: + raise TypeError( + "Specify only one of 'microsecond' or 'nanosecond'") + kw['microsecond'] = nanos // 1000 + inst = datetime.datetime.__new__(cls, *args, **kw) + inst._nanosecond = nanos or 0 + return inst + + @property + def nanosecond(self): + """Read-only: nanosecond precision.""" + return self._nanosecond + + def rfc3339(self): + """RFC 3339-compliant timestamp. + + :rtype: str + :returns: Timestamp string according to RFC 3339 spec. + """ + if self._nanosecond == 0: + return _datetime_to_rfc3339(self) + nanos = str(self._nanosecond).rstrip('0') + return '%s.%sZ' % (self.strftime(_RFC3339_NO_FRACTION), nanos) + + @classmethod + def from_rfc3339(cls, stamp): + """Parse RFC 3339-compliant timestamp, preserving nanoseconds. + + :type stamp: str + :param stamp: RFC 3339 stamp, with up to nanosecond precision + + :rtype: :class:`TimestampWithNanoseconds` + :returns: an instance matching the timestamp string + """ + with_nanos = _RFC3339_NANOS.match(stamp) + if with_nanos is None: + raise ValueError( + 'Timestamp: %r, does not match pattern: %r' % ( + stamp, _RFC3339_NANOS.pattern)) + bare = datetime.datetime.strptime( + with_nanos.group('no_fraction'), _RFC3339_NO_FRACTION) + fraction = with_nanos.group('nanos') + if fraction is None: + nanos = 0 + else: + scale = 9 - len(fraction) + nanos = int(fraction) * (10 ** scale) + return cls(bare.year, bare.month, bare.day, + bare.hour, bare.minute, bare.second, + nanosecond=nanos, tzinfo=UTC) + + +def _try_to_coerce_bytes(bytestring): + """Try to coerce a byte string into the right thing based on Python + version and whether or not it is base64 encoded. + + Return a text string or raise ValueError. + """ + # Attempt to coerce using google.protobuf.Value, which will expect + # something that is utf-8 (and base64 consistently is). + try: + Value(string_value=bytestring) + return bytestring + except ValueError: + raise ValueError('Received a bytes that is not base64 encoded. ' + 'Ensure that you either send a Unicode string or a ' + 'base64-encoded bytes.') + + +# pylint: disable=too-many-return-statements +def _make_value_pb(value): + """Helper for :func:`_make_list_value_pbs`. + + :type value: scalar value + :param value: value to convert + + :rtype: :class:`~google.protobuf.struct_pb2.Value` + :returns: value protobufs + :raises: :exc:`ValueError` if value is not of a known scalar type. + """ + if value is None: + return Value(null_value='NULL_VALUE') + if isinstance(value, list): + return Value(list_value=_make_list_value_pb(value)) + if isinstance(value, bool): + return Value(bool_value=value) + if isinstance(value, six.integer_types): + return Value(string_value=str(value)) + if isinstance(value, float): + if math.isnan(value): + return Value(string_value='NaN') + if math.isinf(value): + return Value(string_value=str(value)) + return Value(number_value=value) + if isinstance(value, TimestampWithNanoseconds): + return Value(string_value=value.rfc3339()) + if isinstance(value, datetime.datetime): + return Value(string_value=_datetime_to_rfc3339(value)) + if isinstance(value, datetime.date): + return Value(string_value=value.isoformat()) + if isinstance(value, six.binary_type): + value = _try_to_coerce_bytes(value) + return Value(string_value=value) + if isinstance(value, six.text_type): + return Value(string_value=value) + raise ValueError("Unknown type: %s" % (value,)) +# pylint: enable=too-many-return-statements + + +def _make_list_value_pb(values): + """Construct of ListValue protobufs. + + :type values: list of scalar + :param values: Row data + + :rtype: :class:`~google.protobuf.struct_pb2.ListValue` + :returns: protobuf + """ + return ListValue(values=[_make_value_pb(value) for value in values]) + + +def _make_list_value_pbs(values): + """Construct a sequence of ListValue protobufs. + + :type values: list of list of scalar + :param values: Row data + + :rtype: list of :class:`~google.protobuf.struct_pb2.ListValue` + :returns: sequence of protobufs + """ + return [_make_list_value_pb(row) for row in values] + + +# pylint: disable=too-many-branches +def _parse_value_pb(value_pb, field_type): + """Convert a Value protobuf to cell data. + + :type value_pb: :class:`~google.protobuf.struct_pb2.Value` + :param value_pb: protobuf to convert + + :type field_type: :class:`~google.cloud.proto.spanner.v1.type_pb2.Type` + :param field_type: type code for the value + + :rtype: varies on field_type + :returns: value extracted from value_pb + :raises: ValueError if uknown type is passed + """ + if value_pb.HasField('null_value'): + return None + if field_type.code == type_pb2.STRING: + result = value_pb.string_value + elif field_type.code == type_pb2.BYTES: + result = value_pb.string_value.encode('utf8') + elif field_type.code == type_pb2.BOOL: + result = value_pb.bool_value + elif field_type.code == type_pb2.INT64: + result = int(value_pb.string_value) + elif field_type.code == type_pb2.FLOAT64: + if value_pb.HasField('string_value'): + result = float(value_pb.string_value) + else: + result = value_pb.number_value + elif field_type.code == type_pb2.DATE: + result = _date_from_iso8601_date(value_pb.string_value) + elif field_type.code == type_pb2.TIMESTAMP: + result = TimestampWithNanoseconds.from_rfc3339(value_pb.string_value) + elif field_type.code == type_pb2.ARRAY: + result = [ + _parse_value_pb(item_pb, field_type.array_element_type) + for item_pb in value_pb.list_value.values] + elif field_type.code == type_pb2.STRUCT: + result = [ + _parse_value_pb(item_pb, field_type.struct_type.fields[i].type) + for (i, item_pb) in enumerate(value_pb.list_value.values)] + else: + raise ValueError("Unknown type: %s" % (field_type,)) + return result +# pylint: enable=too-many-branches + + +def _parse_list_value_pbs(rows, row_type): + """Convert a list of ListValue protobufs into a list of list of cell data. + + :type rows: list of :class:`~google.protobuf.struct_pb2.ListValue` + :param rows: row data returned from a read/query + + :type row_type: :class:`~google.cloud.proto.spanner.v1.type_pb2.StructType` + :param row_type: row schema specification + + :rtype: list of list of cell data + :returns: data for the rows, coerced into appropriate types + """ + result = [] + for row in rows: + row_data = [] + for value_pb, field in zip(row.values, row_type.fields): + row_data.append(_parse_value_pb(value_pb, field.type)) + result.append(row_data) + return result + + +class _SessionWrapper(object): + """Base class for objects wrapping a session. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session used to perform the commit + """ + def __init__(self, session): + self._session = session + + +def _options_with_prefix(prefix, **kw): + """Create GAPIC options w/ prefix. + + :type prefix: str + :param prefix: appropriate resource path + + :type kw: dict + :param kw: other keyword arguments passed to the constructor + + :rtype: :class:`~google.gax.CallOptions` + :returns: GAPIC call options with supplied prefix + """ + return CallOptions( + metadata=[('google-cloud-resource-prefix', prefix)], **kw) diff --git a/spanner/google/cloud/spanner/batch.py b/spanner/google/cloud/spanner/batch.py new file mode 100644 index 000000000000..552d7960b1ab --- /dev/null +++ b/spanner/google/cloud/spanner/batch.py @@ -0,0 +1,192 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Context manager for Cloud Spanner batched writes.""" + +from google.cloud.proto.spanner.v1.mutation_pb2 import Mutation +from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions + +# pylint: disable=ungrouped-imports +from google.cloud._helpers import _pb_timestamp_to_datetime +from google.cloud.spanner._helpers import _SessionWrapper +from google.cloud.spanner._helpers import _make_list_value_pbs +from google.cloud.spanner._helpers import _options_with_prefix +# pylint: enable=ungrouped-imports + + +class _BatchBase(_SessionWrapper): + """Accumulate mutations for transmission during :meth:`commit`. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session used to perform the commit + """ + def __init__(self, session): + super(_BatchBase, self).__init__(session) + self._mutations = [] + + def _check_state(self): + """Helper for :meth:`commit` et al. + + Subclasses must override + + :raises: :exc:`ValueError` if the object's state is invalid for making + API requests. + """ + raise NotImplementedError + + def insert(self, table, columns, values): + """Insert one or more new table rows. + + :type table: str + :param table: Name of the table to be modified. + + :type columns: list of str + :param columns: Name of the table columns to be modified. + + :type values: list of lists + :param values: Values to be modified. + """ + self._mutations.append(Mutation( + insert=_make_write_pb(table, columns, values))) + + def update(self, table, columns, values): + """Update one or more existing table rows. + + :type table: str + :param table: Name of the table to be modified. + + :type columns: list of str + :param columns: Name of the table columns to be modified. + + :type values: list of lists + :param values: Values to be modified. + """ + self._mutations.append(Mutation( + update=_make_write_pb(table, columns, values))) + + def insert_or_update(self, table, columns, values): + """Insert/update one or more table rows. + + :type table: str + :param table: Name of the table to be modified. + + :type columns: list of str + :param columns: Name of the table columns to be modified. + + :type values: list of lists + :param values: Values to be modified. + """ + self._mutations.append(Mutation( + insert_or_update=_make_write_pb(table, columns, values))) + + def replace(self, table, columns, values): + """Replace one or more table rows. + + :type table: str + :param table: Name of the table to be modified. + + :type columns: list of str + :param columns: Name of the table columns to be modified. + + :type values: list of lists + :param values: Values to be modified. + """ + self._mutations.append(Mutation( + replace=_make_write_pb(table, columns, values))) + + def delete(self, table, keyset): + """Delete one or more table rows. + + :type table: str + :param table: Name of the table to be modified. + + :type keyset: :class:`~google.cloud.spanner.keyset.Keyset` + :param keyset: Keys/ranges identifying rows to delete. + """ + delete = Mutation.Delete( + table=table, + key_set=keyset.to_pb(), + ) + self._mutations.append(Mutation( + delete=delete)) + + +class Batch(_BatchBase): + """Accumulate mutations for transmission during :meth:`commit`. + """ + committed = None + """Timestamp at which the batch was successfully committed.""" + + def _check_state(self): + """Helper for :meth:`commit` et al. + + Subclasses must override + + :raises: :exc:`ValueError` if the object's state is invalid for making + API requests. + """ + if self.committed is not None: + raise ValueError("Batch already committed") + + def commit(self): + """Commit mutations to the database. + + :rtype: datetime + :returns: timestamp of the committed changes. + """ + self._check_state() + database = self._session._database + api = database.spanner_api + options = _options_with_prefix(database.name) + txn_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite()) + response = api.commit(self._session.name, self._mutations, + single_use_transaction=txn_options, + options=options) + self.committed = _pb_timestamp_to_datetime( + response.commit_timestamp) + return self.committed + + def __enter__(self): + """Begin ``with`` block.""" + self._check_state() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """End ``with`` block.""" + if exc_type is None: + self.commit() + + +def _make_write_pb(table, columns, values): + """Helper for :meth:`Batch.insert` et aliae. + + :type table: str + :param table: Name of the table to be modified. + + :type columns: list of str + :param columns: Name of the table columns to be modified. + + :type values: list of lists + :param values: Values to be modified. + + :rtype: :class:`google.cloud.proto.spanner.v1.mutation_pb2.Mutation.Write` + :returns: Write protobuf + """ + return Mutation.Write( + table=table, + columns=columns, + values=_make_list_value_pbs(values), + ) diff --git a/spanner/google/cloud/spanner/client.py b/spanner/google/cloud/spanner/client.py new file mode 100644 index 000000000000..678ac5551588 --- /dev/null +++ b/spanner/google/cloud/spanner/client.py @@ -0,0 +1,326 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parent client for calling the Cloud Spanner API. + +This is the base from which all interactions with the API occur. + +In the hierarchy of API concepts + +* a :class:`~google.cloud.spanner.client.Client` owns an + :class:`~google.cloud.spanner.instance.Instance` +* a :class:`~google.cloud.spanner.instance.Instance` owns a + :class:`~google.cloud.spanner.database.Database` +""" + +import google.auth.credentials +from google.gax import INITIAL_PAGE +from google.longrunning import operations_grpc +# pylint: disable=line-too-long +from google.cloud.gapic.spanner_admin_database.v1.database_admin_client import ( # noqa + DatabaseAdminClient) +from google.cloud.gapic.spanner_admin_instance.v1.instance_admin_client import ( # noqa + InstanceAdminClient) +# pylint: enable=line-too-long + +from google.cloud._helpers import make_secure_stub +from google.cloud._http import DEFAULT_USER_AGENT +from google.cloud.client import _ClientFactoryMixin +from google.cloud.client import _ClientProjectMixin +from google.cloud.credentials import get_credentials +from google.cloud.iterator import GAXIterator +from google.cloud.spanner._helpers import _options_with_prefix +from google.cloud.spanner.instance import DEFAULT_NODE_COUNT +from google.cloud.spanner.instance import Instance + +SPANNER_ADMIN_SCOPE = 'https://www.googleapis.com/auth/spanner.admin' + +OPERATIONS_API_HOST = 'spanner.googleapis.com' + + +class InstanceConfig(object): + """Named configurations for Spanner instances. + + :type name: str + :param name: ID of the instance configuration + + :type display_name: str + :param display_name: Name of the instance configuration + """ + def __init__(self, name, display_name): + self.name = name + self.display_name = display_name + + @classmethod + def from_pb(cls, config_pb): + """Construct an instance from the equvalent protobuf. + + :type config_pb: + :class:`~google.spanner.v1.spanner_instance_admin_pb2.InstanceConfig` + :param config_pb: the protobuf to parse + + :rtype: :class:`InstanceConfig` + :returns: an instance of this class + """ + return cls(config_pb.name, config_pb.display_name) + + +def _make_operations_stub(client): + """Helper for :meth:`Client._operations_stub`""" + return make_secure_stub(client.credentials, client.user_agent, + operations_grpc.OperationsStub, + OPERATIONS_API_HOST) + + +class Client(_ClientFactoryMixin, _ClientProjectMixin): + """Client for interacting with Cloud Spanner API. + + .. note:: + + Since the Cloud Spanner API requires the gRPC transport, no + ``http`` argument is accepted by this class. + + :type project: :class:`str` or :func:`unicode ` + :param project: (Optional) The ID of the project which owns the + instances, tables and data. If not provided, will + attempt to determine from the environment. + + :type credentials: + :class:`OAuth2Credentials ` or + :data:`NoneType ` + :param credentials: (Optional) The OAuth2 Credentials to use for this + client. If not provided, defaults to the Google + Application Default Credentials. + + :type user_agent: str + :param user_agent: (Optional) The user agent to be used with API request. + Defaults to :const:`DEFAULT_USER_AGENT`. + + :raises: :class:`ValueError ` if both ``read_only`` + and ``admin`` are :data:`True` + """ + _instance_admin_api = None + _database_admin_api = None + _operations_stub_internal = None + + def __init__(self, project=None, credentials=None, + user_agent=DEFAULT_USER_AGENT): + + _ClientProjectMixin.__init__(self, project=project) + if credentials is None: + credentials = get_credentials() + + scopes = [ + SPANNER_ADMIN_SCOPE, + ] + + credentials = google.auth.credentials.with_scopes_if_required( + credentials, scopes) + + self._credentials = credentials + self.user_agent = user_agent + + @property + def credentials(self): + """Getter for client's credentials. + + :rtype: + :class:`OAuth2Credentials ` + :returns: The credentials stored on the client. + """ + return self._credentials + + @property + def project_name(self): + """Project name to be used with Spanner APIs. + + .. note:: + + This property will not change if ``project`` does not, but the + return value is not cached. + + The project name is of the form + + ``"projects/{project}"`` + + :rtype: str + :returns: The project name to be used with the Cloud Spanner Admin + API RPC service. + """ + return 'projects/' + self.project + + @property + def instance_admin_api(self): + """Helper for session-related API calls.""" + if self._instance_admin_api is None: + self._instance_admin_api = InstanceAdminClient() + return self._instance_admin_api + + @property + def database_admin_api(self): + """Helper for session-related API calls.""" + if self._database_admin_api is None: + self._database_admin_api = DatabaseAdminClient() + return self._database_admin_api + + @property + def _operations_stub(self): + """Stub for google.longrunning.operations calls. + + .. note: + + Will be replaced by a GAX API helper once that library is + released. + """ + if self._operations_stub_internal is None: + self._operations_stub_internal = _make_operations_stub(self) + return self._operations_stub_internal + + def copy(self): + """Make a copy of this client. + + Copies the local data stored as simple types but does not copy the + current state of any open connections with the Cloud Bigtable API. + + :rtype: :class:`.Client` + :returns: A copy of the current client. + """ + credentials = self._credentials + copied_creds = credentials.create_scoped(credentials.scopes) + return self.__class__( + self.project, + copied_creds, + self.user_agent, + ) + + def list_instance_configs(self, page_size=None, page_token=None): + """List available instance configurations for the client's project. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.InstanceAdmin.ListInstanceConfigs + + :type page_size: int + :param page_size: (Optional) Maximum number of results to return. + + :type page_token: str + :param page_token: (Optional) Token for fetching next page of results. + + :rtype: :class:`~google.cloud.iterator.Iterator` + :returns: + Iterator of + :class:`~google.cloud.spanner.instance.InstanceConfig` + resources within the client's project. + """ + if page_token is None: + page_token = INITIAL_PAGE + options = _options_with_prefix(self.project_name, + page_token=page_token) + path = 'projects/%s' % (self.project,) + page_iter = self.instance_admin_api.list_instance_configs( + path, page_size=page_size, options=options) + return GAXIterator(self, page_iter, _item_to_instance_config) + + def instance(self, instance_id, + configuration_name=None, + display_name=None, + node_count=DEFAULT_NODE_COUNT): + """Factory to create a instance associated with this client. + + :type instance_id: str + :param instance_id: The ID of the instance. + + :type configuration_name: string + :param configuration_name: + (Optional) Name of the instance configuration used to set up the + instance's cluster, in the form: + ``projects//instanceConfigs/``. + **Required** for instances which do not yet exist. + + :type display_name: str + :param display_name: (Optional) The display name for the instance in + the Cloud Console UI. (Must be between 4 and 30 + characters.) If this value is not set in the + constructor, will fall back to the instance ID. + + :type node_count: int + :param node_count: (Optional) The number of nodes in the instance's + cluster; used to set up the instance's cluster. + + :rtype: :class:`~google.cloud.spanner.instance.Instance` + :returns: an instance owned by this client. + """ + return Instance( + instance_id, self, configuration_name, node_count, display_name) + + def list_instances(self, filter_='', page_size=None, page_token=None): + """List instances for the client's project. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.InstanceAdmin.ListInstances + + :type filter_: string + :param filter_: (Optional) Filter to select instances listed. See: + the ``ListInstancesRequest`` docs above for examples. + + :type page_size: int + :param page_size: (Optional) Maximum number of results to return. + + :type page_token: str + :param page_token: (Optional) Token for fetching next page of results. + + :rtype: :class:`~google.cloud.iterator.Iterator` + :returns: + Iterator of :class:`~google.cloud.spanner.instance.Instance` + resources within the client's project. + """ + if page_token is None: + page_token = INITIAL_PAGE + options = _options_with_prefix(self.project_name, + page_token=page_token) + path = 'projects/%s' % (self.project,) + page_iter = self.instance_admin_api.list_instances( + path, filter_=filter_, page_size=page_size, options=options) + return GAXIterator(self, page_iter, _item_to_instance) + + +def _item_to_instance_config( + iterator, config_pb): # pylint: disable=unused-argument + """Convert an instance config protobuf to the native object. + + :type iterator: :class:`~google.cloud.iterator.Iterator` + :param iterator: The iterator that is currently in use. + + :type config_pb: + :class:`~google.spanner.admin.instance.v1.InstanceConfig` + :param config_pb: An instance config returned from the API. + + :rtype: :class:`~google.cloud.spanner.instance.InstanceConfig` + :returns: The next instance config in the page. + """ + return InstanceConfig.from_pb(config_pb) + + +def _item_to_instance(iterator, instance_pb): + """Convert an instance protobuf to the native object. + + :type iterator: :class:`~google.cloud.iterator.Iterator` + :param iterator: The iterator that is currently in use. + + :type instance_pb: :class:`~google.spanner.admin.instance.v1.Instance` + :param instance_pb: An instance returned from the API. + + :rtype: :class:`~google.cloud.spanner.instance.Instance` + :returns: The next instance in the page. + """ + return Instance.from_pb(instance_pb, iterator.client) diff --git a/spanner/google/cloud/spanner/database.py b/spanner/google/cloud/spanner/database.py new file mode 100644 index 000000000000..16864b4b0c78 --- /dev/null +++ b/spanner/google/cloud/spanner/database.py @@ -0,0 +1,554 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""User friendly container for Cloud Spanner Database.""" + +import functools +import re + +from google.gax.errors import GaxError +from google.gax.grpc import exc_to_code +from google.gax import _OperationFuture +from google.cloud.proto.spanner.admin.database.v1 import ( + spanner_database_admin_pb2 as admin_v1_pb2) +from google.cloud.gapic.spanner.v1.spanner_client import SpannerClient +from grpc import StatusCode +import six + +# pylint: disable=ungrouped-imports +from google.cloud.exceptions import Conflict +from google.cloud.exceptions import NotFound +from google.cloud.operation import register_type +from google.cloud.spanner._helpers import _options_with_prefix +from google.cloud.spanner.batch import Batch +from google.cloud.spanner.session import Session +from google.cloud.spanner.pool import BurstyPool +from google.cloud.spanner.snapshot import Snapshot +from google.cloud.spanner.pool import SessionCheckout +# pylint: enable=ungrouped-imports + + +_DATABASE_NAME_RE = re.compile( + r'^projects/(?P[^/]+)/' + r'instances/(?P[a-z][-a-z0-9]*)/' + r'databases/(?P[a-z][a-z0-9_\-]*[a-z0-9])$' + ) + +register_type(admin_v1_pb2.Database) +register_type(admin_v1_pb2.CreateDatabaseMetadata) +register_type(admin_v1_pb2.UpdateDatabaseDdlMetadata) + + +class _BrokenResultFuture(_OperationFuture): + """An _OperationFuture subclass that is permissive about type mismatches + in results, and simply returns an empty-ish object if they happen. + + This class exists to get past a contra-spec result on + `update_database_ddl`; since the result is empty there is no + critical loss. + """ + @functools.wraps(_OperationFuture.result) + def result(self, *args, **kwargs): + try: + return super(_BrokenResultFuture, self).result(*args, **kwargs) + except TypeError: + return self._result_type() + + +class Database(object): + """Representation of a Cloud Spanner Database. + + We can use a :class:`Database` to: + + * :meth:`create` the database + * :meth:`reload` the database + * :meth:`update` the database + * :meth:`drop` the database + + :type database_id: str + :param database_id: The ID of the database. + + :type instance: :class:`~google.cloud.spanner.instance.Instance` + :param instance: The instance that owns the database. + + :type ddl_statements: list of string + :param ddl_statements: (Optional) DDL statements, excluding the + CREATE DATABASE statement. + + :type pool: concrete subclass of + :class:`~google.cloud.spanner.pool.AbstractSessionPool`. + :param pool: (Optional) session pool to be used by database. If not + passed, the database will construct an instance of + :class:`~google.cloud.spanner.pool.BurstyPool`. + """ + + _spanner_api = None + + def __init__(self, database_id, instance, ddl_statements=(), pool=None): + self.database_id = database_id + self._instance = instance + self._ddl_statements = _check_ddl_statements(ddl_statements) + + if pool is None: + pool = BurstyPool() + + self._pool = pool + pool.bind(self) + + @classmethod + def from_pb(cls, database_pb, instance, pool=None): + """Creates an instance of this class from a protobuf. + + :type database_pb: + :class:`google.spanner.v2.spanner_instance_admin_pb2.Instance` + :param database_pb: A instance protobuf object. + + :type instance: :class:`~google.cloud.spanner.instance.Instance` + :param instance: The instance that owns the database. + + :type pool: concrete subclass of + :class:`~google.cloud.spanner.pool.AbstractSessionPool`. + :param pool: (Optional) session pool to be used by database. + + :rtype: :class:`Database` + :returns: The database parsed from the protobuf response. + :raises: + :class:`ValueError ` if the instance + name does not match the expected format + or if the parsed project ID does not match the project ID + on the instance's client, or if the parsed instance ID does + not match the instance's ID. + """ + match = _DATABASE_NAME_RE.match(database_pb.name) + if match is None: + raise ValueError('Database protobuf name was not in the ' + 'expected format.', database_pb.name) + if match.group('project') != instance._client.project: + raise ValueError('Project ID on database does not match the ' + 'project ID on the instance\'s client') + instance_id = match.group('instance_id') + if instance_id != instance.instance_id: + raise ValueError('Instance ID on database does not match the ' + 'Instance ID on the instance') + database_id = match.group('database_id') + + return cls(database_id, instance, pool=pool) + + @property + def name(self): + """Database name used in requests. + + .. note:: + + This property will not change if ``database_id`` does not, but the + return value is not cached. + + The database name is of the form + + ``"projects/../instances/../databases/{database_id}"`` + + :rtype: str + :returns: The database name. + """ + return self._instance.name + '/databases/' + self.database_id + + @property + def ddl_statements(self): + """DDL Statements used to define database schema. + + See: + cloud.google.com/spanner/docs/data-definition-language + + :rtype: sequence of string + :returns: the statements + """ + return self._ddl_statements + + @property + def spanner_api(self): + """Helper for session-related API calls.""" + if self._spanner_api is None: + self._spanner_api = SpannerClient() + return self._spanner_api + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return (other.database_id == self.database_id and + other._instance == self._instance) + + def __ne__(self, other): + return not self.__eq__(other) + + def create(self): + """Create this database within its instance + + Inclues any configured schema assigned to :attr:`ddl_statements`. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.DatabaseAdmin.CreateDatabase + """ + api = self._instance._client.database_admin_api + options = _options_with_prefix(self.name) + db_name = self.database_id + if '-' in db_name: + db_name = '`%s`' % (db_name,) + + try: + future = api.create_database( + parent=self._instance.name, + create_statement='CREATE DATABASE %s' % (db_name,), + extra_statements=list(self._ddl_statements), + options=options, + ) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.ALREADY_EXISTS: + raise Conflict(self.name) + elif exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + raise NotFound('Instance not found: {name}'.format( + name=self._instance.name, + )) + raise + + future.caller_metadata = {'request_type': 'CreateDatabase'} + return future + + def exists(self): + """Test whether this database exists. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.DatabaseAdmin.GetDatabaseDDL + """ + api = self._instance._client.database_admin_api + options = _options_with_prefix(self.name) + + try: + api.get_database_ddl(self.name, options=options) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + return False + raise + return True + + def reload(self): + """Reload this database. + + Refresh any configured schema into :attr:`ddl_statements`. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.DatabaseAdmin.GetDatabaseDDL + """ + api = self._instance._client.database_admin_api + options = _options_with_prefix(self.name) + + try: + response = api.get_database_ddl(self.name, options=options) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + raise NotFound(self.name) + raise + self._ddl_statements = tuple(response.statements) + + def update_ddl(self, ddl_statements): + """Update DDL for this database. + + Apply any configured schema from :attr:`ddl_statements`. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabase + + :rtype: :class:`google.cloud.operation.Operation` + :returns: an operation instance + """ + client = self._instance._client + api = client.database_admin_api + options = _options_with_prefix(self.name) + + try: + future = api.update_database_ddl( + self.name, ddl_statements, '', options=options) + future.__class__ = _BrokenResultFuture + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + raise NotFound(self.name) + raise + + future.caller_metadata = {'request_type': 'UpdateDatabaseDdl'} + return future + + def drop(self): + """Drop this database. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.DatabaseAdmin.DropDatabase + """ + api = self._instance._client.database_admin_api + options = _options_with_prefix(self.name) + + try: + api.drop_database(self.name, options=options) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + raise NotFound(self.name) + raise + + def session(self): + """Factory to create a session for this database. + + :rtype: :class:`~google.cloud.spanner.session.Session` + :returns: a session bound to this database. + """ + return Session(self) + + def read(self, table, columns, keyset, index='', limit=0, + resume_token=b''): + """Perform a ``StreamingRead`` API request for rows in a table. + + :type table: str + :param table: name of the table from which to fetch data + + :type columns: list of str + :param columns: names of columns to be retrieved + + :type keyset: :class:`~google.cloud.spanner.keyset.KeySet` + :param keyset: keys / ranges identifying rows to be retrieved + + :type index: str + :param index: (Optional) name of index to use, rather than the + table's primary key + + :type limit: int + :param limit: (Optional) maxiumn number of rows to return + + :type resume_token: bytes + :param resume_token: token for resuming previously-interrupted read + + :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + """ + with SessionCheckout(self._pool) as session: + return session.read( + table, columns, keyset, index, limit, resume_token) + + def execute_sql(self, sql, params=None, param_types=None, query_mode=None, + resume_token=b''): + """Perform an ``ExecuteStreamingSql`` API request. + + :type sql: str + :param sql: SQL query statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``sql``. + + :type param_types: + dict, {str -> :class:`google.spanner.v1.type_pb2.TypeCode`} + :param param_types: (Optional) explicit types for one or more param + values; overrides default type detection on the + back-end. + + :type query_mode: + :class:`google.spanner.v1.spanner_pb2.ExecuteSqlRequest.QueryMode` + :param query_mode: Mode governing return of results / query plan. See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + + :type resume_token: bytes + :param resume_token: token for resuming previously-interrupted query + + :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + """ + with SessionCheckout(self._pool) as session: + return session.execute_sql( + sql, params, param_types, query_mode, resume_token) + + def run_in_transaction(self, func, *args, **kw): + """Perform a unit of work in a transaction, retrying on abort. + + :type func: callable + :param func: takes a required positional argument, the transaction, + and additional positional / keyword arguments as supplied + by the caller. + + :type args: tuple + :param args: additional positional arguments to be passed to ``func``. + + :type kw: dict + :param kw: optional keyword arguments to be passed to ``func``. + If passed, "timeout_secs" will be removed and used to + override the default timeout. + + :rtype: :class:`datetime.datetime` + :returns: timestamp of committed transaction + """ + with SessionCheckout(self._pool) as session: + return session.run_in_transaction(func, *args, **kw) + + def batch(self): + """Return an object which wraps a batch. + + The wrapper *must* be used as a context manager, with the batch + as the value returned by the wrapper. + + :rtype: :class:`~google.cloud.spanner.database.BatchCheckout` + :returns: new wrapper + """ + return BatchCheckout(self) + + def snapshot(self, read_timestamp=None, min_read_timestamp=None, + max_staleness=None, exact_staleness=None): + """Return an object which wraps a snapshot. + + The wrapper *must* be used as a context manager, with the snapshot + as the value returned by the wrapper. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.TransactionOptions.ReadOnly + + If no options are passed, reads will use the ``strong`` model, reading + at a timestamp where all previously committed transactions are visible. + + :type read_timestamp: :class:`datetime.datetime` + :param read_timestamp: Execute all reads at the given timestamp. + + :type min_read_timestamp: :class:`datetime.datetime` + :param min_read_timestamp: Execute all reads at a + timestamp >= ``min_read_timestamp``. + + :type max_staleness: :class:`datetime.timedelta` + :param max_staleness: Read data at a + timestamp >= NOW - ``max_staleness`` seconds. + + :type exact_staleness: :class:`datetime.timedelta` + :param exact_staleness: Execute all reads at a timestamp that is + ``exact_staleness`` old. + + :rtype: :class:`~google.cloud.spanner.snapshot.Snapshot` + :returns: a snapshot bound to this session + :raises: :exc:`ValueError` if the session has not yet been created. + + :rtype: :class:`~google.cloud.spanner.database.SnapshotCheckout` + :returns: new wrapper + """ + return SnapshotCheckout( + self, + read_timestamp=read_timestamp, + min_read_timestamp=min_read_timestamp, + max_staleness=max_staleness, + exact_staleness=exact_staleness, + ) + + +class BatchCheckout(object): + """Context manager for using a batch from a database. + + Inside the context manager, checks out a session from the database, + creates a batch from it, making the batch available. + + Caller must *not* use the batch to perform API requests outside the scope + of the context manager. + + :type database: :class:`~google.cloud.spannder.database.Database` + :param database: database to use + """ + def __init__(self, database): + self._database = database + self._session = self._batch = None + + def __enter__(self): + """Begin ``with`` block.""" + session = self._session = self._database._pool.get() + batch = self._batch = Batch(session) + return batch + + def __exit__(self, exc_type, exc_val, exc_tb): + """End ``with`` block.""" + try: + if exc_type is None: + self._batch.commit() + finally: + self._database._pool.put(self._session) + + +class SnapshotCheckout(object): + """Context manager for using a snapshot from a database. + + Inside the context manager, checks out a session from the database, + creates a snapshot from it, making the snapshot available. + + Caller must *not* use the snapshot to perform API requests outside the + scope of the context manager. + + :type database: :class:`~google.cloud.spannder.database.Database` + :param database: database to use + + :type read_timestamp: :class:`datetime.datetime` + :param read_timestamp: Execute all reads at the given timestamp. + + :type min_read_timestamp: :class:`datetime.datetime` + :param min_read_timestamp: Execute all reads at a + timestamp >= ``min_read_timestamp``. + + :type max_staleness: :class:`datetime.timedelta` + :param max_staleness: Read data at a + timestamp >= NOW - ``max_staleness`` seconds. + + :type exact_staleness: :class:`datetime.timedelta` + :param exact_staleness: Execute all reads at a timestamp that is + ``exact_staleness`` old. + """ + def __init__(self, database, read_timestamp=None, min_read_timestamp=None, + max_staleness=None, exact_staleness=None): + self._database = database + self._session = None + self._read_timestamp = read_timestamp + self._min_read_timestamp = min_read_timestamp + self._max_staleness = max_staleness + self._exact_staleness = exact_staleness + + def __enter__(self): + """Begin ``with`` block.""" + session = self._session = self._database._pool.get() + return Snapshot( + session, + read_timestamp=self._read_timestamp, + min_read_timestamp=self._min_read_timestamp, + max_staleness=self._max_staleness, + exact_staleness=self._exact_staleness, + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + """End ``with`` block.""" + self._database._pool.put(self._session) + + +def _check_ddl_statements(value): + """Validate DDL Statements used to define database schema. + + See: + https://cloud.google.com/spanner/docs/data-definition-language + + :type value: list of string + :param value: DDL statements, excluding the 'CREATE DATABSE' statement + + :rtype: tuple + :returns: tuple of validated DDL statement strings. + """ + if not all(isinstance(line, six.string_types) for line in value): + raise ValueError("Pass a list of strings") + + if any('create database' in line.lower() for line in value): + raise ValueError("Do not pass a 'CREATE DATABASE' statement") + + return tuple(value) diff --git a/spanner/google/cloud/spanner/instance.py b/spanner/google/cloud/spanner/instance.py new file mode 100644 index 000000000000..2935fc2ad57f --- /dev/null +++ b/spanner/google/cloud/spanner/instance.py @@ -0,0 +1,399 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""User friendly container for Cloud Spanner Instance.""" + +import re + +from google.gax import INITIAL_PAGE +from google.gax.errors import GaxError +from google.gax.grpc import exc_to_code +from google.cloud.proto.spanner.admin.instance.v1 import ( + spanner_instance_admin_pb2 as admin_v1_pb2) +from google.protobuf.field_mask_pb2 import FieldMask +from grpc import StatusCode + +# pylint: disable=ungrouped-imports +from google.cloud.exceptions import Conflict +from google.cloud.exceptions import NotFound +from google.cloud.iterator import GAXIterator +from google.cloud.operation import register_type +from google.cloud.spanner._helpers import _options_with_prefix +from google.cloud.spanner.database import Database +from google.cloud.spanner.pool import BurstyPool +# pylint: enable=ungrouped-imports + + +_INSTANCE_NAME_RE = re.compile( + r'^projects/(?P[^/]+)/' + r'instances/(?P[a-z][-a-z0-9]*)$') + +DEFAULT_NODE_COUNT = 1 + +register_type(admin_v1_pb2.Instance) +register_type(admin_v1_pb2.CreateInstanceMetadata) +register_type(admin_v1_pb2.UpdateInstanceMetadata) + + +class Instance(object): + """Representation of a Cloud Spanner Instance. + + We can use a :class:`Instance` to: + + * :meth:`reload` itself + * :meth:`create` itself + * :meth:`update` itself + * :meth:`delete` itself + + :type instance_id: str + :param instance_id: The ID of the instance. + + :type client: :class:`~google.cloud.spanner.client.Client` + :param client: The client that owns the instance. Provides + authorization and a project ID. + + :type configuration_name: str + :param configuration_name: Name of the instance configuration defining + how the instance will be created. + Required for instances which do not yet exist. + + :type node_count: int + :param node_count: (Optional) Number of nodes allocated to the instance. + + :type display_name: str + :param display_name: (Optional) The display name for the instance in the + Cloud Console UI. (Must be between 4 and 30 + characters.) If this value is not set in the + constructor, will fall back to the instance ID. + """ + + def __init__(self, + instance_id, + client, + configuration_name=None, + node_count=DEFAULT_NODE_COUNT, + display_name=None): + self.instance_id = instance_id + self._client = client + self.configuration_name = configuration_name + self.node_count = node_count + self.display_name = display_name or instance_id + + def _update_from_pb(self, instance_pb): + """Refresh self from the server-provided protobuf. + + Helper for :meth:`from_pb` and :meth:`reload`. + """ + if not instance_pb.display_name: # Simple field (string) + raise ValueError('Instance protobuf does not contain display_name') + self.display_name = instance_pb.display_name + self.configuration_name = instance_pb.config + self.node_count = instance_pb.node_count + + @classmethod + def from_pb(cls, instance_pb, client): + """Creates an instance from a protobuf. + + :type instance_pb: + :class:`google.spanner.v2.spanner_instance_admin_pb2.Instance` + :param instance_pb: A instance protobuf object. + + :type client: :class:`~google.cloud.spanner.client.Client` + :param client: The client that owns the instance. + + :rtype: :class:`Instance` + :returns: The instance parsed from the protobuf response. + :raises: :class:`ValueError ` if the instance + name does not match + ``projects/{project}/instances/{instance_id}`` + or if the parsed project ID does not match the project ID + on the client. + """ + match = _INSTANCE_NAME_RE.match(instance_pb.name) + if match is None: + raise ValueError('Instance protobuf name was not in the ' + 'expected format.', instance_pb.name) + if match.group('project') != client.project: + raise ValueError('Project ID on instance does not match the ' + 'project ID on the client') + instance_id = match.group('instance_id') + configuration_name = instance_pb.config + + result = cls(instance_id, client, configuration_name) + result._update_from_pb(instance_pb) + return result + + @property + def name(self): + """Instance name used in requests. + + .. note:: + + This property will not change if ``instance_id`` does not, + but the return value is not cached. + + The instance name is of the form + + ``"projects/{project}/instances/{instance_id}"`` + + :rtype: str + :returns: The instance name. + """ + return self._client.project_name + '/instances/' + self.instance_id + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + # NOTE: This does not compare the configuration values, such as + # the display_name. Instead, it only compares + # identifying values instance ID and client. This is + # intentional, since the same instance can be in different states + # if not synchronized. Instances with similar instance + # settings but different clients can't be used in the same way. + return (other.instance_id == self.instance_id and + other._client == self._client) + + def __ne__(self, other): + return not self.__eq__(other) + + def copy(self): + """Make a copy of this instance. + + Copies the local data stored as simple types and copies the client + attached to this instance. + + :rtype: :class:`~google.cloud.spanner.instance.Instance` + :returns: A copy of the current instance. + """ + new_client = self._client.copy() + return self.__class__( + self.instance_id, + new_client, + self.configuration_name, + node_count=self.node_count, + display_name=self.display_name, + ) + + def create(self): + """Create this instance. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.instance.v1#google.spanner.admin.instance.v1.InstanceAdmin.CreateInstance + + .. note:: + + Uses the ``project`` and ``instance_id`` on the current + :class:`Instance` in addition to the ``display_name``. + To change them before creating, reset the values via + + .. code:: python + + instance.display_name = 'New display name' + instance.instance_id = 'i-changed-my-mind' + + before calling :meth:`create`. + + :rtype: :class:`google.cloud.operation.Operation` + :returns: an operation instance + """ + api = self._client.instance_admin_api + instance_pb = admin_v1_pb2.Instance( + name=self.name, + config=self.configuration_name, + display_name=self.display_name, + node_count=self.node_count, + ) + options = _options_with_prefix(self.name) + + try: + future = api.create_instance( + parent=self._client.project_name, + instance_id=self.instance_id, + instance=instance_pb, + options=options, + ) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.ALREADY_EXISTS: + raise Conflict(self.name) + raise + + future.caller_metadata = {'request_type': 'CreateInstance'} + return future + + def exists(self): + """Test whether this instance exists. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.instance.v1#google.spanner.admin.instance.v1.InstanceAdmin.GetInstanceConfig + """ + api = self._client.instance_admin_api + options = _options_with_prefix(self.name) + + try: + api.get_instance(self.name, options=options) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + return False + raise + + return True + + def reload(self): + """Reload the metadata for this instance. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.instance.v1#google.spanner.admin.instance.v1.InstanceAdmin.GetInstanceConfig + """ + api = self._client.instance_admin_api + options = _options_with_prefix(self.name) + + try: + instance_pb = api.get_instance(self.name, options=options) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + raise NotFound(self.name) + raise + + self._update_from_pb(instance_pb) + + def update(self): + """Update this instance. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.instance.v1#google.spanner.admin.instance.v1.InstanceAdmin.UpdateInstance + + .. note:: + + Updates the ``display_name`` and ``node_count``. To change those + values before updating, set them via + + .. code:: python + + instance.display_name = 'New display name' + instance.node_count = 5 + + before calling :meth:`update`. + + :rtype: :class:`google.cloud.operation.Operation` + :returns: an operation instance + """ + api = self._client.instance_admin_api + instance_pb = admin_v1_pb2.Instance( + name=self.name, + config=self.configuration_name, + display_name=self.display_name, + node_count=self.node_count, + ) + field_mask = FieldMask(paths=['config', 'display_name', 'node_count']) + options = _options_with_prefix(self.name) + + try: + future = api.update_instance( + instance=instance_pb, + field_mask=field_mask, + options=options, + ) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + raise NotFound(self.name) + raise + + future.caller_metadata = {'request_type': 'UpdateInstance'} + return future + + def delete(self): + """Mark an instance and all of its databases for permanent deletion. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.instance.v1#google.spanner.admin.instance.v1.InstanceAdmin.DeleteInstance + + Immediately upon completion of the request: + + * Billing will cease for all of the instance's reserved resources. + + Soon afterward: + + * The instance and all databases within the instance will be deleteed. + All data in the databases will be permanently deleted. + """ + api = self._client.instance_admin_api + options = _options_with_prefix(self.name) + + try: + api.delete_instance(self.name, options=options) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + raise NotFound(self.name) + raise + + def database(self, database_id, ddl_statements=(), pool=None): + """Factory to create a database within this instance. + + :type database_id: str + :param database_id: The ID of the instance. + + :type ddl_statements: list of string + :param ddl_statements: (Optional) DDL statements, excluding the + 'CREATE DATABSE' statement. + + :type pool: concrete subclass of + :class:`~google.cloud.spanner.pool.AbstractSessionPool`. + :param pool: (Optional) session pool to be used by database. + + :rtype: :class:`~google.cloud.spanner.database.Database` + :returns: a database owned by this instance. + """ + return Database( + database_id, self, ddl_statements=ddl_statements, pool=pool) + + def list_databases(self, page_size=None, page_token=None): + """List databases for the instance. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.DatabaseAdmin.ListDatabases + + :type page_size: int + :param page_size: (Optional) Maximum number of results to return. + + :type page_token: str + :param page_token: (Optional) Token for fetching next page of results. + + :rtype: :class:`~google.cloud.iterator.Iterator` + :returns: + Iterator of :class:`~google.cloud.spanner.database.Database` + resources within the current instance. + """ + if page_token is None: + page_token = INITIAL_PAGE + options = _options_with_prefix(self.name, page_token=page_token) + page_iter = self._client.database_admin_api.list_databases( + self.name, page_size=page_size, options=options) + iterator = GAXIterator(self._client, page_iter, _item_to_database) + iterator.instance = self + return iterator + + +def _item_to_database(iterator, database_pb): + """Convert a database protobuf to the native object. + + :type iterator: :class:`~google.cloud.iterator.Iterator` + :param iterator: The iterator that is currently in use. + + :type database_pb: :class:`~google.spanner.admin.database.v1.Database` + :param database_pb: A database returned from the API. + + :rtype: :class:`~google.cloud.spanner.database.Database` + :returns: The next database in the page. + """ + return Database.from_pb(database_pb, iterator.instance, pool=BurstyPool()) diff --git a/spanner/google/cloud/spanner/keyset.py b/spanner/google/cloud/spanner/keyset.py new file mode 100644 index 000000000000..fe0d5cd1485d --- /dev/null +++ b/spanner/google/cloud/spanner/keyset.py @@ -0,0 +1,113 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrap representation of Spanner keys / ranges.""" + +from google.cloud.proto.spanner.v1.keys_pb2 import KeyRange as KeyRangePB +from google.cloud.proto.spanner.v1.keys_pb2 import KeySet as KeySetPB + +from google.cloud.spanner._helpers import _make_list_value_pb +from google.cloud.spanner._helpers import _make_list_value_pbs + + +class KeyRange(object): + """Identify range of table rows via start / end points. + + :type start_open: list of scalars + :param start_open: keys identifying start of range (this key excluded) + + :type start_closed: list of scalars + :param start_closed: keys identifying start of range (this key included) + + :type end_open: list of scalars + :param end_open: keys identifying end of range (this key excluded) + + :type end_closed: list of scalars + :param end_closed: keys identifying end of range (this key included) + """ + def __init__(self, start_open=None, start_closed=None, + end_open=None, end_closed=None): + if not any([start_open, start_closed, end_open, end_closed]): + raise ValueError("Must specify at least a start or end row.") + + if start_open and start_closed: + raise ValueError("Specify one of 'start_open' / 'start_closed'.") + + if end_open and end_closed: + raise ValueError("Specify one of 'end_open' / 'end_closed'.") + + self.start_open = start_open + self.start_closed = start_closed + self.end_open = end_open + self.end_closed = end_closed + + def to_pb(self): + """Construct a KeyRange protobuf. + + :rtype: :class:`~google.cloud.proto.spanner.v1.keys_pb2.KeyRange` + :returns: protobuf corresponding to this instance. + """ + kwargs = {} + + if self.start_open: + kwargs['start_open'] = _make_list_value_pb(self.start_open) + + if self.start_closed: + kwargs['start_closed'] = _make_list_value_pb(self.start_closed) + + if self.end_open: + kwargs['end_open'] = _make_list_value_pb(self.end_open) + + if self.end_closed: + kwargs['end_closed'] = _make_list_value_pb(self.end_closed) + + return KeyRangePB(**kwargs) + + +class KeySet(object): + """Identify table rows via keys / ranges. + + :type keys: list of list of scalars + :param keys: keys identifying individual rows within a table. + + :type ranges: list of :class:`KeyRange` + :param ranges: ranges identifying rows within a table. + + :type all_: boolean + :param all_: if True, identify all rows within a table + """ + def __init__(self, keys=(), ranges=(), all_=False): + if all_ and (keys or ranges): + raise ValueError("'all_' is exclusive of 'keys' / 'ranges'.") + self.keys = list(keys) + self.ranges = list(ranges) + self.all_ = all_ + + def to_pb(self): + """Construct a KeySet protobuf. + + :rtype: :class:`~google.cloud.proto.spanner.v1.keys_pb2.KeySet` + :returns: protobuf corresponding to this instance. + """ + if self.all_: + return KeySetPB(all=True) + kwargs = {} + + if self.keys: + kwargs['keys'] = _make_list_value_pbs(self.keys) + + if self.ranges: + kwargs['ranges'] = [krange.to_pb() for krange in self.ranges] + + return KeySetPB(**kwargs) diff --git a/spanner/google/cloud/spanner/pool.py b/spanner/google/cloud/spanner/pool.py new file mode 100644 index 000000000000..e88f635573f9 --- /dev/null +++ b/spanner/google/cloud/spanner/pool.py @@ -0,0 +1,464 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pools managing shared Session objects.""" + +import datetime + +from six.moves import queue +from six.moves import xrange + +from google.cloud.exceptions import NotFound + + +_NOW = datetime.datetime.utcnow # unit tests may replace + + +class AbstractSessionPool(object): + """Specifies required API for concrete session pool implementations.""" + + _database = None + + def bind(self, database): + """Associate the pool with a database. + + :type database: :class:`~google.cloud.spanner.database.Database` + :param database: database used by the pool: used to create sessions + when needed. + + Concrete implementations of this method may pre-fill the pool + using the database. + """ + raise NotImplementedError() + + def get(self): + """Check a session out from the pool. + + Concrete implementations of this method are allowed to raise an + error to signal that the pool is exhausted, or to block until a + session is available. + """ + raise NotImplementedError() + + def put(self, session): + """Return a session to the pool. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session being returned. + + Concrete implementations of this method are allowed to raise an + error to signal that the pool is full, or to block until it is + not full. + """ + raise NotImplementedError() + + def clear(self): + """Delete all sessions in the pool. + + Concrete implementations of this method are allowed to raise an + error to signal that the pool is full, or to block until it is + not full. + """ + raise NotImplementedError() + + def session(self, **kwargs): + """Check out a session from the pool. + + :type kwargs: dict + :param kwargs: (optional) keyword arguments, passed through to + the returned checkout. + + :rtype: :class:`~google.cloud.spanner.session.SessionCheckout` + :returns: a checkout instance, to be used as a context manager for + accessing the session and returning it to the pool. + """ + return SessionCheckout(self, **kwargs) + + +class FixedSizePool(AbstractSessionPool): + """Concrete session pool implementation: + + - Pre-allocates / creates a fixed number of sessions. + + - "Pings" existing sessions via :meth:`session.exists` before returning + them, and replaces expired sessions. + + - Blocks, with a timeout, when :meth:`get` is called on an empty pool. + Raises after timing out. + + - Raises when :meth:`put` is called on a full pool. That error is + never expected in normal practice, as users should be calling + :meth:`get` followed by :meth:`put` whenever in need of a session. + + :type size: int + :param size: fixed pool size + + :type default_timeout: int + :param default_timeout: default timeout, in seconds, to wait for + a returned session. + """ + DEFAULT_SIZE = 10 + DEFAULT_TIMEOUT = 10 + + def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT): + self.size = size + self.default_timeout = default_timeout + self._sessions = queue.Queue(size) + + def bind(self, database): + """Associate the pool with a database. + + :type database: :class:`~google.cloud.spanner.database.Database` + :param database: database used by the pool: used to create sessions + when needed. + """ + self._database = database + + while not self._sessions.full(): + session = database.session() + session.create() + self._sessions.put(session) + + def get(self, timeout=None): # pylint: disable=arguments-differ + """Check a session out from the pool. + + :type timeout: int + :param timeout: seconds to block waiting for an available session + + :rtype: :class:`~google.cloud.spanner.session.Session` + :returns: an existing session from the pool, or a newly-created + session. + :raises: :exc:`six.moves.queue.Empty` if the queue is empty. + """ + if timeout is None: + timeout = self.default_timeout + + session = self._sessions.get(block=True, timeout=timeout) + + if not session.exists(): + session = self._database.session() + session.create() + + return session + + def put(self, session): + """Return a session to the pool. + + Never blocks: if the pool is full, raises. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session being returned. + + :raises: :exc:`six.moves.queue.Full` if the queue is full. + """ + self._sessions.put_nowait(session) + + def clear(self): + """Delete all sessions in the pool.""" + + while True: + try: + session = self._sessions.get(block=False) + except queue.Empty: + break + else: + session.delete() + + +class BurstyPool(AbstractSessionPool): + """Concrete session pool implementation: + + - "Pings" existing sessions via :meth:`session.exists` before returning + them. + + - Creates a new session, rather than blocking, when :meth:`get` is called + on an empty pool. + + - Discards the returned session, rather than blocking, when :meth:`put` + is called on a full pool. + + :type target_size: int + :param target_size: max pool size + """ + + def __init__(self, target_size=10): + self.target_size = target_size + self._database = None + self._sessions = queue.Queue(target_size) + + def bind(self, database): + """Associate the pool with a database. + + :type database: :class:`~google.cloud.spanner.database.Database` + :param database: database used by the pool: used to create sessions + when needed. + """ + self._database = database + + def get(self): + """Check a session out from the pool. + + :rtype: :class:`~google.cloud.spanner.session.Session` + :returns: an existing session from the pool, or a newly-created + session. + """ + try: + session = self._sessions.get_nowait() + except queue.Empty: + session = self._database.session() + session.create() + else: + if not session.exists(): + session = self._database.session() + session.create() + return session + + def put(self, session): + """Return a session to the pool. + + Never blocks: if the pool is full, the returned session is + discarded. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session being returned. + """ + try: + self._sessions.put_nowait(session) + except queue.Full: + try: + session.delete() + except NotFound: + pass + + def clear(self): + """Delete all sessions in the pool.""" + + while True: + try: + session = self._sessions.get(block=False) + except queue.Empty: + break + else: + session.delete() + + +class PingingPool(AbstractSessionPool): + """Concrete session pool implementation: + + - Pre-allocates / creates a fixed number of sessions. + + - Sessions are used in "round-robin" order (LRU first). + + - "Pings" existing sessions in the background after a specified interval + via an API call (``session.exists()``). + + - Blocks, with a timeout, when :meth:`get` is called on an empty pool. + Raises after timing out. + + - Raises when :meth:`put` is called on a full pool. That error is + never expected in normal practice, as users should be calling + :meth:`get` followed by :meth:`put` whenever in need of a session. + + The application is responsible for calling :meth:`ping` at appropriate + times, e.g. from a background thread. + + :type size: int + :param size: fixed pool size + + :type default_timeout: int + :param default_timeout: default timeout, in seconds, to wait for + a returned session. + + :type ping_interval: int + :param ping_interval: interval at which to ping sessions. + """ + + def __init__(self, size=10, default_timeout=10, ping_interval=3000): + self.size = size + self.default_timeout = default_timeout + self._delta = datetime.timedelta(seconds=ping_interval) + self._sessions = queue.PriorityQueue(size) + + def bind(self, database): + """Associate the pool with a database. + + :type database: :class:`~google.cloud.spanner.database.Database` + :param database: database used by the pool: used to create sessions + when needed. + """ + self._database = database + + for _ in xrange(self.size): + session = database.session() + session.create() + self.put(session) + + def get(self, timeout=None): # pylint: disable=arguments-differ + """Check a session out from the pool. + + :type timeout: int + :param timeout: seconds to block waiting for an available session + + :rtype: :class:`~google.cloud.spanner.session.Session` + :returns: an existing session from the pool, or a newly-created + session. + :raises: :exc:`six.moves.queue.Empty` if the queue is empty. + """ + if timeout is None: + timeout = self.default_timeout + + ping_after, session = self._sessions.get(block=True, timeout=timeout) + + if _NOW() > ping_after: + if not session.exists(): + session = self._database.session() + session.create() + + return session + + def put(self, session): + """Return a session to the pool. + + Never blocks: if the pool is full, raises. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session being returned. + + :raises: :exc:`six.moves.queue.Full` if the queue is full. + """ + self._sessions.put_nowait((_NOW() + self._delta, session)) + + def clear(self): + """Delete all sessions in the pool.""" + while True: + try: + _, session = self._sessions.get(block=False) + except queue.Empty: + break + else: + session.delete() + + def ping(self): + """Refresh maybe-expired sessions in the pool. + + This method is designed to be called from a background thread, + or during the "idle" phase of an event loop. + """ + while True: + try: + ping_after, session = self._sessions.get(block=False) + except queue.Empty: # all sessions in use + break + if ping_after > _NOW(): # oldest session is fresh + # Re-add to queue with existing expiration + self._sessions.put((ping_after, session)) + break + if not session.exists(): # stale + session = self._database.session() + session.create() + # Re-add to queue with new expiration + self.put(session) + + +class TransactionPingingPool(PingingPool): + """Concrete session pool implementation: + + In addition to the features of :class:`PingingPool`, this class + creates and begins a transaction for each of its sessions at startup. + + When a session is returned to the pool, if its transaction has been + committed or rolled back, the pool creates a new transaction for the + session and pushes the transaction onto a separate queue of "transactions + to begin." The application is responsible for flushing this queue + as appropriate via the pool's :meth:`begin_pending_transactions` method. + + :type size: int + :param size: fixed pool size + + :type default_timeout: int + :param default_timeout: default timeout, in seconds, to wait for + a returned session. + + :type ping_interval: int + :param ping_interval: interval at which to ping sessions. + """ + + def __init__(self, size=10, default_timeout=10, ping_interval=3000): + self._pending_sessions = queue.Queue() + + super(TransactionPingingPool, self).__init__( + size, default_timeout, ping_interval) + + self.begin_pending_transactions() + + def bind(self, database): + """Associate the pool with a database. + + :type database: :class:`~google.cloud.spanner.database.Database` + :param database: database used by the pool: used to create sessions + when needed. + """ + super(TransactionPingingPool, self).bind(database) + self.begin_pending_transactions() + + def put(self, session): + """Return a session to the pool. + + Never blocks: if the pool is full, raises. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session being returned. + + :raises: :exc:`six.moves.queue.Full` if the queue is full. + """ + if self._sessions.full(): + raise queue.Full + + txn = session._transaction + if txn is None or txn.committed() or txn._rolled_back: + session.transaction() + self._pending_sessions.put(session) + else: + super(TransactionPingingPool, self).put(session) + + def begin_pending_transactions(self): + """Begin all transactions for sessions added to the pool.""" + while not self._pending_sessions.empty(): + session = self._pending_sessions.get() + session._transaction.begin() + super(TransactionPingingPool, self).put(session) + + +class SessionCheckout(object): + """Context manager: hold session checked out from a pool. + + :type pool: concrete subclass of + :class:`~google.cloud.spanner.session.AbstractSessionPool` + :param pool: Pool from which to check out a session. + + :type kwargs: dict + :param kwargs: extra keyword arguments to be passed to :meth:`pool.get`. + """ + _session = None # Not checked out until '__enter__'. + + def __init__(self, pool, **kwargs): + self._pool = pool + self._kwargs = kwargs.copy() + + def __enter__(self): + self._session = self._pool.get(**self._kwargs) + return self._session + + def __exit__(self, *ignored): + self._pool.put(self._session) diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py new file mode 100644 index 000000000000..ecf0995938ef --- /dev/null +++ b/spanner/google/cloud/spanner/session.py @@ -0,0 +1,360 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper for Cloud Spanner Session objects.""" + +import time + +from google.gax.errors import GaxError +from google.gax.grpc import exc_to_code +from google.rpc.error_details_pb2 import RetryInfo +from grpc import StatusCode + +# pylint: disable=ungrouped-imports +from google.cloud.exceptions import NotFound +from google.cloud.spanner._helpers import _options_with_prefix +from google.cloud.spanner.batch import Batch +from google.cloud.spanner.snapshot import Snapshot +from google.cloud.spanner.transaction import Transaction +# pylint: enable=ungrouped-imports + + +DEFAULT_RETRY_TIMEOUT_SECS = 30 +"""Default timeout used by :meth:`Session.run_in_transaction`.""" + + +class Session(object): + """Representation of a Cloud Spanner Session. + + We can use a :class:`Session` to: + + * :meth:`create` the session + * Use :meth:`exists` to check for the existence of the session + * :meth:`drop` the session + + :type database: :class:`~google.cloud.spanner.database.Database` + :param database: The database to which the session is bound. + """ + + _session_id = None + _transaction = None + + def __init__(self, database): + self._database = database + + @property + def session_id(self): + """Read-only ID, set by the back-end during :meth:`create`.""" + return self._session_id + + @property + def name(self): + """Session name used in requests. + + .. note:: + + This property will not change if ``session_id`` does not, but the + return value is not cached. + + The session name is of the form + + ``"projects/../instances/../databases/../sessions/{session_id}"`` + + :rtype: str + :returns: The session name. + """ + if self._session_id is None: + raise ValueError('No session ID set by back-end') + return self._database.name + '/sessions/' + self._session_id + + def create(self): + """Create this session, bound to its database. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.Spanner.CreateSession + + :raises: :exc:`ValueError` if :attr:`session_id` is already set. + """ + if self._session_id is not None: + raise ValueError('Session ID already set by back-end') + api = self._database.spanner_api + options = _options_with_prefix(self._database.name) + session_pb = api.create_session(self._database.name, options=options) + self._session_id = session_pb.name.split('/')[-1] + + def exists(self): + """Test for the existence of this session. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.Spanner.GetSession + + :rtype: bool + :returns: True if the session exists on the back-end, else False. + """ + if self._session_id is None: + return False + api = self._database.spanner_api + options = _options_with_prefix(self._database.name) + try: + api.get_session(self.name, options=options) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + return False + raise + else: + return True + + def delete(self): + """Delete this session. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.Spanner.GetSession + + :raises: :exc:`ValueError` if :attr:`session_id` is not already set. + """ + if self._session_id is None: + raise ValueError('Session ID not set by back-end') + api = self._database.spanner_api + options = _options_with_prefix(self._database.name) + try: + api.delete_session(self.name, options=options) + except GaxError as exc: + if exc_to_code(exc.cause) == StatusCode.NOT_FOUND: + raise NotFound(self.name) + raise + + def snapshot(self, read_timestamp=None, min_read_timestamp=None, + max_staleness=None, exact_staleness=None): + """Create a snapshot to perform a set of reads with shared staleness. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.TransactionOptions.ReadOnly + + If no options are passed, reads will use the ``strong`` model, reading + at a timestamp where all previously committed transactions are visible. + + :type read_timestamp: :class:`datetime.datetime` + :param read_timestamp: Execute all reads at the given timestamp. + + :type min_read_timestamp: :class:`datetime.datetime` + :param min_read_timestamp: Execute all reads at a + timestamp >= ``min_read_timestamp``. + + :type max_staleness: :class:`datetime.timedelta` + :param max_staleness: Read data at a + timestamp >= NOW - ``max_staleness`` seconds. + + :type exact_staleness: :class:`datetime.timedelta` + :param exact_staleness: Execute all reads at a timestamp that is + ``exact_staleness`` old. + + :rtype: :class:`~google.cloud.spanner.snapshot.Snapshot` + :returns: a snapshot bound to this session + :raises: :exc:`ValueError` if the session has not yet been created. + """ + if self._session_id is None: + raise ValueError("Session has not been created.") + + return Snapshot(self, + read_timestamp=read_timestamp, + min_read_timestamp=min_read_timestamp, + max_staleness=max_staleness, + exact_staleness=exact_staleness) + + def read(self, table, columns, keyset, index='', limit=0, + resume_token=b''): + """Perform a ``StreamingRead`` API request for rows in a table. + + :type table: str + :param table: name of the table from which to fetch data + + :type columns: list of str + :param columns: names of columns to be retrieved + + :type keyset: :class:`~google.cloud.spanner.keyset.KeySet` + :param keyset: keys / ranges identifying rows to be retrieved + + :type index: str + :param index: (Optional) name of index to use, rather than the + table's primary key + + :type limit: int + :param limit: (Optional) maxiumn number of rows to return + + :type resume_token: bytes + :param resume_token: token for resuming previously-interrupted read + + :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + """ + return self.snapshot().read( + table, columns, keyset, index, limit, resume_token) + + def execute_sql(self, sql, params=None, param_types=None, query_mode=None, + resume_token=b''): + """Perform an ``ExecuteStreamingSql`` API request. + + :type sql: str + :param sql: SQL query statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``sql``. + + :type param_types: + dict, {str -> :class:`google.spanner.v1.type_pb2.TypeCode`} + :param param_types: (Optional) explicit types for one or more param + values; overrides default type detection on the + back-end. + + :type query_mode: + :class:`google.spanner.v1.spanner_pb2.ExecuteSqlRequest.QueryMode` + :param query_mode: Mode governing return of results / query plan. See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + + :type resume_token: bytes + :param resume_token: token for resuming previously-interrupted query + + :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + """ + return self.snapshot().execute_sql( + sql, params, param_types, query_mode, resume_token) + + def batch(self): + """Factory to create a batch for this session. + + :rtype: :class:`~google.cloud.spanner.batch.Batch` + :returns: a batch bound to this session + :raises: :exc:`ValueError` if the session has not yet been created. + """ + if self._session_id is None: + raise ValueError("Session has not been created.") + + return Batch(self) + + def transaction(self): + """Create a transaction to perform a set of reads with shared staleness. + + :rtype: :class:`~google.cloud.spanner.transaction.Transaction` + :returns: a transaction bound to this session + :raises: :exc:`ValueError` if the session has not yet been created. + """ + if self._session_id is None: + raise ValueError("Session has not been created.") + + if self._transaction is not None: + self._transaction._rolled_back = True + + txn = self._transaction = Transaction(self) + return txn + + def run_in_transaction(self, func, *args, **kw): + """Perform a unit of work in a transaction, retrying on abort. + + :type func: callable + :param func: takes a required positional argument, the transaction, + and additional positional / keyword arguments as supplied + by the caller. + + :type args: tuple + :param args: additional positional arguments to be passed to ``func``. + + :type kw: dict + :param kw: optional keyword arguments to be passed to ``func``. + If passed, "timeout_secs" will be removed and used to + override the default timeout. + + :rtype: :class:`datetime.datetime` + :returns: timestamp of committed transaction + """ + deadline = time.time() + kw.pop( + 'timeout_secs', DEFAULT_RETRY_TIMEOUT_SECS) + + while True: + if self._transaction is None: + txn = self.transaction() + else: + txn = self._transaction + if txn._id is None: + txn.begin() + try: + func(txn, *args, **kw) + except GaxError as exc: + _delay_until_retry(exc, deadline) + del self._transaction + continue + except Exception: + txn.rollback() + del self._transaction + raise + + try: + txn.commit() + except GaxError as exc: + _delay_until_retry(exc, deadline) + del self._transaction + else: + return txn.committed + + +# pylint: disable=misplaced-bare-raise +# +# Rational: this function factors out complex shared deadline / retry +# handling from two `except:` clauses. +def _delay_until_retry(exc, deadline): + """Helper for :meth:`Session.run_in_transaction`. + + Detect retryable abort, and impose server-supplied delay. + + :type exc: :class:`google.gax.errors.GaxError` + :param exc: exception for aborted transaction + + :type deadline: float + :param deadline: maximum timestamp to continue retrying the transaction. + """ + if exc_to_code(exc.cause) != StatusCode.ABORTED: + raise + + now = time.time() + + if now >= deadline: + raise + + delay = _get_retry_delay(exc) + if delay is not None: + + if now + delay > deadline: + raise + + time.sleep(delay) +# pylint: enable=misplaced-bare-raise + + +def _get_retry_delay(exc): + """Helper for :func:`_delay_until_retry`. + + :type exc: :class:`google.gax.errors.GaxError` + :param exc: exception for aborted transaction + + :rtype: float + :returns: seconds to wait before retrying the transaction. + """ + metadata = dict(exc.cause.trailing_metadata()) + retry_info_pb = metadata.get('google.rpc.retryinfo-bin') + if retry_info_pb is not None: + retry_info = RetryInfo() + retry_info.ParseFromString(retry_info_pb) + nanos = retry_info.retry_delay.nanos + return retry_info.retry_delay.seconds + nanos / 1.0e9 diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py new file mode 100644 index 000000000000..22b39dbc813d --- /dev/null +++ b/spanner/google/cloud/spanner/snapshot.py @@ -0,0 +1,197 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model a set of read-only queries to a database as a snapshot.""" + +from google.protobuf.struct_pb2 import Struct +from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions +from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector + +from google.cloud._helpers import _datetime_to_pb_timestamp +from google.cloud._helpers import _timedelta_to_duration_pb +from google.cloud.spanner._helpers import _make_value_pb +from google.cloud.spanner._helpers import _options_with_prefix +from google.cloud.spanner._helpers import _SessionWrapper +from google.cloud.spanner.streamed import StreamedResultSet + + +class _SnapshotBase(_SessionWrapper): + """Base class for Snapshot. + + Allows reuse of API request methods with different transaction selector. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session used to perform the commit + """ + def _make_txn_selector(self): # pylint: disable=redundant-returns-doc + """Helper for :meth:`read` / :meth:`execute_sql`. + + Subclasses must override, returning an instance of + :class:`transaction_pb2.TransactionSelector` + appropriate for making ``read`` / ``execute_sql`` requests + + :raises: NotImplementedError, always + """ + raise NotImplementedError + + def read(self, table, columns, keyset, index='', limit=0, + resume_token=b''): + """Perform a ``StreamingRead`` API request for rows in a table. + + :type table: str + :param table: name of the table from which to fetch data + + :type columns: list of str + :param columns: names of columns to be retrieved + + :type keyset: :class:`~google.cloud.spanner.keyset.KeySet` + :param keyset: keys / ranges identifying rows to be retrieved + + :type index: str + :param index: (Optional) name of index to use, rather than the + table's primary key + + :type limit: int + :param limit: (Optional) maxiumn number of rows to return + + :type resume_token: bytes + :param resume_token: token for resuming previously-interrupted read + + :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + """ + database = self._session._database + api = database.spanner_api + options = _options_with_prefix(database.name) + transaction = self._make_txn_selector() + + iterator = api.streaming_read( + self._session.name, table, columns, keyset.to_pb(), + transaction=transaction, index=index, limit=limit, + resume_token=resume_token, options=options) + + return StreamedResultSet(iterator) + + def execute_sql(self, sql, params=None, param_types=None, query_mode=None, + resume_token=b''): + """Perform an ``ExecuteStreamingSql`` API request for rows in a table. + + :type sql: str + :param sql: SQL query statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``sql``. + + :type param_types: dict + :param param_types: + (Optional) maps explicit types for one or more param values; + required if parameters are passed. + + :type query_mode: + :class:`google.cloud.proto.spanner.v1.ExecuteSqlRequest.QueryMode` + :param query_mode: Mode governing return of results / query plan. See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + + :type resume_token: bytes + :param resume_token: token for resuming previously-interrupted query + + :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + """ + if params is not None: + if param_types is None: + raise ValueError( + "Specify 'param_types' when passing 'params'.") + params_pb = Struct(fields={ + key: _make_value_pb(value) for key, value in params.items()}) + else: + params_pb = None + + database = self._session._database + options = _options_with_prefix(database.name) + transaction = self._make_txn_selector() + api = database.spanner_api + iterator = api.execute_streaming_sql( + self._session.name, sql, + transaction=transaction, params=params_pb, param_types=param_types, + query_mode=query_mode, resume_token=resume_token, options=options) + + return StreamedResultSet(iterator) + + +class Snapshot(_SnapshotBase): + """Allow a set of reads / SQL statements with shared staleness. + + See: + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.TransactionOptions.ReadOnly + + If no options are passed, reads will use the ``strong`` model, reading + at a timestamp where all previously committed transactions are visible. + + :type session: :class:`~google.cloud.spanner.session.Session` + :param session: the session used to perform the commit. + + :type read_timestamp: :class:`datetime.datetime` + :param read_timestamp: Execute all reads at the given timestamp. + + :type min_read_timestamp: :class:`datetime.datetime` + :param min_read_timestamp: Execute all reads at a + timestamp >= ``min_read_timestamp``. + + :type max_staleness: :class:`datetime.timedelta` + :param max_staleness: Read data at a + timestamp >= NOW - ``max_staleness`` seconds. + + :type exact_staleness: :class:`datetime.timedelta` + :param exact_staleness: Execute all reads at a timestamp that is + ``exact_staleness`` old. + """ + def __init__(self, session, read_timestamp=None, min_read_timestamp=None, + max_staleness=None, exact_staleness=None): + super(Snapshot, self).__init__(session) + opts = [ + read_timestamp, min_read_timestamp, max_staleness, exact_staleness] + flagged = [opt for opt in opts if opt is not None] + + if len(flagged) > 1: + raise ValueError("Supply zero or one options.") + + self._strong = len(flagged) == 0 + self._read_timestamp = read_timestamp + self._min_read_timestamp = min_read_timestamp + self._max_staleness = max_staleness + self._exact_staleness = exact_staleness + + def _make_txn_selector(self): + """Helper for :meth:`read`.""" + if self._read_timestamp: + key = 'read_timestamp' + value = _datetime_to_pb_timestamp(self._read_timestamp) + elif self._min_read_timestamp: + key = 'min_read_timestamp' + value = _datetime_to_pb_timestamp(self._min_read_timestamp) + elif self._max_staleness: + key = 'max_staleness' + value = _timedelta_to_duration_pb(self._max_staleness) + elif self._exact_staleness: + key = 'exact_staleness' + value = _timedelta_to_duration_pb(self._exact_staleness) + else: + key = 'strong' + value = True + + options = TransactionOptions( + read_only=TransactionOptions.ReadOnly(**{key: value})) + return TransactionSelector(single_use=options) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py new file mode 100644 index 000000000000..74c7e8754334 --- /dev/null +++ b/spanner/google/cloud/spanner/streamed.py @@ -0,0 +1,262 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper for streaming results.""" + +from google.protobuf.struct_pb2 import ListValue +from google.protobuf.struct_pb2 import Value +from google.cloud.proto.spanner.v1 import type_pb2 +import six + +# pylint: disable=ungrouped-imports +from google.cloud.spanner._helpers import _parse_value_pb +# pylint: enable=ungrouped-imports + + +class StreamedResultSet(object): + """Process a sequence of partial result sets into a single set of row data. + + :type response_iterator: + :param response_iterator: + Iterator yielding + :class:`google.cloud.proto.spanner.v1.result_set_pb2.PartialResultSet` + instances. + """ + def __init__(self, response_iterator): + self._response_iterator = response_iterator + self._rows = [] # Fully-processed rows + self._counter = 0 # Counter for processed responses + self._metadata = None # Until set from first PRS + self._stats = None # Until set from last PRS + self._resume_token = None # To resume from last received PRS + self._current_row = [] # Accumulated values for incomplete row + self._pending_chunk = None # Incomplete value + + @property + def rows(self): + """Fully-processed rows. + + :rtype: list of row-data lists. + :returns: list of completed row data, from proceesd PRS responses. + """ + return self._rows + + @property + def fields(self): + """Field descriptors for result set columns. + + :rtype: list of :class:`~google.cloud.proto.spanner.v1.type_pb2.Field` + :returns: list of fields describing column names / types. + """ + return self._metadata.row_type.fields + + @property + def metadata(self): + """Result set metadata + + :rtype: :class:`~.result_set_pb2.ResultSetMetadata` + :returns: structure describing the results + """ + return self._metadata + + @property + def stats(self): + """Result set statistics + + :rtype: + :class:`~google.cloud.proto.spanner.v1.result_set_pb2.ResultSetStats` + :returns: structure describing status about the response + """ + return self._stats + + @property + def resume_token(self): + """Token for resuming interrupted read / query. + + :rtype: bytes + :returns: token from last chunk of results. + """ + return self._resume_token + + def _merge_chunk(self, value): + """Merge pending chunk with next value. + + :type value: :class:`~google.protobuf.struct_pb2.Value` + :param value: continuation of chunked value from previous + partial result set. + + :rtype: :class:`~google.protobuf.struct_pb2.Value` + :returns: the merged value + """ + current_column = len(self._current_row) + field = self.fields[current_column] + merged = _merge_by_type(self._pending_chunk, value, field.type) + self._pending_chunk = None + return merged + + def _merge_values(self, values): + """Merge values into rows. + + :type values: list of :class:`~google.protobuf.struct_pb2.Value` + :param values: non-chunked values from partial result set. + """ + width = len(self.fields) + for value in values: + index = len(self._current_row) + field = self.fields[index] + self._current_row.append(_parse_value_pb(value, field.type)) + if len(self._current_row) == width: + self._rows.append(self._current_row) + self._current_row = [] + + def consume_next(self): + """Consume the next partial result set from the stream. + + Parse the result set into new/existing rows in :attr:`_rows` + """ + response = six.next(self._response_iterator) + self._counter += 1 + self._resume_token = response.resume_token + + if self._metadata is None: # first response + self._metadata = response.metadata + + if response.HasField('stats'): # last response + self._stats = response.stats + + values = list(response.values) + if self._pending_chunk is not None: + values[0] = self._merge_chunk(values[0]) + + if response.chunked_value: + self._pending_chunk = values.pop() + + self._merge_values(values) + + def consume_all(self): + """Consume the streamed responses until there are no more.""" + while True: + try: + self.consume_next() + except StopIteration: + break + + def __iter__(self): + iter_rows, self._rows[:] = self._rows[:], () + while True: + if len(iter_rows) == 0: + self.consume_next() # raises StopIteration + iter_rows, self._rows[:] = self._rows[:], () + while iter_rows: + yield iter_rows.pop(0) + + +class Unmergeable(ValueError): + """Unable to merge two values. + + :type lhs: :class:`google.protobuf.struct_pb2.Value` + :param lhs: pending value to be merged + + :type rhs: :class:`google.protobuf.struct_pb2.Value` + :param rhs: remaining value to be merged + + :type type_: :class:`google.cloud.proto.spanner.v1.type_pb2.Type` + :param type_: field type of values being merged + """ + def __init__(self, lhs, rhs, type_): + message = "Cannot merge %s values: %s %s" % ( + type_pb2.TypeCode.Name(type_.code), lhs, rhs) + super(Unmergeable, self).__init__(message) + + +def _unmergeable(lhs, rhs, type_): + """Helper for '_merge_by_type'.""" + raise Unmergeable(lhs, rhs, type_) + + +def _merge_float64(lhs, rhs, type_): # pylint: disable=unused-argument + """Helper for '_merge_by_type'.""" + lhs_kind = lhs.WhichOneof('kind') + if lhs_kind == 'string_value': + return Value(string_value=lhs.string_value + rhs.string_value) + rhs_kind = rhs.WhichOneof('kind') + array_continuation = ( + lhs_kind == 'number_value' and + rhs_kind == 'string_value' and + rhs.string_value == '') + if array_continuation: + return lhs + raise Unmergeable(lhs, rhs, type_) + + +def _merge_string(lhs, rhs, type_): # pylint: disable=unused-argument + """Helper for '_merge_by_type'.""" + return Value(string_value=lhs.string_value + rhs.string_value) + + +_UNMERGEABLE_TYPES = (type_pb2.BOOL,) + + +def _merge_array(lhs, rhs, type_): + """Helper for '_merge_by_type'.""" + element_type = type_.array_element_type + if element_type.code in _UNMERGEABLE_TYPES: + # Individual values cannot be merged, just concatenate + lhs.list_value.values.extend(rhs.list_value.values) + return lhs + lhs, rhs = list(lhs.list_value.values), list(rhs.list_value.values) + first = rhs.pop(0) + if first.HasField('null_value'): # can't merge + lhs.append(first) + else: + last = lhs.pop() + try: + merged = _merge_by_type(last, first, element_type) + except Unmergeable: + lhs.append(last) + lhs.append(first) + else: + lhs.append(merged) + return Value(list_value=ListValue(values=(lhs + rhs))) + + +def _merge_struct(lhs, rhs, type_): + """Helper for '_merge_by_type'.""" + fields = type_.struct_type.fields + lhs, rhs = list(lhs.list_value.values), list(rhs.list_value.values) + candidate_type = fields[len(lhs) - 1].type + first = rhs.pop(0) + if (first.HasField('null_value') or + candidate_type.code in _UNMERGEABLE_TYPES): + lhs.append(first) + else: + last = lhs.pop() + lhs.append(_merge_by_type(last, first, candidate_type)) + return Value(list_value=ListValue(values=lhs + rhs)) + + +_MERGE_BY_TYPE = { + type_pb2.BOOL: _unmergeable, + type_pb2.INT64: _merge_string, + type_pb2.FLOAT64: _merge_float64, + type_pb2.STRING: _merge_string, + type_pb2.ARRAY: _merge_array, + type_pb2.STRUCT: _merge_struct, +} + + +def _merge_by_type(lhs, rhs, type_): + """Helper for '_merge_chunk'.""" + merger = _MERGE_BY_TYPE[type_.code] + return merger(lhs, rhs, type_) diff --git a/spanner/google/cloud/spanner/transaction.py b/spanner/google/cloud/spanner/transaction.py new file mode 100644 index 000000000000..af2140896830 --- /dev/null +++ b/spanner/google/cloud/spanner/transaction.py @@ -0,0 +1,129 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spanner read-write transaction support.""" + +from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector +from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions + +from google.cloud._helpers import _pb_timestamp_to_datetime +from google.cloud.spanner._helpers import _options_with_prefix +from google.cloud.spanner.snapshot import _SnapshotBase +from google.cloud.spanner.batch import _BatchBase + + +class Transaction(_SnapshotBase, _BatchBase): + """Implement read-write transaction semantics for a session.""" + committed = None + """Timestamp at which the transaction was successfully committed.""" + + def __init__(self, session): + super(Transaction, self).__init__(session) + self._id = None + self._rolled_back = False + + def _check_state(self): + """Helper for :meth:`commit` et al. + + :raises: :exc:`ValueError` if the object's state is invalid for making + API requests. + """ + if self._id is None: + raise ValueError("Transaction is not begun") + + if self.committed is not None: + raise ValueError("Transaction is already committed") + + if self._rolled_back: + raise ValueError("Transaction is already rolled back") + + def _make_txn_selector(self): + """Helper for :meth:`read`. + + :rtype: + :class:`~.transaction_pb2.TransactionSelector` + :returns: a selector configured for read-write transaction semantics. + """ + self._check_state() + return TransactionSelector(id=self._id) + + def begin(self): + """Begin a transaction on the database. + + :rtype: bytes + :returns: the ID for the newly-begun transaction. + :raises: ValueError if the transaction is already begun, committed, + or rolled back. + """ + if self._id is not None: + raise ValueError("Transaction already begun") + + if self.committed is not None: + raise ValueError("Transaction already committed") + + if self._rolled_back: + raise ValueError("Transaction is already rolled back") + + database = self._session._database + api = database.spanner_api + options = _options_with_prefix(database.name) + txn_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite()) + response = api.begin_transaction( + self._session.name, txn_options, options=options) + self._id = response.id + return self._id + + def rollback(self): + """Roll back a transaction on the database.""" + self._check_state() + database = self._session._database + api = database.spanner_api + options = _options_with_prefix(database.name) + api.rollback(self._session.name, self._id, options=options) + self._rolled_back = True + + def commit(self): + """Commit mutations to the database. + + :rtype: datetime + :returns: timestamp of the committed changes. + :raises: :exc:`ValueError` if there are no mutations to commit. + """ + self._check_state() + + if len(self._mutations) == 0: + raise ValueError("No mutations to commit") + + database = self._session._database + api = database.spanner_api + options = _options_with_prefix(database.name) + response = api.commit( + self._session.name, self._mutations, + transaction_id=self._id, options=options) + self.committed = _pb_timestamp_to_datetime( + response.commit_timestamp) + return self.committed + + def __enter__(self): + """Begin ``with`` block.""" + self.begin() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """End ``with`` block.""" + if exc_type is None: + self.commit() + else: + self.rollback() diff --git a/spanner/setup.cfg b/spanner/setup.cfg new file mode 100644 index 000000000000..2a9acf13daa9 --- /dev/null +++ b/spanner/setup.cfg @@ -0,0 +1,2 @@ +[bdist_wheel] +universal = 1 diff --git a/spanner/setup.py b/spanner/setup.py new file mode 100644 index 000000000000..aa35996b2a50 --- /dev/null +++ b/spanner/setup.py @@ -0,0 +1,72 @@ +# Copyright 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from setuptools import find_packages +from setuptools import setup + + +PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__)) + +with open(os.path.join(PACKAGE_ROOT, 'README.rst')) as file_obj: + README = file_obj.read() + +# NOTE: This is duplicated throughout and we should try to +# consolidate. +SETUP_BASE = { + 'author': 'Google Cloud Platform', + 'author_email': 'jjg+google-cloud-python@google.com', + 'scripts': [], + 'url': 'https://github.com/GoogleCloudPlatform/google-cloud-python', + 'license': 'Apache 2.0', + 'platforms': 'Posix; MacOS X; Windows', + 'include_package_data': True, + 'zip_safe': False, + 'classifiers': [ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Topic :: Internet', + ], +} + + +REQUIREMENTS = [ + 'google-cloud-core >= 0.23.0, < 0.24dev', + 'grpcio >= 1.0.2, < 2.0dev', + 'gapic-google-cloud-spanner-v1 >= 0.15.0, < 0.16dev', + 'gapic-google-cloud-spanner-admin-database-v1 >= 0.15.0, < 0.16dev', + 'gapic-google-cloud-spanner-admin-instance-v1 >= 0.15.0, < 0.16dev', +] + +setup( + name='google-cloud-spanner', + version='0.23.0', + description='Python Client for Cloud Spanner', + long_description=README, + namespace_packages=[ + 'google', + 'google.cloud', + ], + packages=find_packages(), + install_requires=REQUIREMENTS, + **SETUP_BASE +) diff --git a/spanner/tox.ini b/spanner/tox.ini new file mode 100644 index 000000000000..9e509cc9b05e --- /dev/null +++ b/spanner/tox.ini @@ -0,0 +1,31 @@ +[tox] +envlist = + py27,py34,py35,cover + +[testing] +deps = + {toxinidir}/../core + pytest + mock +covercmd = + py.test --quiet \ + --cov=google.cloud.spanner \ + --cov=unit_tests \ + --cov-config {toxinidir}/.coveragerc \ + unit_tests + +[testenv] +commands = + py.test --quiet {posargs} unit_tests +deps = + {[testing]deps} + +[testenv:cover] +basepython = + python2.7 +commands = + {[testing]covercmd} +deps = + {[testenv]deps} + coverage + pytest-cov diff --git a/spanner/unit_tests/__init__.py b/spanner/unit_tests/__init__.py new file mode 100644 index 000000000000..58e0d9153632 --- /dev/null +++ b/spanner/unit_tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/spanner/unit_tests/streaming-read-acceptance-test.json b/spanner/unit_tests/streaming-read-acceptance-test.json new file mode 100644 index 000000000000..9b44b4077812 --- /dev/null +++ b/spanner/unit_tests/streaming-read-acceptance-test.json @@ -0,0 +1,217 @@ +{"tests": [ + { + "result": {"value": [[ + true, + "abc", + "100", + 1.1, + "YWJj", + [ + "abc", + "def", + null, + "ghi" + ], + [ + ["abc"], + ["def"], + ["ghi"] + ] + ]]}, + "chunks": ["{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"BOOL\"\n }\n }, {\n \"name\": \"f2\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }, {\n \"name\": \"f3\",\n \"type\": {\n \"code\": \"INT64\"\n }\n }, {\n \"name\": \"f4\",\n \"type\": {\n \"code\": \"FLOAT64\"\n }\n }, {\n \"name\": \"f5\",\n \"type\": {\n \"code\": \"BYTES\"\n }\n }, {\n \"name\": \"f6\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRING\"\n }\n }\n }, {\n \"name\": \"f7\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRUCT\",\n \"structType\": {\n \"fields\": [{\n \"name\": \"f71\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }]\n }\n }\n }\n }]\n }\n },\n \"values\": [true, \"abc\", \"100\", 1.1, \"YWJj\", [\"abc\", \"def\", null, \"ghi\"], [[\"abc\"], [\"def\"], [\"ghi\"]]]\n}"], + "name": "Basic Test" + }, + { + "result": {"value": [["abcdefghi"]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }]\n }\n },\n \"values\": [\"abc\"],\n \"chunkedValue\": true\n}", + "{\n \"values\": [\"def\"],\n \"chunkedValue\": true\n}", + "{\n \"values\": [\"ghi\"]\n}" + ], + "name": "String Chunking Test" + }, + { + "result": {"value": [[[ + "abc", + "def", + "ghi", + "jkl" + ]]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRING\"\n }\n }\n }]\n }\n },\n \"values\": [[\"abc\", \"d\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"ef\", \"gh\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"i\", \"jkl\"]]\n}" + ], + "name": "String Array Chunking Test" + }, + { + "result": {"value": [[[ + "abc", + "def", + null, + "ghi", + null, + "jkl" + ]]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRING\"\n }\n }\n }]\n }\n },\n \"values\": [[\"abc\", \"def\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[null, \"ghi\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[null, \"jkl\"]]\n}" + ], + "name": "String Array Chunking Test With Nulls" + }, + { + "result": {"value": [[[ + "abc", + "def", + "ghi", + "jkl" + ]]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRING\"\n }\n }\n }]\n }\n },\n \"values\": [[\"abc\", \"def\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"\", \"ghi\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"\", \"jkl\"]]\n}" + ], + "name": "String Array Chunking Test With Empty Strings" + }, + { + "result": {"value": [[["abcdefghi"]]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRING\"\n }\n }\n }]\n }\n },\n \"values\": [[\"abc\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"def\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"ghi\"]]\n}" + ], + "name": "String Array Chunking Test With One Large String" + }, + { + "result": {"value": [[[ + "1", + "23", + "4", + null, + 5 + ]]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"INT64\"\n }\n }\n }]\n }\n },\n \"values\": [[\"1\", \"2\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"3\", \"4\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"\", null, \"5\"]]\n}" + ], + "name": "INT64 Array Chunking Test" + }, + { + "result": {"value": [[[ + 1, + 2, + "Infinity", + "-Infinity", + "NaN", + null, + 3 + ]]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"FLOAT64\"\n }\n }\n }]\n }\n },\n \"values\": [[1.0, 2.0]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"Infinity\", \"-Infinity\", \"NaN\"]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[\"\", null, 3.0]]\n}" + ], + "name": "FLOAT64 Array Chunking Test" + }, + { + "result": {"value": [[[ + [ + "abc", + "defghi" + ], + [ + "123", + "456" + ] + ]]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRUCT\",\n \"structType\": {\n \"fields\": [{\n \"name\": \"f11\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }, {\n \"name\": \"f12\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }]\n }\n }\n }\n }]\n }\n },\n \"values\": [[[\"abc\", \"def\"]]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[[\"ghi\"], [\"123\", \"456\"]]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[[\"\"]]]\n}" + ], + "name": "Struct Array Chunking Test" + }, + { + "result": {"value": [[[[[["abc"]]]]]]}, + "chunks": ["{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRUCT\",\n \"structType\": {\n \"fields\": [{\n \"name\": \"f11\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRUCT\",\n \"structType\": {\n \"fields\": [{\n \"name\": \"f12\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }]\n }\n }\n }\n }]\n }\n }\n }\n }]\n }\n },\n \"values\": [[[[[\"abc\"]]]]]\n}"], + "name": "Nested Struct Array Test" + }, + { + "result": {"value": [[[[[ + ["abc"], + ["def"] + ]]]]]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRUCT\",\n \"structType\": {\n \"fields\": [{\n \"name\": \"f11\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRUCT\",\n \"structType\": {\n \"fields\": [{\n \"name\": \"f12\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }]\n }\n }\n }\n }]\n }\n }\n }\n }]\n }\n },\n \"values\": [[[[[\"ab\"]]]]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[[[[\"c\"], [\"def\"]]]]]\n}" + ], + "name": "Nested Struct Array Chunking Test" + }, + { + "result": {"value": [ + [ + "1", + [["ab"]] + ], + [ + "2", + [["c"]] + ] + ]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }, {\n \"name\": \"f2\",\n \"type\": {\n \"code\": \"ARRAY\",\n \"arrayElementType\": {\n \"code\": \"STRUCT\",\n \"structType\": {\n \"fields\": [{\n \"name\": \"f21\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }]\n }\n }\n }\n }]\n }\n },\n \"values\": [\"1\", [[\"a\"]]],\n \"chunkedValue\": true\n}", + "{\n \"values\": [[[\"b\"]], \"2\"],\n \"chunkedValue\": true\n}", + "{\n \"values\": [\"\", [[\"c\"]]]\n}" + ], + "name": "Struct Array And String Chunking Test" + }, + { + "result": {"value": [ + [ + "abc", + "1" + ], + [ + "def", + "2" + ] + ]}, + "chunks": ["{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }, {\n \"name\": \"f2\",\n \"type\": {\n \"code\": \"INT64\"\n }\n }]\n }\n },\n \"values\": [\"abc\", \"1\", \"def\", \"2\"]\n}"], + "name": "Multiple Row Single Chunk" + }, + { + "result": {"value": [ + [ + "abc", + "1" + ], + [ + "def", + "2" + ] + ]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }, {\n \"name\": \"f2\",\n \"type\": {\n \"code\": \"INT64\"\n }\n }]\n }\n },\n \"values\": [\"ab\"],\n \"chunkedValue\": true\n}", + "{\n \"values\": [\"c\", \"1\", \"de\"],\n \"chunkedValue\": true\n}", + "{\n \"values\": [\"f\", \"2\"]\n}" + ], + "name": "Multiple Row Multiple Chunks" + }, + { + "result": {"value": [ + ["ab"], + ["c"], + ["d"], + ["ef"] + ]}, + "chunks": [ + "{\n \"metadata\": {\n \"rowType\": {\n \"fields\": [{\n \"name\": \"f1\",\n \"type\": {\n \"code\": \"STRING\"\n }\n }]\n }\n },\n \"values\": [\"a\"],\n \"chunkedValue\": true\n}", + "{\n \"values\": [\"b\", \"c\"]\n}", + "{\n \"values\": [\"d\", \"e\"],\n \"chunkedValue\": true\n}", + "{\n \"values\": [\"f\"]\n}" + ], + "name": "Multiple Row Chunks/Non Chunks Interleaved" + } +]} diff --git a/spanner/unit_tests/test__helpers.py b/spanner/unit_tests/test__helpers.py new file mode 100644 index 000000000000..2b432d446ab0 --- /dev/null +++ b/spanner/unit_tests/test__helpers.py @@ -0,0 +1,498 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + + +class TestTimestampWithNanoseconds(unittest.TestCase): + + def _get_target_class(self): + from google.cloud.spanner._helpers import TimestampWithNanoseconds + return TimestampWithNanoseconds + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_ctor_wo_nanos(self): + stamp = self._make_one(2016, 12, 20, 21, 13, 47, 123456) + self.assertEqual(stamp.year, 2016) + self.assertEqual(stamp.month, 12) + self.assertEqual(stamp.day, 20) + self.assertEqual(stamp.hour, 21) + self.assertEqual(stamp.minute, 13) + self.assertEqual(stamp.second, 47) + self.assertEqual(stamp.microsecond, 123456) + self.assertEqual(stamp.nanosecond, 0) + + def test_ctor_w_nanos(self): + stamp = self._make_one( + 2016, 12, 20, 21, 13, 47, nanosecond=123456789) + self.assertEqual(stamp.year, 2016) + self.assertEqual(stamp.month, 12) + self.assertEqual(stamp.day, 20) + self.assertEqual(stamp.hour, 21) + self.assertEqual(stamp.minute, 13) + self.assertEqual(stamp.second, 47) + self.assertEqual(stamp.microsecond, 123456) + self.assertEqual(stamp.nanosecond, 123456789) + + def test_ctor_w_micros_positional_and_nanos(self): + with self.assertRaises(TypeError): + self._make_one( + 2016, 12, 20, 21, 13, 47, 123456, nanosecond=123456789) + + def test_ctor_w_micros_keyword_and_nanos(self): + with self.assertRaises(TypeError): + self._make_one( + 2016, 12, 20, 21, 13, 47, + microsecond=123456, nanosecond=123456789) + + def test_rfc339_wo_nanos(self): + stamp = self._make_one(2016, 12, 20, 21, 13, 47, 123456) + self.assertEqual(stamp.rfc3339(), + '2016-12-20T21:13:47.123456Z') + + def test_rfc339_w_nanos(self): + stamp = self._make_one(2016, 12, 20, 21, 13, 47, nanosecond=123456789) + self.assertEqual(stamp.rfc3339(), + '2016-12-20T21:13:47.123456789Z') + + def test_rfc339_w_nanos_no_trailing_zeroes(self): + stamp = self._make_one(2016, 12, 20, 21, 13, 47, nanosecond=100000000) + self.assertEqual(stamp.rfc3339(), + '2016-12-20T21:13:47.1Z') + + def test_from_rfc3339_w_invalid(self): + klass = self._get_target_class() + STAMP = '2016-12-20T21:13:47' + with self.assertRaises(ValueError): + klass.from_rfc3339(STAMP) + + def test_from_rfc3339_wo_fraction(self): + from google.cloud._helpers import UTC + klass = self._get_target_class() + STAMP = '2016-12-20T21:13:47Z' + expected = self._make_one(2016, 12, 20, 21, 13, 47, tzinfo=UTC) + stamp = klass.from_rfc3339(STAMP) + self.assertEqual(stamp, expected) + + def test_from_rfc3339_w_partial_precision(self): + from google.cloud._helpers import UTC + klass = self._get_target_class() + STAMP = '2016-12-20T21:13:47.1Z' + expected = self._make_one(2016, 12, 20, 21, 13, 47, + microsecond=100000, tzinfo=UTC) + stamp = klass.from_rfc3339(STAMP) + self.assertEqual(stamp, expected) + + def test_from_rfc3339_w_full_precision(self): + from google.cloud._helpers import UTC + klass = self._get_target_class() + STAMP = '2016-12-20T21:13:47.123456789Z' + expected = self._make_one(2016, 12, 20, 21, 13, 47, + nanosecond=123456789, tzinfo=UTC) + stamp = klass.from_rfc3339(STAMP) + self.assertEqual(stamp, expected) + + +class Test_make_value_pb(unittest.TestCase): + + def _callFUT(self, *args, **kw): + from google.cloud.spanner._helpers import _make_value_pb + return _make_value_pb(*args, **kw) + + def test_w_None(self): + value_pb = self._callFUT(None) + self.assertTrue(value_pb.HasField('null_value')) + + def test_w_bytes(self): + from google.protobuf.struct_pb2 import Value + BYTES = b'BYTES' + expected = Value(string_value=BYTES) + value_pb = self._callFUT(BYTES) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb, expected) + + def test_w_invalid_bytes(self): + BYTES = b'\xff\xfe\x03&' + with self.assertRaises(ValueError): + self._callFUT(BYTES) + + def test_w_explicit_unicode(self): + from google.protobuf.struct_pb2 import Value + TEXT = u'TEXT' + value_pb = self._callFUT(TEXT) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, TEXT) + + def test_w_list(self): + from google.protobuf.struct_pb2 import Value + from google.protobuf.struct_pb2 import ListValue + value_pb = self._callFUT([u'a', u'b', u'c']) + self.assertIsInstance(value_pb, Value) + self.assertIsInstance(value_pb.list_value, ListValue) + values = value_pb.list_value.values + self.assertEqual([value.string_value for value in values], + [u'a', u'b', u'c']) + + def test_w_bool(self): + from google.protobuf.struct_pb2 import Value + value_pb = self._callFUT(True) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.bool_value, True) + + def test_w_int(self): + import six + from google.protobuf.struct_pb2 import Value + for int_type in six.integer_types: # include 'long' on Python 2 + value_pb = self._callFUT(int_type(42)) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, '42') + + def test_w_float(self): + from google.protobuf.struct_pb2 import Value + value_pb = self._callFUT(3.14159) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.number_value, 3.14159) + + def test_w_float_nan(self): + from google.protobuf.struct_pb2 import Value + value_pb = self._callFUT(float('nan')) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, 'NaN') + + def test_w_float_neg_inf(self): + from google.protobuf.struct_pb2 import Value + value_pb = self._callFUT(float('-inf')) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, '-inf') + + def test_w_float_pos_inf(self): + from google.protobuf.struct_pb2 import Value + value_pb = self._callFUT(float('inf')) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, 'inf') + + def test_w_date(self): + import datetime + from google.protobuf.struct_pb2 import Value + today = datetime.date.today() + value_pb = self._callFUT(today) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, today.isoformat()) + + def test_w_timestamp_w_nanos(self): + from google.protobuf.struct_pb2 import Value + from google.cloud._helpers import UTC + from google.cloud.spanner._helpers import TimestampWithNanoseconds + when = TimestampWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=UTC) + value_pb = self._callFUT(when) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, when.rfc3339()) + + def test_w_datetime(self): + import datetime + from google.protobuf.struct_pb2 import Value + from google.cloud._helpers import UTC, _datetime_to_rfc3339 + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + value_pb = self._callFUT(now) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, _datetime_to_rfc3339(now)) + + def test_w_unknown_type(self): + with self.assertRaises(ValueError): + self._callFUT(object()) + + +class Test_make_list_value_pb(unittest.TestCase): + + def _callFUT(self, *args, **kw): + from google.cloud.spanner._helpers import _make_list_value_pb + return _make_list_value_pb(*args, **kw) + + def test_empty(self): + from google.protobuf.struct_pb2 import ListValue + result = self._callFUT(values=[]) + self.assertIsInstance(result, ListValue) + self.assertEqual(len(result.values), 0) + + def test_w_single_value(self): + from google.protobuf.struct_pb2 import ListValue + VALUE = u'value' + result = self._callFUT(values=[VALUE]) + self.assertIsInstance(result, ListValue) + self.assertEqual(len(result.values), 1) + self.assertEqual(result.values[0].string_value, VALUE) + + def test_w_multiple_values(self): + from google.protobuf.struct_pb2 import ListValue + VALUE_1 = u'value' + VALUE_2 = 42 + result = self._callFUT(values=[VALUE_1, VALUE_2]) + self.assertIsInstance(result, ListValue) + self.assertEqual(len(result.values), 2) + self.assertEqual(result.values[0].string_value, VALUE_1) + self.assertEqual(result.values[1].string_value, str(VALUE_2)) + + +class Test_make_list_value_pbs(unittest.TestCase): + + def _callFUT(self, *args, **kw): + from google.cloud.spanner._helpers import _make_list_value_pbs + return _make_list_value_pbs(*args, **kw) + + def test_empty(self): + result = self._callFUT(values=[]) + self.assertEqual(result, []) + + def test_w_single_values(self): + from google.protobuf.struct_pb2 import ListValue + values = [[0], [1]] + result = self._callFUT(values=values) + self.assertEqual(len(result), len(values)) + for found, expected in zip(result, values): + self.assertIsInstance(found, ListValue) + self.assertEqual(len(found.values), 1) + self.assertEqual(found.values[0].string_value, str(expected[0])) + + def test_w_multiple_values(self): + from google.protobuf.struct_pb2 import ListValue + values = [[0, u'A'], [1, u'B']] + result = self._callFUT(values=values) + self.assertEqual(len(result), len(values)) + for found, expected in zip(result, values): + self.assertIsInstance(found, ListValue) + self.assertEqual(len(found.values), 2) + self.assertEqual(found.values[0].string_value, str(expected[0])) + self.assertEqual(found.values[1].string_value, expected[1]) + + +class Test_parse_value_pb(unittest.TestCase): + + def _callFUT(self, *args, **kw): + from google.cloud.spanner._helpers import _parse_value_pb + return _parse_value_pb(*args, **kw) + + def test_w_null(self): + from google.protobuf.struct_pb2 import Value, NULL_VALUE + from google.cloud.proto.spanner.v1.type_pb2 import Type, STRING + field_type = Type(code=STRING) + value_pb = Value(null_value=NULL_VALUE) + + self.assertEqual(self._callFUT(value_pb, field_type), None) + + def test_w_string(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, STRING + VALUE = u'Value' + field_type = Type(code=STRING) + value_pb = Value(string_value=VALUE) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + + def test_w_bytes(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, BYTES + VALUE = b'Value' + field_type = Type(code=BYTES) + value_pb = Value(string_value=VALUE) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + + def test_w_bool(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, BOOL + VALUE = True + field_type = Type(code=BOOL) + value_pb = Value(bool_value=VALUE) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + + def test_w_int(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, INT64 + VALUE = 12345 + field_type = Type(code=INT64) + value_pb = Value(string_value=str(VALUE)) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + + def test_w_float(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, FLOAT64 + VALUE = 3.14159 + field_type = Type(code=FLOAT64) + value_pb = Value(number_value=VALUE) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + + def test_w_date(self): + import datetime + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, DATE + VALUE = datetime.date.today() + field_type = Type(code=DATE) + value_pb = Value(string_value=VALUE.isoformat()) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + + def test_w_timestamp_wo_nanos(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, TIMESTAMP + from google.cloud._helpers import UTC, _datetime_to_rfc3339 + from google.cloud.spanner._helpers import TimestampWithNanoseconds + VALUE = TimestampWithNanoseconds( + 2016, 12, 20, 21, 13, 47, microsecond=123456, tzinfo=UTC) + field_type = Type(code=TIMESTAMP) + value_pb = Value(string_value=_datetime_to_rfc3339(VALUE)) + + parsed = self._callFUT(value_pb, field_type) + self.assertIsInstance(parsed, TimestampWithNanoseconds) + self.assertEqual(parsed, VALUE) + + def test_w_timestamp_w_nanos(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, TIMESTAMP + from google.cloud._helpers import UTC, _datetime_to_rfc3339 + from google.cloud.spanner._helpers import TimestampWithNanoseconds + VALUE = TimestampWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=UTC) + field_type = Type(code=TIMESTAMP) + value_pb = Value(string_value=_datetime_to_rfc3339(VALUE)) + + parsed = self._callFUT(value_pb, field_type) + self.assertIsInstance(parsed, TimestampWithNanoseconds) + self.assertEqual(parsed, VALUE) + + def test_w_array_empty(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, ARRAY, INT64 + field_type = Type(code=ARRAY, array_element_type=Type(code=INT64)) + value_pb = Value() + + self.assertEqual(self._callFUT(value_pb, field_type), []) + + def test_w_array_non_empty(self): + from google.protobuf.struct_pb2 import Value, ListValue + from google.cloud.proto.spanner.v1.type_pb2 import Type, ARRAY, INT64 + field_type = Type(code=ARRAY, array_element_type=Type(code=INT64)) + VALUES = [32, 19, 5] + values_pb = ListValue( + values=[Value(string_value=str(value)) for value in VALUES]) + value_pb = Value(list_value=values_pb) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUES) + + def test_w_struct(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type, StructType + from google.cloud.proto.spanner.v1.type_pb2 import ( + STRUCT, STRING, INT64) + from google.cloud.spanner._helpers import _make_list_value_pb + VALUES = [u'phred', 32] + struct_type_pb = StructType(fields=[ + StructType.Field(name='name', type=Type(code=STRING)), + StructType.Field(name='age', type=Type(code=INT64)), + ]) + field_type = Type(code=STRUCT, struct_type=struct_type_pb) + value_pb = Value(list_value=_make_list_value_pb(VALUES)) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUES) + + def test_w_unknown_type(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.proto.spanner.v1.type_pb2 import Type + from google.cloud.proto.spanner.v1.type_pb2 import ( + TYPE_CODE_UNSPECIFIED) + field_type = Type(code=TYPE_CODE_UNSPECIFIED) + value_pb = Value(string_value='Borked') + + with self.assertRaises(ValueError): + self._callFUT(value_pb, field_type) + + +class Test_parse_list_value_pbs(unittest.TestCase): + + def _callFUT(self, *args, **kw): + from google.cloud.spanner._helpers import _parse_list_value_pbs + return _parse_list_value_pbs(*args, **kw) + + def test_empty(self): + from google.cloud.proto.spanner.v1.type_pb2 import Type, StructType + from google.cloud.proto.spanner.v1.type_pb2 import STRING, INT64 + struct_type_pb = StructType(fields=[ + StructType.Field(name='name', type=Type(code=STRING)), + StructType.Field(name='age', type=Type(code=INT64)), + ]) + + self.assertEqual(self._callFUT(rows=[], row_type=struct_type_pb), []) + + def test_non_empty(self): + from google.cloud.proto.spanner.v1.type_pb2 import Type, StructType + from google.cloud.proto.spanner.v1.type_pb2 import STRING, INT64 + from google.cloud.spanner._helpers import _make_list_value_pbs + VALUES = [ + [u'phred', 32], + [u'bharney', 31], + ] + struct_type_pb = StructType(fields=[ + StructType.Field(name='name', type=Type(code=STRING)), + StructType.Field(name='age', type=Type(code=INT64)), + ]) + values_pbs = _make_list_value_pbs(VALUES) + + self.assertEqual( + self._callFUT(rows=values_pbs, row_type=struct_type_pb), VALUES) + + +class Test_SessionWrapper(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner._helpers import _SessionWrapper + return _SessionWrapper + + def _makeOne(self, session): + return self._getTargetClass()(session) + + def test_ctor(self): + session = object() + base = self._makeOne(session) + self.assertTrue(base._session is session) + + +class Test_options_with_prefix(unittest.TestCase): + + def _call_fut(self, *args, **kw): + from google.cloud.spanner._helpers import _options_with_prefix + return _options_with_prefix(*args, **kw) + + def test_wo_kwargs(self): + from google.gax import CallOptions + PREFIX = 'prefix' + options = self._call_fut(PREFIX) + self.assertIsInstance(options, CallOptions) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', PREFIX)]) + + def test_w_kwargs(self): + from google.gax import CallOptions + PREFIX = 'prefix' + TOKEN = 'token' + options = self._call_fut('prefix', page_token=TOKEN) + self.assertIsInstance(options, CallOptions) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', PREFIX)]) + self.assertEqual(options.page_token, TOKEN) diff --git a/spanner/unit_tests/test_batch.py b/spanner/unit_tests/test_batch.py new file mode 100644 index 000000000000..5ac6aa3fcaec --- /dev/null +++ b/spanner/unit_tests/test_batch.py @@ -0,0 +1,351 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from google.cloud._testing import _GAXBaseAPI + + +TABLE_NAME = 'citizens' +COLUMNS = ['email', 'first_name', 'last_name', 'age'] +VALUES = [ + [u'phred@exammple.com', u'Phred', u'Phlyntstone', 32], + [u'bharney@example.com', u'Bharney', u'Rhubble', 31], +] + + +class _BaseTest(unittest.TestCase): + + PROJECT_ID = 'project-id' + INSTANCE_ID = 'instance-id' + INSTANCE_NAME = 'projects/' + PROJECT_ID + '/instances/' + INSTANCE_ID + DATABASE_ID = 'database-id' + DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID + SESSION_ID = 'session-id' + SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + +class Test_BatchBase(_BaseTest): + + def _getTargetClass(self): + from google.cloud.spanner.batch import _BatchBase + return _BatchBase + + def _compare_values(self, result, source): + from google.protobuf.struct_pb2 import ListValue + from google.protobuf.struct_pb2 import Value + for found, expected in zip(result, source): + self.assertIsInstance(found, ListValue) + self.assertEqual(len(found.values), len(expected)) + for found_cell, expected_cell in zip(found.values, expected): + self.assertIsInstance(found_cell, Value) + if isinstance(expected_cell, int): + self.assertEqual( + int(found_cell.string_value), expected_cell) + else: + self.assertEqual(found_cell.string_value, expected_cell) + + def test_ctor(self): + session = _Session() + base = self._makeOne(session) + self.assertTrue(base._session is session) + self.assertEqual(len(base._mutations), 0) + + def test__check_state_virtual(self): + session = _Session() + base = self._makeOne(session) + with self.assertRaises(NotImplementedError): + base._check_state() + + def test_insert(self): + from google.cloud.proto.spanner.v1.mutation_pb2 import Mutation + session = _Session() + base = self._makeOne(session) + + base.insert(TABLE_NAME, columns=COLUMNS, values=VALUES) + + self.assertEqual(len(base._mutations), 1) + mutation = base._mutations[0] + self.assertIsInstance(mutation, Mutation) + write = mutation.insert + self.assertIsInstance(write, Mutation.Write) + self.assertEqual(write.table, TABLE_NAME) + self.assertEqual(write.columns, COLUMNS) + self._compare_values(write.values, VALUES) + + def test_update(self): + from google.cloud.proto.spanner.v1.mutation_pb2 import Mutation + session = _Session() + base = self._makeOne(session) + + base.update(TABLE_NAME, columns=COLUMNS, values=VALUES) + + self.assertEqual(len(base._mutations), 1) + mutation = base._mutations[0] + self.assertIsInstance(mutation, Mutation) + write = mutation.update + self.assertIsInstance(write, Mutation.Write) + self.assertEqual(write.table, TABLE_NAME) + self.assertEqual(write.columns, COLUMNS) + self._compare_values(write.values, VALUES) + + def test_insert_or_update(self): + from google.cloud.proto.spanner.v1.mutation_pb2 import Mutation + session = _Session() + base = self._makeOne(session) + + base.insert_or_update(TABLE_NAME, columns=COLUMNS, values=VALUES) + + self.assertEqual(len(base._mutations), 1) + mutation = base._mutations[0] + self.assertIsInstance(mutation, Mutation) + write = mutation.insert_or_update + self.assertIsInstance(write, Mutation.Write) + self.assertEqual(write.table, TABLE_NAME) + self.assertEqual(write.columns, COLUMNS) + self._compare_values(write.values, VALUES) + + def test_replace(self): + from google.cloud.proto.spanner.v1.mutation_pb2 import Mutation + session = _Session() + base = self._makeOne(session) + + base.replace(TABLE_NAME, columns=COLUMNS, values=VALUES) + + self.assertEqual(len(base._mutations), 1) + mutation = base._mutations[0] + self.assertIsInstance(mutation, Mutation) + write = mutation.replace + self.assertIsInstance(write, Mutation.Write) + self.assertEqual(write.table, TABLE_NAME) + self.assertEqual(write.columns, COLUMNS) + self._compare_values(write.values, VALUES) + + def test_delete(self): + from google.cloud.proto.spanner.v1.mutation_pb2 import Mutation + from google.cloud.spanner.keyset import KeySet + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) + session = _Session() + base = self._makeOne(session) + + base.delete(TABLE_NAME, keyset=keyset) + + self.assertEqual(len(base._mutations), 1) + mutation = base._mutations[0] + self.assertIsInstance(mutation, Mutation) + delete = mutation.delete + self.assertIsInstance(delete, Mutation.Delete) + self.assertEqual(delete.table, TABLE_NAME) + key_set_pb = delete.key_set + self.assertEqual(len(key_set_pb.ranges), 0) + self.assertEqual(len(key_set_pb.keys), len(keys)) + for found, expected in zip(key_set_pb.keys, keys): + self.assertEqual( + [int(value.string_value) for value in found.values], expected) + + +class TestBatch(_BaseTest): + + def _getTargetClass(self): + from google.cloud.spanner.batch import Batch + return Batch + + def test_ctor(self): + session = _Session() + batch = self._makeOne(session) + self.assertTrue(batch._session is session) + + def test_commit_already_committed(self): + from google.cloud.spanner.keyset import KeySet + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) + database = _Database() + session = _Session(database) + batch = self._makeOne(session) + batch.committed = object() + batch.delete(TABLE_NAME, keyset=keyset) + + with self.assertRaises(ValueError): + batch.commit() + + def test_commit_grpc_error(self): + from google.gax.errors import GaxError + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionOptions) + from google.cloud.proto.spanner.v1.mutation_pb2 import ( + Mutation as MutationPB) + from google.cloud.spanner.keyset import KeySet + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _random_gax_error=True) + session = _Session(database) + batch = self._makeOne(session) + batch.delete(TABLE_NAME, keyset=keyset) + + with self.assertRaises(GaxError): + batch.commit() + + (session, mutations, single_use_txn, options) = api._committed + self.assertEqual(session, self.SESSION_NAME) + self.assertTrue(len(mutations), 1) + mutation = mutations[0] + self.assertIsInstance(mutation, MutationPB) + self.assertTrue(mutation.HasField('delete')) + delete = mutation.delete + self.assertEqual(delete.table, TABLE_NAME) + keyset_pb = delete.key_set + self.assertEqual(len(keyset_pb.ranges), 0) + self.assertEqual(len(keyset_pb.keys), len(keys)) + for found, expected in zip(keyset_pb.keys, keys): + self.assertEqual( + [int(value.string_value) for value in found.values], expected) + self.assertIsInstance(single_use_txn, TransactionOptions) + self.assertTrue(single_use_txn.HasField('read_write')) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_commit_ok(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionOptions) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _commit_response=response) + session = _Session(database) + batch = self._makeOne(session) + batch.insert(TABLE_NAME, COLUMNS, VALUES) + + committed = batch.commit() + + self.assertEqual(committed, now) + self.assertEqual(batch.committed, committed) + + (session, mutations, single_use_txn, options) = api._committed + self.assertEqual(session, self.SESSION_NAME) + self.assertEqual(mutations, batch._mutations) + self.assertIsInstance(single_use_txn, TransactionOptions) + self.assertTrue(single_use_txn.HasField('read_write')) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_context_mgr_already_committed(self): + import datetime + from google.cloud._helpers import UTC + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI() + session = _Session(database) + batch = self._makeOne(session) + batch.committed = now + + with self.assertRaises(ValueError): + with batch: + pass # pragma: NO COVER + + self.assertEqual(api._committed, None) + + def test_context_mgr_success(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionOptions) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _commit_response=response) + session = _Session(database) + batch = self._makeOne(session) + + with batch: + batch.insert(TABLE_NAME, COLUMNS, VALUES) + + self.assertEqual(batch.committed, now) + + (session, mutations, single_use_txn, options) = api._committed + self.assertEqual(session, self.SESSION_NAME) + self.assertEqual(mutations, batch._mutations) + self.assertIsInstance(single_use_txn, TransactionOptions) + self.assertTrue(single_use_txn.HasField('read_write')) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_context_mgr_failure(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _commit_response=response) + session = _Session(database) + batch = self._makeOne(session) + + class _BailOut(Exception): + pass + + with self.assertRaises(_BailOut): + with batch: + batch.insert(TABLE_NAME, COLUMNS, VALUES) + raise _BailOut() + + self.assertEqual(batch.committed, None) + self.assertEqual(api._committed, None) + self.assertEqual(len(batch._mutations), 1) + + +class _Session(object): + + def __init__(self, database=None, name=TestBatch.SESSION_NAME): + self._database = database + self.name = name + + +class _Database(object): + name = 'testing' + + +class _FauxSpannerAPI(_GAXBaseAPI): + + _create_instance_conflict = False + _instance_not_found = False + _committed = None + + def commit(self, session, mutations, + transaction_id='', single_use_transaction=None, options=None): + from google.gax.errors import GaxError + assert transaction_id == '' + self._committed = (session, mutations, single_use_transaction, options) + if self._random_gax_error: + raise GaxError('error') + return self._commit_response diff --git a/spanner/unit_tests/test_client.py b/spanner/unit_tests/test_client.py new file mode 100644 index 000000000000..722733f71819 --- /dev/null +++ b/spanner/unit_tests/test_client.py @@ -0,0 +1,436 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import mock + + +def _make_credentials(): + import google.auth.credentials + + class _CredentialsWithScopes( + google.auth.credentials.Credentials, + google.auth.credentials.Scoped): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class Test__make_operations_stub(unittest.TestCase): + + def _callFUT(self, client): + from google.cloud.spanner.client import _make_operations_stub + return _make_operations_stub(client) + + def test_it(self): + from google.cloud._testing import _Monkey + from google.cloud.spanner import client as MUT + + credentials = _Credentials() + user_agent = 'you-sir-age-int' + client = _Client(credentials, user_agent) + + fake_stub = object() + make_secure_stub_args = [] + + def mock_make_secure_stub(*args): + make_secure_stub_args.append(args) + return fake_stub + + with _Monkey(MUT, make_secure_stub=mock_make_secure_stub): + result = self._callFUT(client) + + self.assertIs(result, fake_stub) + self.assertEqual(make_secure_stub_args, [ + ( + client.credentials, + client.user_agent, + MUT.operations_grpc.OperationsStub, + MUT.OPERATIONS_API_HOST, + ), + ]) + + +class TestClient(unittest.TestCase): + + PROJECT = 'PROJECT' + PATH = 'projects/%s' % (PROJECT,) + CONFIGURATION_NAME = 'config-name' + INSTANCE_ID = 'instance-id' + INSTANCE_NAME = '%s/instances/%s' % (PATH, INSTANCE_ID) + DISPLAY_NAME = 'display-name' + NODE_COUNT = 5 + TIMEOUT_SECONDS = 80 + USER_AGENT = 'you-sir-age-int' + + def _getTargetClass(self): + from google.cloud.spanner.client import Client + return Client + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def _constructor_test_helper(self, expected_scopes, creds, + user_agent=None, + expected_creds=None): + from google.cloud.spanner import client as MUT + + user_agent = user_agent or MUT.DEFAULT_USER_AGENT + client = self._makeOne(project=self.PROJECT, credentials=creds, + user_agent=user_agent) + + expected_creds = expected_creds or creds.with_scopes.return_value + self.assertIs(client._credentials, expected_creds) + + self.assertTrue(client._credentials is expected_creds) + if expected_scopes is not None: + creds.with_scopes.assert_called_once_with(expected_scopes) + + self.assertEqual(client.project, self.PROJECT) + self.assertEqual(client.user_agent, user_agent) + + def test_constructor_default_scopes(self): + from google.cloud.spanner import client as MUT + + expected_scopes = [ + MUT.SPANNER_ADMIN_SCOPE, + ] + creds = _make_credentials() + self._constructor_test_helper(expected_scopes, creds) + + def test_constructor_custom_user_agent_and_timeout(self): + from google.cloud.spanner import client as MUT + + CUSTOM_USER_AGENT = 'custom-application' + expected_scopes = [ + MUT.SPANNER_ADMIN_SCOPE, + ] + creds = _make_credentials() + self._constructor_test_helper(expected_scopes, creds, + user_agent=CUSTOM_USER_AGENT) + + def test_constructor_implicit_credentials(self): + from google.cloud._testing import _Monkey + from google.cloud.spanner import client as MUT + + creds = _make_credentials() + + def mock_get_credentials(): + return creds + + with _Monkey(MUT, get_credentials=mock_get_credentials): + self._constructor_test_helper( + None, None, + expected_creds=creds.with_scopes.return_value) + + def test_constructor_credentials_wo_create_scoped(self): + creds = _make_credentials() + expected_scopes = None + self._constructor_test_helper(expected_scopes, creds) + + def test_instance_admin_api(self): + from google.cloud._testing import _Monkey + from google.cloud.spanner import client as MUT + client = self._makeOne(project=self.PROJECT) + + class _Client(object): + pass + + with _Monkey(MUT, InstanceAdminClient=_Client): + api = client.instance_admin_api + + self.assertTrue(isinstance(api, _Client)) + again = client.instance_admin_api + self.assertTrue(again is api) + + def test_database_admin_api(self): + from google.cloud._testing import _Monkey + from google.cloud.spanner import client as MUT + client = self._makeOne(project=self.PROJECT) + + class _Client(object): + pass + + with _Monkey(MUT, DatabaseAdminClient=_Client): + api = client.database_admin_api + + self.assertTrue(isinstance(api, _Client)) + again = client.database_admin_api + self.assertTrue(again is api) + + def test__operations_stub(self): + from google.cloud._testing import _Monkey + from google.cloud.spanner import client as MUT + client = self._makeOne(project=self.PROJECT) + + class _Stub(object): + pass + + def _make_operations_stub(_): + return _Stub() + + with _Monkey(MUT, _make_operations_stub=_make_operations_stub): + stub = client._operations_stub + + self.assertTrue(isinstance(stub, _Stub)) + again = client._operations_stub + self.assertTrue(again is stub) + + def test_copy(self): + credentials = _Credentials('value') + client = self._makeOne( + project=self.PROJECT, + credentials=credentials, + user_agent=self.USER_AGENT) + + new_client = client.copy() + self.assertEqual(new_client._credentials, client._credentials) + self.assertEqual(new_client.project, client.project) + self.assertEqual(new_client.user_agent, client.user_agent) + + def test_credentials_property(self): + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + self.assertTrue(client.credentials is credentials) + + def test_project_name_property(self): + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + project_name = 'projects/' + self.PROJECT + self.assertEqual(client.project_name, project_name) + + def test_list_instance_configs_wo_paging(self): + from google.cloud._testing import _GAXPageIterator + from google.gax import INITIAL_PAGE + from google.cloud.spanner.client import InstanceConfig + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + client.connection = object() + api = client._instance_admin_api = _FauxInstanceAdminAPI() + config = _InstanceConfigPB(name=self.CONFIGURATION_NAME, + display_name=self.DISPLAY_NAME) + response = _GAXPageIterator([config]) + api._list_instance_configs_response = response + + iterator = client.list_instance_configs() + configs = list(iterator) + + self.assertEqual(len(configs), 1) + config = configs[0] + self.assertTrue(isinstance(config, InstanceConfig)) + self.assertEqual(config.name, self.CONFIGURATION_NAME) + self.assertEqual(config.display_name, self.DISPLAY_NAME) + + project, page_size, options = api._listed_instance_configs + self.assertEqual(project, self.PATH) + self.assertEqual(page_size, None) + self.assertTrue(options.page_token is INITIAL_PAGE) + self.assertEqual( + options.kwargs['metadata'], + [('google-cloud-resource-prefix', client.project_name)]) + + def test_list_instance_configs_w_paging(self): + import six + from google.cloud._testing import _GAXPageIterator + from google.cloud.spanner.client import InstanceConfig + SIZE = 15 + TOKEN_RETURNED = 'TOKEN_RETURNED' + TOKEN_PASSED = 'TOKEN_PASSED' + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + client.connection = object() + api = client._instance_admin_api = _FauxInstanceAdminAPI() + config = _InstanceConfigPB(name=self.CONFIGURATION_NAME, + display_name=self.DISPLAY_NAME) + response = _GAXPageIterator([config], page_token=TOKEN_RETURNED) + api._list_instance_configs_response = response + + iterator = client.list_instance_configs(SIZE, TOKEN_PASSED) + page = six.next(iterator.pages) + next_token = iterator.next_page_token + configs = list(page) + + self.assertEqual(len(configs), 1) + config = configs[0] + self.assertTrue(isinstance(config, InstanceConfig)) + self.assertEqual(config.name, self.CONFIGURATION_NAME) + self.assertEqual(config.display_name, self.DISPLAY_NAME) + self.assertEqual(next_token, TOKEN_RETURNED) + + project, page_size, options = api._listed_instance_configs + self.assertEqual(project, self.PATH) + self.assertEqual(page_size, SIZE) + self.assertEqual(options.page_token, TOKEN_PASSED) + self.assertEqual( + options.kwargs['metadata'], + [('google-cloud-resource-prefix', client.project_name)]) + + def test_instance_factory_defaults(self): + from google.cloud.spanner.instance import DEFAULT_NODE_COUNT + from google.cloud.spanner.instance import Instance + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + + instance = client.instance(self.INSTANCE_ID) + + self.assertTrue(isinstance(instance, Instance)) + self.assertEqual(instance.instance_id, self.INSTANCE_ID) + self.assertIsNone(instance.configuration_name) + self.assertEqual(instance.display_name, self.INSTANCE_ID) + self.assertEqual(instance.node_count, DEFAULT_NODE_COUNT) + self.assertTrue(instance._client is client) + + def test_instance_factory_explicit(self): + from google.cloud.spanner.instance import Instance + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + + instance = client.instance(self.INSTANCE_ID, self.CONFIGURATION_NAME, + display_name=self.DISPLAY_NAME, + node_count=self.NODE_COUNT) + + self.assertTrue(isinstance(instance, Instance)) + self.assertEqual(instance.instance_id, self.INSTANCE_ID) + self.assertEqual(instance.configuration_name, self.CONFIGURATION_NAME) + self.assertEqual(instance.display_name, self.DISPLAY_NAME) + self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertTrue(instance._client is client) + + def test_list_instances_wo_paging(self): + from google.cloud._testing import _GAXPageIterator + from google.gax import INITIAL_PAGE + from google.cloud.spanner.instance import Instance + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + client.connection = object() + api = client._instance_admin_api = _FauxInstanceAdminAPI() + instance = _InstancePB(name=self.INSTANCE_NAME, + config=self.CONFIGURATION_NAME, + display_name=self.DISPLAY_NAME, + node_count=self.NODE_COUNT) + response = _GAXPageIterator([instance]) + api._list_instances_response = response + + iterator = client.list_instances(filter_='name:TEST') + instances = list(iterator) + + self.assertEqual(len(instances), 1) + instance = instances[0] + self.assertTrue(isinstance(instance, Instance)) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.configuration_name, self.CONFIGURATION_NAME) + self.assertEqual(instance.display_name, self.DISPLAY_NAME) + self.assertEqual(instance.node_count, self.NODE_COUNT) + + project, filter_, page_size, options = api._listed_instances + self.assertEqual(project, self.PATH) + self.assertEqual(filter_, 'name:TEST') + self.assertEqual(page_size, None) + self.assertTrue(options.page_token is INITIAL_PAGE) + self.assertEqual( + options.kwargs['metadata'], + [('google-cloud-resource-prefix', client.project_name)]) + + def test_list_instances_w_paging(self): + import six + from google.cloud._testing import _GAXPageIterator + from google.cloud.spanner.instance import Instance + SIZE = 15 + TOKEN_RETURNED = 'TOKEN_RETURNED' + TOKEN_PASSED = 'TOKEN_PASSED' + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + client.connection = object() + api = client._instance_admin_api = _FauxInstanceAdminAPI() + instance = _InstancePB(name=self.INSTANCE_NAME, + config=self.CONFIGURATION_NAME, + display_name=self.DISPLAY_NAME, + node_count=self.NODE_COUNT) + response = _GAXPageIterator([instance], page_token=TOKEN_RETURNED) + api._list_instances_response = response + + iterator = client.list_instances( + page_size=SIZE, page_token=TOKEN_PASSED) + page = six.next(iterator.pages) + next_token = iterator.next_page_token + instances = list(page) + + self.assertEqual(len(instances), 1) + instance = instances[0] + self.assertTrue(isinstance(instance, Instance)) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.configuration_name, self.CONFIGURATION_NAME) + self.assertEqual(instance.display_name, self.DISPLAY_NAME) + self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(next_token, TOKEN_RETURNED) + + project, filter_, page_size, options = api._listed_instances + self.assertEqual(project, self.PATH) + self.assertEqual(filter_, '') + self.assertEqual(page_size, SIZE) + self.assertEqual(options.page_token, TOKEN_PASSED) + self.assertEqual( + options.kwargs['metadata'], + [('google-cloud-resource-prefix', client.project_name)]) + + +class _Client(object): + + def __init__(self, credentials, user_agent): + self.credentials = credentials + self.user_agent = user_agent + + +class _Credentials(object): + + scopes = None + + def __init__(self, access_token=None): + self._access_token = access_token + self._tokens = [] + + def create_scoped(self, scope): + self.scopes = scope + return self + + def __eq__(self, other): + return self._access_token == other._access_token + + +class _FauxInstanceAdminAPI(object): + + def list_instance_configs(self, name, page_size, options): + self._listed_instance_configs = (name, page_size, options) + return self._list_instance_configs_response + + def list_instances(self, name, filter_, page_size, options): + self._listed_instances = (name, filter_, page_size, options) + return self._list_instances_response + + +class _InstanceConfigPB(object): + + def __init__(self, name, display_name): + self.name = name + self.display_name = display_name + + +class _InstancePB(object): + + def __init__(self, name, config, display_name=None, node_count=None): + self.name = name + self.config = config + self.display_name = display_name + self.node_count = node_count diff --git a/spanner/unit_tests/test_database.py b/spanner/unit_tests/test_database.py new file mode 100644 index 000000000000..89e571ee59cb --- /dev/null +++ b/spanner/unit_tests/test_database.py @@ -0,0 +1,1116 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import mock + +from google.cloud._testing import _GAXBaseAPI + + +class _BaseTest(unittest.TestCase): + + PROJECT_ID = 'project-id' + PARENT = 'projects/' + PROJECT_ID + INSTANCE_ID = 'instance-id' + INSTANCE_NAME = PARENT + '/instances/' + INSTANCE_ID + DATABASE_ID = 'database_id' + DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID + SESSION_ID = 'session_id' + SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + +class TestDatabase(_BaseTest): + + def _getTargetClass(self): + from google.cloud.spanner.database import Database + return Database + + def test_ctor_defaults(self): + from google.cloud.spanner.pool import BurstyPool + instance = _Instance(self.INSTANCE_NAME) + + database = self._makeOne(self.DATABASE_ID, instance) + + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertTrue(database._instance is instance) + self.assertEqual(list(database.ddl_statements), []) + self.assertIsInstance(database._pool, BurstyPool) + # BurstyPool does not create sessions during 'bind()'. + self.assertTrue(database._pool._sessions.empty()) + + def test_ctor_w_explicit_pool(self): + instance = _Instance(self.INSTANCE_NAME) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertTrue(database._instance is instance) + self.assertEqual(list(database.ddl_statements), []) + self.assertIs(database._pool, pool) + self.assertIs(pool._bound, database) + + def test_ctor_w_ddl_statements_non_string(self): + + with self.assertRaises(ValueError): + self._makeOne( + self.DATABASE_ID, instance=object(), + ddl_statements=[object()]) + + def test_ctor_w_ddl_statements_w_create_database(self): + + with self.assertRaises(ValueError): + self._makeOne( + self.DATABASE_ID, instance=object(), + ddl_statements=['CREATE DATABASE foo']) + + def test_ctor_w_ddl_statements_ok(self): + from google.cloud.spanner._fixtures import DDL_STATEMENTS + instance = _Instance(self.INSTANCE_NAME) + pool = _Pool() + database = self._makeOne( + self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, + pool=pool) + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertTrue(database._instance is instance) + self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) + + def test_from_pb_bad_database_name(self): + from google.cloud.proto.spanner.admin.database.v1 import ( + spanner_database_admin_pb2 as admin_v1_pb2) + database_name = 'INCORRECT_FORMAT' + database_pb = admin_v1_pb2.Database(name=database_name) + klass = self._getTargetClass() + + with self.assertRaises(ValueError): + klass.from_pb(database_pb, None) + + def test_from_pb_project_mistmatch(self): + from google.cloud.proto.spanner.admin.database.v1 import ( + spanner_database_admin_pb2 as admin_v1_pb2) + ALT_PROJECT = 'ALT_PROJECT' + client = _Client(project=ALT_PROJECT) + instance = _Instance(self.INSTANCE_NAME, client) + database_pb = admin_v1_pb2.Database(name=self.DATABASE_NAME) + klass = self._getTargetClass() + + with self.assertRaises(ValueError): + klass.from_pb(database_pb, instance) + + def test_from_pb_instance_mistmatch(self): + from google.cloud.proto.spanner.admin.database.v1 import ( + spanner_database_admin_pb2 as admin_v1_pb2) + ALT_INSTANCE = '/projects/%s/instances/ALT-INSTANCE' % ( + self.PROJECT_ID,) + client = _Client() + instance = _Instance(ALT_INSTANCE, client) + database_pb = admin_v1_pb2.Database(name=self.DATABASE_NAME) + klass = self._getTargetClass() + + with self.assertRaises(ValueError): + klass.from_pb(database_pb, instance) + + def test_from_pb_success_w_explicit_pool(self): + from google.cloud.proto.spanner.admin.database.v1 import ( + spanner_database_admin_pb2 as admin_v1_pb2) + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client) + database_pb = admin_v1_pb2.Database(name=self.DATABASE_NAME) + klass = self._getTargetClass() + pool = _Pool() + + database = klass.from_pb(database_pb, instance, pool=pool) + + self.assertTrue(isinstance(database, klass)) + self.assertEqual(database._instance, instance) + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._pool, pool) + + def test_from_pb_success_w_hyphen_w_default_pool(self): + from google.cloud.proto.spanner.admin.database.v1 import ( + spanner_database_admin_pb2 as admin_v1_pb2) + from google.cloud.spanner.pool import BurstyPool + DATABASE_ID_HYPHEN = 'database-id' + DATABASE_NAME_HYPHEN = ( + self.INSTANCE_NAME + '/databases/' + DATABASE_ID_HYPHEN) + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client) + database_pb = admin_v1_pb2.Database(name=DATABASE_NAME_HYPHEN) + klass = self._getTargetClass() + + database = klass.from_pb(database_pb, instance) + + self.assertTrue(isinstance(database, klass)) + self.assertEqual(database._instance, instance) + self.assertEqual(database.database_id, DATABASE_ID_HYPHEN) + self.assertIsInstance(database._pool, BurstyPool) + # BurstyPool does not create sessions during 'bind()'. + self.assertTrue(database._pool._sessions.empty()) + + def test_name_property(self): + instance = _Instance(self.INSTANCE_NAME) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + expected_name = self.DATABASE_NAME + self.assertEqual(database.name, expected_name) + + def test_spanner_api_property(self): + from google.cloud._testing import _Monkey + from google.cloud.spanner import database as MUT + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + _client = object() + _clients = [_client] + + def _mock_spanner_client(): + return _clients.pop(0) + + with _Monkey(MUT, SpannerClient=_mock_spanner_client): + api = database.spanner_api + self.assertTrue(api is _client) + # API instance is cached + again = database.spanner_api + self.assertTrue(again is api) + + def test___eq__(self): + instance = _Instance(self.INSTANCE_NAME) + pool1, pool2 = _Pool(), _Pool() + database1 = self._makeOne(self.DATABASE_ID, instance, pool=pool1) + database2 = self._makeOne(self.DATABASE_ID, instance, pool=pool2) + self.assertEqual(database1, database2) + + def test___eq__type_differ(self): + pool = _Pool() + database1 = self._makeOne(self.DATABASE_ID, None, pool=pool) + database2 = object() + self.assertNotEqual(database1, database2) + + def test___ne__same_value(self): + instance = _Instance(self.INSTANCE_NAME) + pool1, pool2 = _Pool(), _Pool() + database1 = self._makeOne(self.DATABASE_ID, instance, pool=pool1) + database2 = self._makeOne(self.DATABASE_ID, instance, pool=pool2) + comparison_val = (database1 != database2) + self.assertFalse(comparison_val) + + def test___ne__(self): + pool1, pool2 = _Pool(), _Pool() + database1 = self._makeOne('database_id1', 'instance1', pool=pool1) + database2 = self._makeOne('database_id2', 'instance2', pool=pool2) + self.assertNotEqual(database1, database2) + + def test_create_grpc_error(self): + from google.gax.errors import GaxError + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _random_gax_error=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(GaxError): + database.create() + + (parent, create_statement, extra_statements, + options) = api._created_database + self.assertEqual(parent, self.INSTANCE_NAME) + self.assertEqual(create_statement, + 'CREATE DATABASE %s' % self.DATABASE_ID) + self.assertEqual(extra_statements, []) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_create_already_exists(self): + DATABASE_ID_HYPHEN = 'database-id' + from google.cloud.exceptions import Conflict + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _create_database_conflict=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(DATABASE_ID_HYPHEN, instance, pool=pool) + + with self.assertRaises(Conflict): + database.create() + + (parent, create_statement, extra_statements, + options) = api._created_database + self.assertEqual(parent, self.INSTANCE_NAME) + self.assertEqual(create_statement, + 'CREATE DATABASE `%s`' % DATABASE_ID_HYPHEN) + self.assertEqual(extra_statements, []) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_create_instance_not_found(self): + from google.cloud.exceptions import NotFound + + DATABASE_ID_HYPHEN = 'database-id' + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _database_not_found=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(DATABASE_ID_HYPHEN, instance, pool=pool) + + with self.assertRaises(NotFound): + database.create() + + (parent, create_statement, extra_statements, + options) = api._created_database + self.assertEqual(parent, self.INSTANCE_NAME) + self.assertEqual(create_statement, + 'CREATE DATABASE `%s`' % DATABASE_ID_HYPHEN) + self.assertEqual(extra_statements, []) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_create_success(self): + from google.cloud.spanner._fixtures import DDL_STATEMENTS + op_future = _FauxOperationFuture() + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _create_database_response=op_future) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne( + self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, + pool=pool) + + future = database.create() + + self.assertIs(future, op_future) + self.assertEqual(future.caller_metadata, + {'request_type': 'CreateDatabase'}) + + (parent, create_statement, extra_statements, + options) = api._created_database + self.assertEqual(parent, self.INSTANCE_NAME) + self.assertEqual(create_statement, + 'CREATE DATABASE %s' % self.DATABASE_ID) + self.assertEqual(extra_statements, DDL_STATEMENTS) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_exists_grpc_error(self): + from google.gax.errors import GaxError + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _random_gax_error=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(GaxError): + database.exists() + + name, options = api._got_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_exists_not_found(self): + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _database_not_found=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + self.assertFalse(database.exists()) + + name, options = api._got_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_exists_success(self): + from google.cloud.proto.spanner.admin.database.v1 import ( + spanner_database_admin_pb2 as admin_v1_pb2) + from google.cloud.spanner._fixtures import DDL_STATEMENTS + client = _Client() + ddl_pb = admin_v1_pb2.GetDatabaseDdlResponse( + statements=DDL_STATEMENTS) + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _get_database_ddl_response=ddl_pb) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + self.assertTrue(database.exists()) + + name, options = api._got_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_reload_grpc_error(self): + from google.gax.errors import GaxError + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _random_gax_error=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(GaxError): + database.reload() + + name, options = api._got_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_reload_not_found(self): + from google.cloud.exceptions import NotFound + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _database_not_found=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(NotFound): + database.reload() + + name, options = api._got_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_reload_success(self): + from google.cloud.proto.spanner.admin.database.v1 import ( + spanner_database_admin_pb2 as admin_v1_pb2) + from google.cloud.spanner._fixtures import DDL_STATEMENTS + client = _Client() + ddl_pb = admin_v1_pb2.GetDatabaseDdlResponse( + statements=DDL_STATEMENTS) + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _get_database_ddl_response=ddl_pb) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + database.reload() + + self.assertEqual(database._ddl_statements, tuple(DDL_STATEMENTS)) + + name, options = api._got_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_update_ddl_grpc_error(self): + from google.gax.errors import GaxError + from google.cloud.spanner._fixtures import DDL_STATEMENTS + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _random_gax_error=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(GaxError): + database.update_ddl(DDL_STATEMENTS) + + name, statements, op_id, options = api._updated_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(statements, DDL_STATEMENTS) + self.assertEqual(op_id, '') + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_update_ddl_not_found(self): + from google.cloud.exceptions import NotFound + from google.cloud.spanner._fixtures import DDL_STATEMENTS + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _database_not_found=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(NotFound): + database.update_ddl(DDL_STATEMENTS) + + name, statements, op_id, options = api._updated_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(statements, DDL_STATEMENTS) + self.assertEqual(op_id, '') + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_update_ddl(self): + from google.cloud.spanner._fixtures import DDL_STATEMENTS + op_future = _FauxOperationFuture() + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _update_database_ddl_response=op_future) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + future = database.update_ddl(DDL_STATEMENTS) + + self.assertIs(future, op_future) + self.assertEqual(future.caller_metadata, + {'request_type': 'UpdateDatabaseDdl'}) + + name, statements, op_id, options = api._updated_database_ddl + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(statements, DDL_STATEMENTS) + self.assertEqual(op_id, '') + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_drop_grpc_error(self): + from google.gax.errors import GaxError + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _random_gax_error=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(GaxError): + database.drop() + + name, options = api._dropped_database + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_drop_not_found(self): + from google.cloud.exceptions import NotFound + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _database_not_found=True) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(NotFound): + database.drop() + + name, options = api._dropped_database + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_drop_success(self): + from google.protobuf.empty_pb2 import Empty + client = _Client() + api = client.database_admin_api = _FauxDatabaseAdminAPI( + _drop_database_response=Empty()) + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + database.drop() + + name, options = api._dropped_database + self.assertEqual(name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_session_factory(self): + from google.cloud.spanner.session import Session + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + session = database.session() + + self.assertTrue(isinstance(session, Session)) + self.assertTrue(session.session_id is None) + self.assertTrue(session._database is database) + + def test_execute_sql_defaults(self): + QUERY = 'SELECT * FROM employees' + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + session._execute_result = [] + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + rows = list(database.execute_sql(QUERY)) + + self.assertEqual(rows, []) + self.assertEqual(session._executed, (QUERY, None, None, None, b'')) + + def test_run_in_transaction_wo_args(self): + import datetime + NOW = datetime.datetime.now() + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + session._committed = NOW + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + _unit_of_work = object() + + committed = database.run_in_transaction(_unit_of_work) + + self.assertEqual(committed, NOW) + self.assertEqual(session._retried, (_unit_of_work, (), {})) + + def test_run_in_transaction_w_args(self): + import datetime + SINCE = datetime.datetime(2017, 1, 1) + UNTIL = datetime.datetime(2018, 1, 1) + NOW = datetime.datetime.now() + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + session._committed = NOW + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + _unit_of_work = object() + + committed = database.run_in_transaction( + _unit_of_work, SINCE, until=UNTIL) + + self.assertEqual(committed, NOW) + self.assertEqual(session._retried, + (_unit_of_work, (SINCE,), {'until': UNTIL})) + + def test_read(self): + from google.cloud.spanner.keyset import KeySet + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + KEYS = ['bharney@example.com', 'phred@example.com'] + KEYSET = KeySet(keys=KEYS) + INDEX = 'email-address-index' + LIMIT = 20 + TOKEN = b'DEADBEEF' + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + rows = list(database.read( + TABLE_NAME, COLUMNS, KEYSET, INDEX, LIMIT, TOKEN)) + + self.assertEqual(rows, []) + + (table, columns, key_set, index, limit, + resume_token) = session._read_with + + self.assertEqual(table, TABLE_NAME) + self.assertEqual(columns, COLUMNS) + self.assertEqual(key_set, KEYSET) + self.assertEqual(index, INDEX) + self.assertEqual(limit, LIMIT) + self.assertEqual(resume_token, TOKEN) + + def test_batch(self): + from google.cloud.spanner.database import BatchCheckout + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + checkout = database.batch() + self.assertIsInstance(checkout, BatchCheckout) + self.assertTrue(checkout._database is database) + + def test_snapshot_defaults(self): + from google.cloud.spanner.database import SnapshotCheckout + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + checkout = database.snapshot() + self.assertIsInstance(checkout, SnapshotCheckout) + self.assertTrue(checkout._database is database) + self.assertIsNone(checkout._read_timestamp) + self.assertIsNone(checkout._min_read_timestamp) + self.assertIsNone(checkout._max_staleness) + self.assertIsNone(checkout._exact_staleness) + + def test_snapshot_w_read_timestamp(self): + import datetime + from google.cloud._helpers import UTC + from google.cloud.spanner.database import SnapshotCheckout + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + checkout = database.snapshot(read_timestamp=now) + + self.assertIsInstance(checkout, SnapshotCheckout) + self.assertTrue(checkout._database is database) + self.assertEqual(checkout._read_timestamp, now) + self.assertIsNone(checkout._min_read_timestamp) + self.assertIsNone(checkout._max_staleness) + self.assertIsNone(checkout._exact_staleness) + + def test_snapshot_w_min_read_timestamp(self): + import datetime + from google.cloud._helpers import UTC + from google.cloud.spanner.database import SnapshotCheckout + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + checkout = database.snapshot(min_read_timestamp=now) + + self.assertIsInstance(checkout, SnapshotCheckout) + self.assertTrue(checkout._database is database) + self.assertIsNone(checkout._read_timestamp) + self.assertEqual(checkout._min_read_timestamp, now) + self.assertIsNone(checkout._max_staleness) + self.assertIsNone(checkout._exact_staleness) + + def test_snapshot_w_max_staleness(self): + import datetime + from google.cloud.spanner.database import SnapshotCheckout + staleness = datetime.timedelta(seconds=1, microseconds=234567) + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + checkout = database.snapshot(max_staleness=staleness) + + self.assertIsInstance(checkout, SnapshotCheckout) + self.assertTrue(checkout._database is database) + self.assertIsNone(checkout._read_timestamp) + self.assertIsNone(checkout._min_read_timestamp) + self.assertEqual(checkout._max_staleness, staleness) + self.assertIsNone(checkout._exact_staleness) + + def test_snapshot_w_exact_staleness(self): + import datetime + from google.cloud.spanner.database import SnapshotCheckout + staleness = datetime.timedelta(seconds=1, microseconds=234567) + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._makeOne(self.DATABASE_ID, instance, pool=pool) + + checkout = database.snapshot(exact_staleness=staleness) + + self.assertIsInstance(checkout, SnapshotCheckout) + self.assertTrue(checkout._database is database) + self.assertIsNone(checkout._read_timestamp) + self.assertIsNone(checkout._min_read_timestamp) + self.assertIsNone(checkout._max_staleness) + self.assertEqual(checkout._exact_staleness, staleness) + + +class TestBatchCheckout(_BaseTest): + + def _getTargetClass(self): + from google.cloud.spanner.database import BatchCheckout + return BatchCheckout + + def test_ctor(self): + database = _Database(self.DATABASE_NAME) + checkout = self._makeOne(database) + self.assertTrue(checkout._database is database) + + def test_context_mgr_success(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionOptions) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner.batch import Batch + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + database = _Database(self.DATABASE_NAME) + api = database.spanner_api = _FauxSpannerClient() + api._commit_response = response + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._makeOne(database) + + with checkout as batch: + self.assertIsNone(pool._session) + self.assertIsInstance(batch, Batch) + self.assertIs(batch._session, session) + + self.assertIs(pool._session, session) + self.assertEqual(batch.committed, now) + (session_name, mutations, single_use_txn, + options) = api._committed + self.assertIs(session_name, self.SESSION_NAME) + self.assertEqual(mutations, []) + self.assertIsInstance(single_use_txn, TransactionOptions) + self.assertTrue(single_use_txn.HasField('read_write')) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_context_mgr_failure(self): + from google.cloud.spanner.batch import Batch + database = _Database(self.DATABASE_NAME) + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._makeOne(database) + + class Testing(Exception): + pass + + with self.assertRaises(Testing): + with checkout as batch: + self.assertIsNone(pool._session) + self.assertIsInstance(batch, Batch) + self.assertIs(batch._session, session) + raise Testing() + + self.assertIs(pool._session, session) + self.assertIsNone(batch.committed) + + +class TestSnapshotCheckout(_BaseTest): + + def _getTargetClass(self): + from google.cloud.spanner.database import SnapshotCheckout + return SnapshotCheckout + + def test_ctor_defaults(self): + from google.cloud.spanner.snapshot import Snapshot + database = _Database(self.DATABASE_NAME) + session = _Session(database) + pool = database._pool = _Pool() + pool.put(session) + + checkout = self._makeOne(database) + self.assertTrue(checkout._database is database) + self.assertIsNone(checkout._read_timestamp) + self.assertIsNone(checkout._min_read_timestamp) + self.assertIsNone(checkout._max_staleness) + self.assertIsNone(checkout._exact_staleness) + + with checkout as snapshot: + self.assertIsNone(pool._session) + self.assertIsInstance(snapshot, Snapshot) + self.assertIs(snapshot._session, session) + self.assertTrue(snapshot._strong) + + self.assertIs(pool._session, session) + + def test_ctor_w_read_timestamp(self): + import datetime + from google.cloud._helpers import UTC + from google.cloud.spanner.snapshot import Snapshot + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + database = _Database(self.DATABASE_NAME) + session = _Session(database) + pool = database._pool = _Pool() + pool.put(session) + + checkout = self._makeOne(database, read_timestamp=now) + self.assertTrue(checkout._database is database) + self.assertEqual(checkout._read_timestamp, now) + self.assertIsNone(checkout._min_read_timestamp) + self.assertIsNone(checkout._max_staleness) + self.assertIsNone(checkout._exact_staleness) + + with checkout as snapshot: + self.assertIsNone(pool._session) + self.assertIsInstance(snapshot, Snapshot) + self.assertIs(snapshot._session, session) + self.assertFalse(snapshot._strong) + self.assertEqual(snapshot._read_timestamp, now) + + self.assertIs(pool._session, session) + + def test_ctor_w_min_read_timestamp(self): + import datetime + from google.cloud._helpers import UTC + from google.cloud.spanner.snapshot import Snapshot + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + database = _Database(self.DATABASE_NAME) + session = _Session(database) + pool = database._pool = _Pool() + pool.put(session) + + checkout = self._makeOne(database, min_read_timestamp=now) + self.assertTrue(checkout._database is database) + self.assertIsNone(checkout._read_timestamp) + self.assertEqual(checkout._min_read_timestamp, now) + self.assertIsNone(checkout._max_staleness) + self.assertIsNone(checkout._exact_staleness) + + with checkout as snapshot: + self.assertIsNone(pool._session) + self.assertIsInstance(snapshot, Snapshot) + self.assertIs(snapshot._session, session) + self.assertFalse(snapshot._strong) + self.assertEqual(snapshot._min_read_timestamp, now) + + self.assertIs(pool._session, session) + + def test_ctor_w_max_staleness(self): + import datetime + from google.cloud.spanner.snapshot import Snapshot + staleness = datetime.timedelta(seconds=1, microseconds=234567) + database = _Database(self.DATABASE_NAME) + session = _Session(database) + pool = database._pool = _Pool() + pool.put(session) + + checkout = self._makeOne(database, max_staleness=staleness) + self.assertTrue(checkout._database is database) + self.assertIsNone(checkout._read_timestamp) + self.assertIsNone(checkout._min_read_timestamp) + self.assertEqual(checkout._max_staleness, staleness) + self.assertIsNone(checkout._exact_staleness) + + with checkout as snapshot: + self.assertIsNone(pool._session) + self.assertIsInstance(snapshot, Snapshot) + self.assertIs(snapshot._session, session) + self.assertFalse(snapshot._strong) + self.assertEqual(snapshot._max_staleness, staleness) + + self.assertIs(pool._session, session) + + def test_ctor_w_exact_staleness(self): + import datetime + from google.cloud.spanner.snapshot import Snapshot + staleness = datetime.timedelta(seconds=1, microseconds=234567) + database = _Database(self.DATABASE_NAME) + session = _Session(database) + pool = database._pool = _Pool() + pool.put(session) + + checkout = self._makeOne(database, exact_staleness=staleness) + + self.assertIs(checkout._database, database) + self.assertIsNone(checkout._read_timestamp) + self.assertIsNone(checkout._min_read_timestamp) + self.assertIsNone(checkout._max_staleness) + self.assertEqual(checkout._exact_staleness, staleness) + + with checkout as snapshot: + self.assertIsNone(pool._session) + self.assertIsInstance(snapshot, Snapshot) + self.assertIs(snapshot._session, session) + self.assertFalse(snapshot._strong) + self.assertEqual(snapshot._exact_staleness, staleness) + + self.assertIs(pool._session, session) + + def test_context_mgr_failure(self): + from google.cloud.spanner.snapshot import Snapshot + database = _Database(self.DATABASE_NAME) + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._makeOne(database) + + class Testing(Exception): + pass + + with self.assertRaises(Testing): + with checkout as snapshot: + self.assertIsNone(pool._session) + self.assertIsInstance(snapshot, Snapshot) + self.assertIs(snapshot._session, session) + raise Testing() + + self.assertIs(pool._session, session) + + +class TestBrokenResultFuture(unittest.TestCase): + def test_result_normal(self): + from google.gax import _OperationFuture + from google.cloud.spanner.database import _BrokenResultFuture + + with mock.patch.object(_OperationFuture, 'result') as super_result: + super_result.return_value = 'foo' + brf = _BrokenResultFuture(object(), object(), str, object()) + self.assertEqual(brf.result(), 'foo') + super_result.assert_called_once() + + def test_result_valueerror(self): + from google.gax import _OperationFuture + from google.cloud.spanner.database import _BrokenResultFuture + + with mock.patch.object(_OperationFuture, 'result') as super_result: + super_result.side_effect = TypeError + brf = _BrokenResultFuture(object(), object(), str, object()) + self.assertEqual(brf.result(), '') + super_result.assert_called_once() + + +class _Client(object): + + def __init__(self, project=TestDatabase.PROJECT_ID): + self.project = project + self.project_name = 'projects/' + self.project + + +class _Instance(object): + + def __init__(self, name, client=None): + self.name = name + self.instance_id = name.rsplit('/', 1)[1] + self._client = client + + +class _Database(object): + + def __init__(self, name, instance=None): + self.name = name + self.database_id = name.rsplit('/', 1)[1] + self._instance = instance + + +class _Pool(object): + _bound = None + + def bind(self, database): + self._bound = database + + def get(self): + session, self._session = self._session, None + return session + + def put(self, session): + self._session = session + + +class _Session(object): + + _rows = () + + def __init__(self, database=None, name=_BaseTest.SESSION_NAME): + self._database = database + self.name = name + + def execute_sql(self, sql, params, param_types, query_mode, resume_token): + self._executed = (sql, params, param_types, query_mode, resume_token) + return iter(self._rows) + + def run_in_transaction(self, func, *args, **kw): + self._retried = (func, args, kw) + return self._committed + + def read(self, table, columns, keyset, index, limit, resume_token): + self._read_with = (table, columns, keyset, index, limit, resume_token) + return iter(self._rows) + + +class _SessionPB(object): + name = TestDatabase.SESSION_NAME + + +class _FauxOperationFuture(object): + pass + + +class _FauxSpannerClient(_GAXBaseAPI): + + _committed = None + + def commit(self, session, mutations, + transaction_id='', single_use_transaction=None, options=None): + assert transaction_id == '' + self._committed = (session, mutations, single_use_transaction, options) + return self._commit_response + + +class _FauxDatabaseAdminAPI(_GAXBaseAPI): + + _create_database_conflict = False + _database_not_found = False + + def _make_grpc_already_exists(self): + from grpc.beta.interfaces import StatusCode + return self._make_grpc_error(StatusCode.ALREADY_EXISTS) + + def create_database(self, + parent, + create_statement, + extra_statements=None, + options=None): + from google.gax.errors import GaxError + self._created_database = ( + parent, create_statement, extra_statements, options) + if self._random_gax_error: + raise GaxError('error') + if self._create_database_conflict: + raise GaxError('conflict', self._make_grpc_already_exists()) + if self._database_not_found: + raise GaxError('not found', self._make_grpc_not_found()) + return self._create_database_response + + def get_database_ddl(self, database, options=None): + from google.gax.errors import GaxError + self._got_database_ddl = database, options + if self._random_gax_error: + raise GaxError('error') + if self._database_not_found: + raise GaxError('not found', self._make_grpc_not_found()) + return self._get_database_ddl_response + + def drop_database(self, database, options=None): + from google.gax.errors import GaxError + self._dropped_database = database, options + if self._random_gax_error: + raise GaxError('error') + if self._database_not_found: + raise GaxError('not found', self._make_grpc_not_found()) + return self._drop_database_response + + def update_database_ddl(self, database, statements, operation_id, + options=None): + from google.gax.errors import GaxError + self._updated_database_ddl = ( + database, statements, operation_id, options) + if self._random_gax_error: + raise GaxError('error') + if self._database_not_found: + raise GaxError('not found', self._make_grpc_not_found()) + return self._update_database_ddl_response diff --git a/spanner/unit_tests/test_instance.py b/spanner/unit_tests/test_instance.py new file mode 100644 index 000000000000..be275a49d023 --- /dev/null +++ b/spanner/unit_tests/test_instance.py @@ -0,0 +1,652 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from google.cloud._testing import _GAXBaseAPI + + +class TestInstance(unittest.TestCase): + + PROJECT = 'project' + PARENT = 'projects/' + PROJECT + INSTANCE_ID = 'instance-id' + INSTANCE_NAME = PARENT + '/instances/' + INSTANCE_ID + CONFIG_NAME = 'configuration-name' + LOCATION = 'projects/' + PROJECT + '/locations/' + CONFIG_NAME + DISPLAY_NAME = 'display_name' + NODE_COUNT = 5 + OP_ID = 8915 + OP_NAME = ('operations/projects/%s/instances/%soperations/%d' % + (PROJECT, INSTANCE_ID, OP_ID)) + TABLE_ID = 'table_id' + TABLE_NAME = INSTANCE_NAME + '/tables/' + TABLE_ID + TIMEOUT_SECONDS = 1 + DATABASE_ID = 'database_id' + DATABASE_NAME = '%s/databases/%s' % (INSTANCE_NAME, DATABASE_ID) + + def _getTargetClass(self): + from google.cloud.spanner.instance import Instance + return Instance + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_constructor_defaults(self): + from google.cloud.spanner.instance import DEFAULT_NODE_COUNT + client = object() + instance = self._makeOne(self.INSTANCE_ID, client) + self.assertEqual(instance.instance_id, self.INSTANCE_ID) + self.assertTrue(instance._client is client) + self.assertTrue(instance.configuration_name is None) + self.assertEqual(instance.node_count, DEFAULT_NODE_COUNT) + self.assertEqual(instance.display_name, self.INSTANCE_ID) + + def test_constructor_non_default(self): + DISPLAY_NAME = 'display_name' + client = object() + + instance = self._makeOne(self.INSTANCE_ID, client, + configuration_name=self.CONFIG_NAME, + node_count=self.NODE_COUNT, + display_name=DISPLAY_NAME) + self.assertEqual(instance.instance_id, self.INSTANCE_ID) + self.assertTrue(instance._client is client) + self.assertEqual(instance.configuration_name, self.CONFIG_NAME) + self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.display_name, DISPLAY_NAME) + + def test_copy(self): + DISPLAY_NAME = 'display_name' + + client = _Client(self.PROJECT) + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME, + display_name=DISPLAY_NAME) + new_instance = instance.copy() + + # Make sure the client copy succeeded. + self.assertFalse(new_instance._client is client) + self.assertEqual(new_instance._client, client) + # Make sure the client got copied to a new instance. + self.assertFalse(instance is new_instance) + self.assertEqual(instance, new_instance) + + def test__update_from_pb_success(self): + from google.cloud.proto.spanner.admin.instance.v1 import ( + spanner_instance_admin_pb2 as admin_v1_pb2) + + display_name = 'display_name' + instance_pb = admin_v1_pb2.Instance( + display_name=display_name, + ) + + instance = self._makeOne(None, None, None, None) + self.assertEqual(instance.display_name, None) + instance._update_from_pb(instance_pb) + self.assertEqual(instance.display_name, display_name) + + def test__update_from_pb_no_display_name(self): + from google.cloud.proto.spanner.admin.instance.v1 import ( + spanner_instance_admin_pb2 as admin_v1_pb2) + + instance_pb = admin_v1_pb2.Instance() + instance = self._makeOne(None, None, None, None) + self.assertEqual(instance.display_name, None) + with self.assertRaises(ValueError): + instance._update_from_pb(instance_pb) + self.assertEqual(instance.display_name, None) + + def test_from_pb_bad_instance_name(self): + from google.cloud.proto.spanner.admin.instance.v1 import ( + spanner_instance_admin_pb2 as admin_v1_pb2) + + instance_name = 'INCORRECT_FORMAT' + instance_pb = admin_v1_pb2.Instance(name=instance_name) + + klass = self._getTargetClass() + with self.assertRaises(ValueError): + klass.from_pb(instance_pb, None) + + def test_from_pb_project_mistmatch(self): + from google.cloud.proto.spanner.admin.instance.v1 import ( + spanner_instance_admin_pb2 as admin_v1_pb2) + + ALT_PROJECT = 'ALT_PROJECT' + client = _Client(project=ALT_PROJECT) + + self.assertNotEqual(self.PROJECT, ALT_PROJECT) + + instance_pb = admin_v1_pb2.Instance(name=self.INSTANCE_NAME) + + klass = self._getTargetClass() + with self.assertRaises(ValueError): + klass.from_pb(instance_pb, client) + + def test_from_pb_success(self): + from google.cloud.proto.spanner.admin.instance.v1 import ( + spanner_instance_admin_pb2 as admin_v1_pb2) + + client = _Client(project=self.PROJECT) + + instance_pb = admin_v1_pb2.Instance( + name=self.INSTANCE_NAME, + config=self.CONFIG_NAME, + display_name=self.INSTANCE_ID, + ) + + klass = self._getTargetClass() + instance = klass.from_pb(instance_pb, client) + self.assertTrue(isinstance(instance, klass)) + self.assertEqual(instance._client, client) + self.assertEqual(instance.instance_id, self.INSTANCE_ID) + self.assertEqual(instance.configuration_name, self.CONFIG_NAME) + + def test_name_property(self): + client = _Client(project=self.PROJECT) + + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + self.assertEqual(instance.name, self.INSTANCE_NAME) + + def test___eq__(self): + client = object() + instance1 = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + instance2 = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + self.assertEqual(instance1, instance2) + + def test___eq__type_differ(self): + client = object() + instance1 = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + instance2 = object() + self.assertNotEqual(instance1, instance2) + + def test___ne__same_value(self): + client = object() + instance1 = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + instance2 = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + comparison_val = (instance1 != instance2) + self.assertFalse(comparison_val) + + def test___ne__(self): + instance1 = self._makeOne('instance_id1', 'client1', self.CONFIG_NAME) + instance2 = self._makeOne('instance_id2', 'client2', self.CONFIG_NAME) + self.assertNotEqual(instance1, instance2) + + def test_create_grpc_error(self): + from google.gax.errors import GaxError + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _random_gax_error=True) + instance = self._makeOne(self.INSTANCE_ID, client, + configuration_name=self.CONFIG_NAME) + + with self.assertRaises(GaxError): + instance.create() + + (parent, instance_id, instance, options) = api._created_instance + self.assertEqual(parent, self.PARENT) + self.assertEqual(instance_id, self.INSTANCE_ID) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.config, self.CONFIG_NAME) + self.assertEqual(instance.display_name, self.INSTANCE_ID) + self.assertEqual(instance.node_count, 1) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_create_already_exists(self): + from google.cloud.exceptions import Conflict + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _create_instance_conflict=True) + instance = self._makeOne(self.INSTANCE_ID, client, + configuration_name=self.CONFIG_NAME) + + with self.assertRaises(Conflict): + instance.create() + + (parent, instance_id, instance, options) = api._created_instance + self.assertEqual(parent, self.PARENT) + self.assertEqual(instance_id, self.INSTANCE_ID) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.config, self.CONFIG_NAME) + self.assertEqual(instance.display_name, self.INSTANCE_ID) + self.assertEqual(instance.node_count, 1) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_create_success(self): + op_future = _FauxOperationFuture() + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _create_instance_response=op_future) + instance = self._makeOne(self.INSTANCE_ID, client, + configuration_name=self.CONFIG_NAME, + display_name=self.DISPLAY_NAME, + node_count=self.NODE_COUNT) + + future = instance.create() + + self.assertIs(future, op_future) + self.assertEqual(future.caller_metadata, + {'request_type': 'CreateInstance'}) + + (parent, instance_id, instance, options) = api._created_instance + self.assertEqual(parent, self.PARENT) + self.assertEqual(instance_id, self.INSTANCE_ID) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.config, self.CONFIG_NAME) + self.assertEqual(instance.display_name, self.DISPLAY_NAME) + self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_exists_instance_grpc_error(self): + from google.gax.errors import GaxError + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _random_gax_error=True) + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + + with self.assertRaises(GaxError): + instance.exists() + + name, options = api._got_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_exists_instance_not_found(self): + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _instance_not_found=True) + api._instance_not_found = True + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + + self.assertFalse(instance.exists()) + + name, options = api._got_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_exists_success(self): + from google.cloud.proto.spanner.admin.instance.v1 import ( + spanner_instance_admin_pb2 as admin_v1_pb2) + client = _Client(self.PROJECT) + instance_pb = admin_v1_pb2.Instance( + name=self.INSTANCE_NAME, + config=self.CONFIG_NAME, + display_name=self.DISPLAY_NAME, + node_count=self.NODE_COUNT, + ) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _get_instance_response=instance_pb) + instance = self._makeOne(self.INSTANCE_ID, client) + + self.assertTrue(instance.exists()) + + name, options = api._got_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_reload_instance_grpc_error(self): + from google.gax.errors import GaxError + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _random_gax_error=True) + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + + with self.assertRaises(GaxError): + instance.reload() + + name, options = api._got_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_reload_instance_not_found(self): + from google.cloud.exceptions import NotFound + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _instance_not_found=True) + api._instance_not_found = True + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + + with self.assertRaises(NotFound): + instance.reload() + + name, options = api._got_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_reload_success(self): + from google.cloud.proto.spanner.admin.instance.v1 import ( + spanner_instance_admin_pb2 as admin_v1_pb2) + client = _Client(self.PROJECT) + instance_pb = admin_v1_pb2.Instance( + name=self.INSTANCE_NAME, + config=self.CONFIG_NAME, + display_name=self.DISPLAY_NAME, + node_count=self.NODE_COUNT, + ) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _get_instance_response=instance_pb) + instance = self._makeOne(self.INSTANCE_ID, client) + + instance.reload() + + self.assertEqual(instance.configuration_name, self.CONFIG_NAME) + self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.display_name, self.DISPLAY_NAME) + + name, options = api._got_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_update_grpc_error(self): + from google.gax.errors import GaxError + from google.cloud.spanner.instance import DEFAULT_NODE_COUNT + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _random_gax_error=True) + instance = self._makeOne(self.INSTANCE_ID, client, + configuration_name=self.CONFIG_NAME) + + with self.assertRaises(GaxError): + instance.update() + + instance, field_mask, options = api._updated_instance + self.assertEqual(field_mask.paths, + ['config', 'display_name', 'node_count']) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.config, self.CONFIG_NAME) + self.assertEqual(instance.display_name, self.INSTANCE_ID) + self.assertEqual(instance.node_count, DEFAULT_NODE_COUNT) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_update_not_found(self): + from google.cloud.exceptions import NotFound + from google.cloud.spanner.instance import DEFAULT_NODE_COUNT + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _instance_not_found=True) + instance = self._makeOne(self.INSTANCE_ID, client, + configuration_name=self.CONFIG_NAME) + + with self.assertRaises(NotFound): + instance.update() + + instance, field_mask, options = api._updated_instance + self.assertEqual(field_mask.paths, + ['config', 'display_name', 'node_count']) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.config, self.CONFIG_NAME) + self.assertEqual(instance.display_name, self.INSTANCE_ID) + self.assertEqual(instance.node_count, DEFAULT_NODE_COUNT) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_update_success(self): + op_future = _FauxOperationFuture() + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _update_instance_response=op_future) + instance = self._makeOne(self.INSTANCE_ID, client, + configuration_name=self.CONFIG_NAME, + node_count=self.NODE_COUNT, + display_name=self.DISPLAY_NAME) + + future = instance.update() + + self.assertIs(future, op_future) + self.assertEqual(future.caller_metadata, + {'request_type': 'UpdateInstance'}) + + instance, field_mask, options = api._updated_instance + self.assertEqual(field_mask.paths, + ['config', 'display_name', 'node_count']) + self.assertEqual(instance.name, self.INSTANCE_NAME) + self.assertEqual(instance.config, self.CONFIG_NAME) + self.assertEqual(instance.display_name, self.DISPLAY_NAME) + self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_delete_grpc_error(self): + from google.gax.errors import GaxError + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _random_gax_error=True) + instance = self._makeOne(self.INSTANCE_ID, client) + + with self.assertRaises(GaxError): + instance.delete() + + name, options = api._deleted_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_delete_not_found(self): + from google.cloud.exceptions import NotFound + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _instance_not_found=True) + instance = self._makeOne(self.INSTANCE_ID, client) + + with self.assertRaises(NotFound): + instance.delete() + + name, options = api._deleted_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_delete_success(self): + from google.protobuf.empty_pb2 import Empty + client = _Client(self.PROJECT) + api = client.instance_admin_api = _FauxInstanceAdminAPI( + _delete_instance_response=Empty()) + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + + instance.delete() + + name, options = api._deleted_instance + self.assertEqual(name, self.INSTANCE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_database_factory_defaults(self): + from google.cloud.spanner.database import Database + from google.cloud.spanner.pool import BurstyPool + client = _Client(self.PROJECT) + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + DATABASE_ID = 'database-id' + + database = instance.database(DATABASE_ID) + + self.assertTrue(isinstance(database, Database)) + self.assertEqual(database.database_id, DATABASE_ID) + self.assertTrue(database._instance is instance) + self.assertEqual(list(database.ddl_statements), []) + self.assertIsInstance(database._pool, BurstyPool) + pool = database._pool + self.assertIs(pool._database, database) + + def test_database_factory_explicit(self): + from google.cloud.spanner._fixtures import DDL_STATEMENTS + from google.cloud.spanner.database import Database + client = _Client(self.PROJECT) + instance = self._makeOne(self.INSTANCE_ID, client, self.CONFIG_NAME) + DATABASE_ID = 'database-id' + pool = _Pool() + + database = instance.database( + DATABASE_ID, ddl_statements=DDL_STATEMENTS, pool=pool) + + self.assertTrue(isinstance(database, Database)) + self.assertEqual(database.database_id, DATABASE_ID) + self.assertTrue(database._instance is instance) + self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS) + self.assertIs(database._pool, pool) + self.assertIs(pool._bound, database) + + def test_list_databases_wo_paging(self): + from google.cloud._testing import _GAXPageIterator + from google.gax import INITIAL_PAGE + from google.cloud.spanner.database import Database + NEXT_TOKEN = 'TOKEN' + database_pb = _DatabasePB(name=self.DATABASE_NAME) + response = _GAXPageIterator([database_pb], page_token=NEXT_TOKEN) + client = _Client(self.PROJECT) + api = client.database_admin_api = _FauxDatabaseAdminAPI() + api._list_databases_response = response + instance = self._makeOne(self.INSTANCE_ID, client) + + iterator = instance.list_databases() + next_token = iterator.next_page_token + databases = list(iterator) + + self.assertEqual(len(databases), 1) + database = databases[0] + self.assertTrue(isinstance(database, Database)) + self.assertEqual(database.name, self.DATABASE_NAME) + self.assertEqual(next_token, NEXT_TOKEN) + + instance_name, page_size, options = api._listed_databases + self.assertEqual(instance_name, self.INSTANCE_NAME) + self.assertEqual(page_size, None) + self.assertTrue(options.page_token is INITIAL_PAGE) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + def test_list_databases_w_paging(self): + from google.cloud._testing import _GAXPageIterator + from google.cloud.spanner.database import Database + SIZE = 15 + TOKEN = 'TOKEN' + database_pb = _DatabasePB(name=self.DATABASE_NAME) + response = _GAXPageIterator([database_pb]) + client = _Client(self.PROJECT) + api = client.database_admin_api = _FauxDatabaseAdminAPI() + api._list_databases_response = response + instance = self._makeOne(self.INSTANCE_ID, client) + + iterator = instance.list_databases( + page_size=SIZE, page_token=TOKEN) + next_token = iterator.next_page_token + databases = list(iterator) + + self.assertEqual(len(databases), 1) + database = databases[0] + self.assertTrue(isinstance(database, Database)) + self.assertEqual(database.name, self.DATABASE_NAME) + self.assertEqual(next_token, None) + + instance_name, page_size, options = api._listed_databases + self.assertEqual(instance_name, self.INSTANCE_NAME) + self.assertEqual(page_size, SIZE) + self.assertEqual(options.page_token, TOKEN) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', instance.name)]) + + +class _Client(object): + + def __init__(self, project, timeout_seconds=None): + self.project = project + self.project_name = 'projects/' + self.project + self.timeout_seconds = timeout_seconds + + def copy(self): + from copy import deepcopy + return deepcopy(self) + + def __eq__(self, other): + return (other.project == self.project and + other.project_name == self.project_name and + other.timeout_seconds == self.timeout_seconds) + + +class _DatabasePB(object): + + def __init__(self, name): + self.name = name + + +class _FauxInstanceAdminAPI(_GAXBaseAPI): + + _create_instance_conflict = False + _instance_not_found = False + + def _make_grpc_already_exists(self): + from grpc.beta.interfaces import StatusCode + return self._make_grpc_error(StatusCode.ALREADY_EXISTS) + + def create_instance(self, parent, instance_id, instance, options=None): + from google.gax.errors import GaxError + self._created_instance = (parent, instance_id, instance, options) + if self._random_gax_error: + raise GaxError('error') + if self._create_instance_conflict: + raise GaxError('conflict', self._make_grpc_already_exists()) + return self._create_instance_response + + def get_instance(self, name, options=None): + from google.gax.errors import GaxError + self._got_instance = (name, options) + if self._random_gax_error: + raise GaxError('error') + if self._instance_not_found: + raise GaxError('not found', self._make_grpc_not_found()) + return self._get_instance_response + + def update_instance(self, instance, field_mask, options=None): + from google.gax.errors import GaxError + self._updated_instance = (instance, field_mask, options) + if self._random_gax_error: + raise GaxError('error') + if self._instance_not_found: + raise GaxError('not found', self._make_grpc_not_found()) + return self._update_instance_response + + def delete_instance(self, name, options=None): + from google.gax.errors import GaxError + self._deleted_instance = name, options + if self._random_gax_error: + raise GaxError('error') + if self._instance_not_found: + raise GaxError('not found', self._make_grpc_not_found()) + return self._delete_instance_response + + +class _FauxDatabaseAdminAPI(object): + + def list_databases(self, name, page_size, options): + self._listed_databases = (name, page_size, options) + return self._list_databases_response + + +class _FauxOperationFuture(object): + pass + + +class _Pool(object): + _bound = None + + def bind(self, database): + self._bound = database diff --git a/spanner/unit_tests/test_keyset.py b/spanner/unit_tests/test_keyset.py new file mode 100644 index 000000000000..7da6dfd9fc85 --- /dev/null +++ b/spanner/unit_tests/test_keyset.py @@ -0,0 +1,218 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + + +class TestKeyRange(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.keyset import KeyRange + return KeyRange + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_no_start_no_end(self): + with self.assertRaises(ValueError): + self._makeOne() + + def test_ctor_w_start_open_and_start_closed(self): + KEY_1 = [u'key_1'] + KEY_2 = [u'key_2'] + with self.assertRaises(ValueError): + self._makeOne(start_open=KEY_1, start_closed=KEY_2) + + def test_ctor_w_end_open_and_end_closed(self): + KEY_1 = [u'key_1'] + KEY_2 = [u'key_2'] + with self.assertRaises(ValueError): + self._makeOne(end_open=KEY_1, end_closed=KEY_2) + + def test_ctor_w_only_start_open(self): + KEY_1 = [u'key_1'] + krange = self._makeOne(start_open=KEY_1) + self.assertEqual(krange.start_open, KEY_1) + self.assertEqual(krange.start_closed, None) + self.assertEqual(krange.end_open, None) + self.assertEqual(krange.end_closed, None) + + def test_ctor_w_only_start_closed(self): + KEY_1 = [u'key_1'] + krange = self._makeOne(start_closed=KEY_1) + self.assertEqual(krange.start_open, None) + self.assertEqual(krange.start_closed, KEY_1) + self.assertEqual(krange.end_open, None) + self.assertEqual(krange.end_closed, None) + + def test_ctor_w_only_end_open(self): + KEY_1 = [u'key_1'] + krange = self._makeOne(end_open=KEY_1) + self.assertEqual(krange.start_open, None) + self.assertEqual(krange.start_closed, None) + self.assertEqual(krange.end_open, KEY_1) + self.assertEqual(krange.end_closed, None) + + def test_ctor_w_only_end_closed(self): + KEY_1 = [u'key_1'] + krange = self._makeOne(end_closed=KEY_1) + self.assertEqual(krange.start_open, None) + self.assertEqual(krange.start_closed, None) + self.assertEqual(krange.end_open, None) + self.assertEqual(krange.end_closed, KEY_1) + + def test_ctor_w_start_open_and_end_closed(self): + KEY_1 = [u'key_1'] + KEY_2 = [u'key_2'] + krange = self._makeOne(start_open=KEY_1, end_closed=KEY_2) + self.assertEqual(krange.start_open, KEY_1) + self.assertEqual(krange.start_closed, None) + self.assertEqual(krange.end_open, None) + self.assertEqual(krange.end_closed, KEY_2) + + def test_ctor_w_start_closed_and_end_open(self): + KEY_1 = [u'key_1'] + KEY_2 = [u'key_2'] + krange = self._makeOne(start_closed=KEY_1, end_open=KEY_2) + self.assertEqual(krange.start_open, None) + self.assertEqual(krange.start_closed, KEY_1) + self.assertEqual(krange.end_open, KEY_2) + self.assertEqual(krange.end_closed, None) + + def test_to_pb_w_start_closed_and_end_open(self): + from google.cloud.proto.spanner.v1.keys_pb2 import KeyRange + KEY_1 = [u'key_1'] + KEY_2 = [u'key_2'] + krange = self._makeOne(start_closed=KEY_1, end_open=KEY_2) + krange_pb = krange.to_pb() + self.assertIsInstance(krange_pb, KeyRange) + self.assertEqual(len(krange_pb.start_closed), 1) + self.assertEqual(krange_pb.start_closed.values[0].string_value, + KEY_1[0]) + self.assertEqual(len(krange_pb.end_open), 1) + self.assertEqual(krange_pb.end_open.values[0].string_value, KEY_2[0]) + + def test_to_pb_w_start_open_and_end_closed(self): + from google.cloud.proto.spanner.v1.keys_pb2 import KeyRange + KEY_1 = [u'key_1'] + KEY_2 = [u'key_2'] + krange = self._makeOne(start_open=KEY_1, end_closed=KEY_2) + krange_pb = krange.to_pb() + self.assertIsInstance(krange_pb, KeyRange) + self.assertEqual(len(krange_pb.start_open), 1) + self.assertEqual(krange_pb.start_open.values[0].string_value, KEY_1[0]) + self.assertEqual(len(krange_pb.end_closed), 1) + self.assertEqual(krange_pb.end_closed.values[0].string_value, KEY_2[0]) + + +class TestKeySet(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.keyset import KeySet + return KeySet + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_w_all(self): + keyset = self._makeOne(all_=True) + + self.assertTrue(keyset.all_) + self.assertEqual(keyset.keys, []) + self.assertEqual(keyset.ranges, []) + + def test_ctor_w_keys(self): + KEYS = [[u'key1'], [u'key2']] + + keyset = self._makeOne(keys=KEYS) + + self.assertFalse(keyset.all_) + self.assertEqual(keyset.keys, KEYS) + self.assertEqual(keyset.ranges, []) + + def test_ctor_w_ranges(self): + from google.cloud.spanner.keyset import KeyRange + range_1 = KeyRange(start_closed=[u'key1'], end_open=[u'key3']) + range_2 = KeyRange(start_open=[u'key5'], end_closed=[u'key6']) + + keyset = self._makeOne(ranges=[range_1, range_2]) + + self.assertFalse(keyset.all_) + self.assertEqual(keyset.keys, []) + self.assertEqual(keyset.ranges, [range_1, range_2]) + + def test_ctor_w_all_and_keys(self): + + with self.assertRaises(ValueError): + self._makeOne(all_=True, keys=[['key1'], ['key2']]) + + def test_ctor_w_all_and_ranges(self): + from google.cloud.spanner.keyset import KeyRange + range_1 = KeyRange(start_closed=[u'key1'], end_open=[u'key3']) + range_2 = KeyRange(start_open=[u'key5'], end_closed=[u'key6']) + + with self.assertRaises(ValueError): + self._makeOne(all_=True, ranges=[range_1, range_2]) + + def test_to_pb_w_all(self): + from google.cloud.proto.spanner.v1.keys_pb2 import KeySet + keyset = self._makeOne(all_=True) + + result = keyset.to_pb() + + self.assertIsInstance(result, KeySet) + self.assertTrue(result.all) + self.assertEqual(len(result.keys), 0) + self.assertEqual(len(result.ranges), 0) + + def test_to_pb_w_only_keys(self): + from google.cloud.proto.spanner.v1.keys_pb2 import KeySet + KEYS = [[u'key1'], [u'key2']] + keyset = self._makeOne(keys=KEYS) + + result = keyset.to_pb() + + self.assertIsInstance(result, KeySet) + self.assertFalse(result.all) + self.assertEqual(len(result.keys), len(KEYS)) + + for found, expected in zip(result.keys, KEYS): + self.assertEqual(len(found), len(expected)) + self.assertEqual(found.values[0].string_value, expected[0]) + + self.assertEqual(len(result.ranges), 0) + + def test_to_pb_w_only_ranges(self): + from google.cloud.proto.spanner.v1.keys_pb2 import KeySet + from google.cloud.spanner.keyset import KeyRange + KEY_1 = u'KEY_1' + KEY_2 = u'KEY_2' + KEY_3 = u'KEY_3' + KEY_4 = u'KEY_4' + RANGES = [ + KeyRange(start_open=KEY_1, end_closed=KEY_2), + KeyRange(start_closed=KEY_3, end_open=KEY_4), + ] + keyset = self._makeOne(ranges=RANGES) + + result = keyset.to_pb() + + self.assertIsInstance(result, KeySet) + self.assertFalse(result.all) + self.assertEqual(len(result.keys), 0) + self.assertEqual(len(result.ranges), len(RANGES)) + + for found, expected in zip(result.ranges, RANGES): + self.assertEqual(found, expected.to_pb()) diff --git a/spanner/unit_tests/test_pool.py b/spanner/unit_tests/test_pool.py new file mode 100644 index 000000000000..e0a06852c031 --- /dev/null +++ b/spanner/unit_tests/test_pool.py @@ -0,0 +1,810 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + + +class TestAbstractSessionPool(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.pool import AbstractSessionPool + return AbstractSessionPool + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_defaults(self): + pool = self._makeOne() + self.assertIsNone(pool._database) + + def test_bind_abstract(self): + pool = self._makeOne() + database = _Database('name') + with self.assertRaises(NotImplementedError): + pool.bind(database) + + def test_get_abstract(self): + pool = self._makeOne() + with self.assertRaises(NotImplementedError): + pool.get() + + def test_put_abstract(self): + pool = self._makeOne() + session = object() + with self.assertRaises(NotImplementedError): + pool.put(session) + + def test_clear_abstract(self): + pool = self._makeOne() + with self.assertRaises(NotImplementedError): + pool.clear() + + def test_session_wo_kwargs(self): + from google.cloud.spanner.pool import SessionCheckout + pool = self._makeOne() + checkout = pool.session() + self.assertIsInstance(checkout, SessionCheckout) + self.assertIs(checkout._pool, pool) + self.assertIsNone(checkout._session) + self.assertEqual(checkout._kwargs, {}) + + def test_session_w_kwargs(self): + from google.cloud.spanner.pool import SessionCheckout + pool = self._makeOne() + checkout = pool.session(foo='bar') + self.assertIsInstance(checkout, SessionCheckout) + self.assertIs(checkout._pool, pool) + self.assertIsNone(checkout._session) + self.assertEqual(checkout._kwargs, {'foo': 'bar'}) + + +class TestFixedSizePool(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.pool import FixedSizePool + return FixedSizePool + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_defaults(self): + pool = self._makeOne() + self.assertIsNone(pool._database) + self.assertEqual(pool.size, 10) + self.assertEqual(pool.default_timeout, 10) + self.assertTrue(pool._sessions.empty()) + + def test_ctor_explicit(self): + pool = self._makeOne(size=4, default_timeout=30) + self.assertIsNone(pool._database) + self.assertEqual(pool.size, 4) + self.assertEqual(pool.default_timeout, 30) + self.assertTrue(pool._sessions.empty()) + + def test_bind(self): + pool = self._makeOne() + database = _Database('name') + SESSIONS = [_Session(database)] * 10 + database._sessions.extend(SESSIONS) + + pool.bind(database) + + self.assertIs(pool._database, database) + self.assertEqual(pool.size, 10) + self.assertEqual(pool.default_timeout, 10) + self.assertTrue(pool._sessions.full()) + + for session in SESSIONS: + self.assertTrue(session._created) + + def test_get_non_expired(self): + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database)] * 4 + database._sessions.extend(SESSIONS) + pool.bind(database) + + session = pool.get() + + self.assertIs(session, SESSIONS[0]) + self.assertTrue(session._exists_checked) + self.assertFalse(pool._sessions.full()) + + def test_get_expired(self): + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database)] * 5 + SESSIONS[0]._exists = False + database._sessions.extend(SESSIONS) + pool.bind(database) + + session = pool.get() + + self.assertIs(session, SESSIONS[4]) + self.assertTrue(session._created) + self.assertTrue(SESSIONS[0]._exists_checked) + self.assertFalse(pool._sessions.full()) + + def test_get_empty_default_timeout(self): + from six.moves.queue import Empty + pool = self._makeOne(size=1) + queue = pool._sessions = _Queue() + + with self.assertRaises(Empty): + pool.get() + + self.assertEqual(queue._got, {'block': True, 'timeout': 10}) + + def test_get_empty_explicit_timeout(self): + from six.moves.queue import Empty + pool = self._makeOne(size=1, default_timeout=0.1) + queue = pool._sessions = _Queue() + + with self.assertRaises(Empty): + pool.get(timeout=1) + + self.assertEqual(queue._got, {'block': True, 'timeout': 1}) + + def test_put_full(self): + from six.moves.queue import Full + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database)] * 4 + database._sessions.extend(SESSIONS) + pool.bind(database) + + with self.assertRaises(Full): + pool.put(_Session(database)) + + self.assertTrue(pool._sessions.full()) + + def test_put_non_full(self): + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database)] * 4 + database._sessions.extend(SESSIONS) + pool.bind(database) + pool._sessions.get() + + pool.put(_Session(database)) + + self.assertTrue(pool._sessions.full()) + + def test_clear(self): + pool = self._makeOne() + database = _Database('name') + SESSIONS = [_Session(database)] * 10 + database._sessions.extend(SESSIONS) + pool.bind(database) + self.assertTrue(pool._sessions.full()) + + for session in SESSIONS: + self.assertTrue(session._created) + + pool.clear() + + for session in SESSIONS: + self.assertTrue(session._deleted) + + +class TestBurstyPool(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.pool import BurstyPool + return BurstyPool + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_defaults(self): + pool = self._makeOne() + self.assertIsNone(pool._database) + self.assertEqual(pool.target_size, 10) + self.assertTrue(pool._sessions.empty()) + + def test_ctor_explicit(self): + pool = self._makeOne(target_size=4) + self.assertIsNone(pool._database) + self.assertEqual(pool.target_size, 4) + self.assertTrue(pool._sessions.empty()) + + def test_get_empty(self): + pool = self._makeOne() + database = _Database('name') + database._sessions.append(_Session(database)) + pool.bind(database) + + session = pool.get() + + self.assertIsInstance(session, _Session) + self.assertIs(session._database, database) + self.assertTrue(session._created) + self.assertTrue(pool._sessions.empty()) + + def test_get_non_empty_session_exists(self): + pool = self._makeOne() + database = _Database('name') + previous = _Session(database) + pool.bind(database) + pool.put(previous) + + session = pool.get() + + self.assertIs(session, previous) + self.assertFalse(session._created) + self.assertTrue(session._exists_checked) + self.assertTrue(pool._sessions.empty()) + + def test_get_non_empty_session_expired(self): + pool = self._makeOne() + database = _Database('name') + previous = _Session(database, exists=False) + newborn = _Session(database) + database._sessions.append(newborn) + pool.bind(database) + pool.put(previous) + + session = pool.get() + + self.assertTrue(previous._exists_checked) + self.assertIs(session, newborn) + self.assertTrue(session._created) + self.assertFalse(session._exists_checked) + self.assertTrue(pool._sessions.empty()) + + def test_put_empty(self): + pool = self._makeOne() + database = _Database('name') + pool.bind(database) + session = _Session(database) + + pool.put(session) + + self.assertFalse(pool._sessions.empty()) + + def test_put_full(self): + pool = self._makeOne(target_size=1) + database = _Database('name') + pool.bind(database) + older = _Session(database) + pool.put(older) + self.assertFalse(pool._sessions.empty()) + + younger = _Session(database) + pool.put(younger) # discarded silently + + self.assertTrue(younger._deleted) + self.assertIs(pool.get(), older) + + def test_put_full_expired(self): + pool = self._makeOne(target_size=1) + database = _Database('name') + pool.bind(database) + older = _Session(database) + pool.put(older) + self.assertFalse(pool._sessions.empty()) + + younger = _Session(database, exists=False) + pool.put(younger) # discarded silently + + self.assertTrue(younger._deleted) + self.assertIs(pool.get(), older) + + def test_clear(self): + pool = self._makeOne() + database = _Database('name') + pool.bind(database) + previous = _Session(database) + pool.put(previous) + + pool.clear() + + self.assertTrue(previous._deleted) + + +class TestPingingPool(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.pool import PingingPool + return PingingPool + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_defaults(self): + pool = self._makeOne() + self.assertIsNone(pool._database) + self.assertEqual(pool.size, 10) + self.assertEqual(pool.default_timeout, 10) + self.assertEqual(pool._delta.seconds, 3000) + self.assertTrue(pool._sessions.empty()) + + def test_ctor_explicit(self): + pool = self._makeOne(size=4, default_timeout=30, ping_interval=1800) + self.assertIsNone(pool._database) + self.assertEqual(pool.size, 4) + self.assertEqual(pool.default_timeout, 30) + self.assertEqual(pool._delta.seconds, 1800) + self.assertTrue(pool._sessions.empty()) + + def test_bind(self): + pool = self._makeOne() + database = _Database('name') + SESSIONS = [_Session(database)] * 10 + database._sessions.extend(SESSIONS) + + pool.bind(database) + + self.assertIs(pool._database, database) + self.assertEqual(pool.size, 10) + self.assertEqual(pool.default_timeout, 10) + self.assertEqual(pool._delta.seconds, 3000) + self.assertTrue(pool._sessions.full()) + + for session in SESSIONS: + self.assertTrue(session._created) + + def test_get_hit_no_ping(self): + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database)] * 4 + database._sessions.extend(SESSIONS) + pool.bind(database) + + session = pool.get() + + self.assertIs(session, SESSIONS[0]) + self.assertFalse(session._exists_checked) + self.assertFalse(pool._sessions.full()) + + def test_get_hit_w_ping(self): + import datetime + from google.cloud._testing import _Monkey + from google.cloud.spanner import pool as MUT + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database)] * 4 + database._sessions.extend(SESSIONS) + + sessions_created = ( + datetime.datetime.utcnow() - datetime.timedelta(seconds=4000)) + + with _Monkey(MUT, _NOW=lambda: sessions_created): + pool.bind(database) + + session = pool.get() + + self.assertIs(session, SESSIONS[0]) + self.assertTrue(session._exists_checked) + self.assertFalse(pool._sessions.full()) + + def test_get_hit_w_ping_expired(self): + import datetime + from google.cloud._testing import _Monkey + from google.cloud.spanner import pool as MUT + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database)] * 5 + SESSIONS[0]._exists = False + database._sessions.extend(SESSIONS) + + sessions_created = ( + datetime.datetime.utcnow() - datetime.timedelta(seconds=4000)) + + with _Monkey(MUT, _NOW=lambda: sessions_created): + pool.bind(database) + + session = pool.get() + + self.assertIs(session, SESSIONS[4]) + self.assertTrue(session._created) + self.assertTrue(SESSIONS[0]._exists_checked) + self.assertFalse(pool._sessions.full()) + + def test_get_empty_default_timeout(self): + from six.moves.queue import Empty + pool = self._makeOne(size=1) + queue = pool._sessions = _Queue() + + with self.assertRaises(Empty): + pool.get() + + self.assertEqual(queue._got, {'block': True, 'timeout': 10}) + + def test_get_empty_explicit_timeout(self): + from six.moves.queue import Empty + pool = self._makeOne(size=1, default_timeout=0.1) + queue = pool._sessions = _Queue() + + with self.assertRaises(Empty): + pool.get(timeout=1) + + self.assertEqual(queue._got, {'block': True, 'timeout': 1}) + + def test_put_full(self): + from six.moves.queue import Full + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database)] * 4 + database._sessions.extend(SESSIONS) + pool.bind(database) + + with self.assertRaises(Full): + pool.put(_Session(database)) + + self.assertTrue(pool._sessions.full()) + + def test_put_non_full(self): + import datetime + from google.cloud._testing import _Monkey + from google.cloud.spanner import pool as MUT + pool = self._makeOne(size=1) + queue = pool._sessions = _Queue() + + now = datetime.datetime.utcnow() + database = _Database('name') + session = _Session(database) + + with _Monkey(MUT, _NOW=lambda: now): + pool.put(session) + + self.assertEqual(len(queue._items), 1) + ping_after, queued = queue._items[0] + self.assertEqual(ping_after, now + datetime.timedelta(seconds=3000)) + self.assertIs(queued, session) + + def test_clear(self): + pool = self._makeOne() + database = _Database('name') + SESSIONS = [_Session(database)] * 10 + database._sessions.extend(SESSIONS) + pool.bind(database) + self.assertTrue(pool._sessions.full()) + + for session in SESSIONS: + self.assertTrue(session._created) + + pool.clear() + + for session in SESSIONS: + self.assertTrue(session._deleted) + + def test_ping_empty(self): + pool = self._makeOne(size=1) + pool.ping() # Does not raise 'Empty' + + def test_ping_oldest_fresh(self): + pool = self._makeOne(size=1) + database = _Database('name') + SESSIONS = [_Session(database)] * 1 + database._sessions.extend(SESSIONS) + pool.bind(database) + + pool.ping() + + self.assertFalse(SESSIONS[0]._exists_checked) + + def test_ping_oldest_stale_but_exists(self): + import datetime + from google.cloud._testing import _Monkey + from google.cloud.spanner import pool as MUT + pool = self._makeOne(size=1) + database = _Database('name') + SESSIONS = [_Session(database)] * 1 + database._sessions.extend(SESSIONS) + pool.bind(database) + + later = datetime.datetime.utcnow() + datetime.timedelta(seconds=4000) + with _Monkey(MUT, _NOW=lambda: later): + pool.ping() + + self.assertTrue(SESSIONS[0]._exists_checked) + + def test_ping_oldest_stale_and_not_exists(self): + import datetime + from google.cloud._testing import _Monkey + from google.cloud.spanner import pool as MUT + pool = self._makeOne(size=1) + database = _Database('name') + SESSIONS = [_Session(database)] * 2 + SESSIONS[0]._exists = False + database._sessions.extend(SESSIONS) + pool.bind(database) + + later = datetime.datetime.utcnow() + datetime.timedelta(seconds=4000) + with _Monkey(MUT, _NOW=lambda: later): + pool.ping() + + self.assertTrue(SESSIONS[0]._exists_checked) + self.assertTrue(SESSIONS[1]._created) + + +class TestTransactionPingingPool(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.pool import TransactionPingingPool + return TransactionPingingPool + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_defaults(self): + pool = self._makeOne() + self.assertIsNone(pool._database) + self.assertEqual(pool.size, 10) + self.assertEqual(pool.default_timeout, 10) + self.assertEqual(pool._delta.seconds, 3000) + self.assertTrue(pool._sessions.empty()) + self.assertTrue(pool._pending_sessions.empty()) + + def test_ctor_explicit(self): + pool = self._makeOne(size=4, default_timeout=30, ping_interval=1800) + self.assertIsNone(pool._database) + self.assertEqual(pool.size, 4) + self.assertEqual(pool.default_timeout, 30) + self.assertEqual(pool._delta.seconds, 1800) + self.assertTrue(pool._sessions.empty()) + self.assertTrue(pool._pending_sessions.empty()) + + def test_bind(self): + pool = self._makeOne() + database = _Database('name') + SESSIONS = [_Session(database) for _ in range(10)] + database._sessions.extend(SESSIONS) + + pool.bind(database) + + self.assertIs(pool._database, database) + self.assertEqual(pool.size, 10) + self.assertEqual(pool.default_timeout, 10) + self.assertEqual(pool._delta.seconds, 3000) + self.assertTrue(pool._sessions.full()) + + for session in SESSIONS: + self.assertTrue(session._created) + txn = session._transaction + self.assertTrue(txn._begun) + + self.assertTrue(pool._pending_sessions.empty()) + + def test_put_full(self): + from six.moves.queue import Full + pool = self._makeOne(size=4) + database = _Database('name') + SESSIONS = [_Session(database) for _ in range(4)] + database._sessions.extend(SESSIONS) + pool.bind(database) + + with self.assertRaises(Full): + pool.put(_Session(database)) + + self.assertTrue(pool._sessions.full()) + + def test_put_non_full_w_active_txn(self): + pool = self._makeOne(size=1) + queue = pool._sessions = _Queue() + pending = pool._pending_sessions = _Queue() + database = _Database('name') + session = _Session(database) + txn = session.transaction() + + pool.put(session) + + self.assertEqual(len(queue._items), 1) + _, queued = queue._items[0] + self.assertIs(queued, session) + + self.assertEqual(len(pending._items), 0) + self.assertFalse(txn._begun) + + def test_put_non_full_w_committed_txn(self): + pool = self._makeOne(size=1) + queue = pool._sessions = _Queue() + pending = pool._pending_sessions = _Queue() + database = _Database('name') + session = _Session(database) + committed = session.transaction() + committed._committed = True + + pool.put(session) + + self.assertEqual(len(queue._items), 0) + + self.assertEqual(len(pending._items), 1) + self.assertIs(pending._items[0], session) + self.assertIsNot(session._transaction, committed) + self.assertFalse(session._transaction._begun) + + def test_put_non_full(self): + pool = self._makeOne(size=1) + queue = pool._sessions = _Queue() + pending = pool._pending_sessions = _Queue() + database = _Database('name') + session = _Session(database) + + pool.put(session) + + self.assertEqual(len(queue._items), 0) + self.assertEqual(len(pending._items), 1) + self.assertIs(pending._items[0], session) + + self.assertFalse(pending.empty()) + + def test_begin_pending_transactions_empty(self): + pool = self._makeOne(size=1) + pool.begin_pending_transactions() # no raise + + def test_begin_pending_transactions_non_empty(self): + pool = self._makeOne(size=1) + pool._sessions = _Queue() + + database = _Database('name') + TRANSACTIONS = [_Transaction()] + PENDING_SESSIONS = [ + _Session(database, transaction=txn) for txn in TRANSACTIONS] + + pending = pool._pending_sessions = _Queue(*PENDING_SESSIONS) + self.assertFalse(pending.empty()) + + pool.begin_pending_transactions() # no raise + + for txn in TRANSACTIONS: + self.assertTrue(txn._begun) + + self.assertTrue(pending.empty()) + + +class TestSessionCheckout(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.pool import SessionCheckout + return SessionCheckout + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_wo_kwargs(self): + pool = _Pool() + checkout = self._makeOne(pool) + self.assertIs(checkout._pool, pool) + self.assertIsNone(checkout._session) + self.assertEqual(checkout._kwargs, {}) + + def test_ctor_w_kwargs(self): + pool = _Pool() + checkout = self._makeOne(pool, foo='bar') + self.assertIs(checkout._pool, pool) + self.assertIsNone(checkout._session) + self.assertEqual(checkout._kwargs, {'foo': 'bar'}) + + def test_context_manager_wo_kwargs(self): + session = object() + pool = _Pool(session) + checkout = self._makeOne(pool) + + self.assertEqual(len(pool._items), 1) + self.assertIs(pool._items[0], session) + + with checkout as borrowed: + self.assertIs(borrowed, session) + self.assertEqual(len(pool._items), 0) + + self.assertEqual(len(pool._items), 1) + self.assertIs(pool._items[0], session) + self.assertEqual(pool._got, {}) + + def test_context_manager_w_kwargs(self): + session = object() + pool = _Pool(session) + checkout = self._makeOne(pool, foo='bar') + + self.assertEqual(len(pool._items), 1) + self.assertIs(pool._items[0], session) + + with checkout as borrowed: + self.assertIs(borrowed, session) + self.assertEqual(len(pool._items), 0) + + self.assertEqual(len(pool._items), 1) + self.assertIs(pool._items[0], session) + self.assertEqual(pool._got, {'foo': 'bar'}) + + +class _Transaction(object): + + _begun = False + _committed = False + _rolled_back = False + + def begin(self): + self._begun = True + + def committed(self): + return self._committed + + +class _Session(object): + + _transaction = None + + def __init__(self, database, exists=True, transaction=None): + self._database = database + self._exists = exists + self._exists_checked = False + self._created = False + self._deleted = False + self._transaction = transaction + + def create(self): + self._created = True + + def exists(self): + self._exists_checked = True + return self._exists + + def delete(self): + from google.cloud.exceptions import NotFound + self._deleted = True + if not self._exists: + raise NotFound("unknown session") + + def transaction(self): + txn = self._transaction = _Transaction() + return txn + + +class _Database(object): + + def __init__(self, name): + self.name = name + self._sessions = [] + + def session(self): + return self._sessions.pop() + + +class _Queue(object): + + _size = 1 + + def __init__(self, *items): + self._items = list(items) + + def empty(self): + return len(self._items) == 0 + + def full(self): + return len(self._items) >= self._size + + def get(self, **kwargs): + from six.moves.queue import Empty + self._got = kwargs + try: + return self._items.pop() + except IndexError: + raise Empty() + + def put(self, item, **kwargs): + self._put = kwargs + self._items.append(item) + + def put_nowait(self, item, **kwargs): + self._put_nowait = kwargs + self._items.append(item) + + +class _Pool(_Queue): + + _database = None diff --git a/spanner/unit_tests/test_session.py b/spanner/unit_tests/test_session.py new file mode 100644 index 000000000000..0c1f500e12e6 --- /dev/null +++ b/spanner/unit_tests/test_session.py @@ -0,0 +1,858 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from google.cloud._testing import _GAXBaseAPI + + +class TestSession(unittest.TestCase): + + PROJECT_ID = 'project-id' + INSTANCE_ID = 'instance-id' + INSTANCE_NAME = ('projects/' + PROJECT_ID + '/instances/' + INSTANCE_ID) + DATABASE_ID = 'database-id' + DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID + SESSION_ID = 'session-id' + SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID + + def _getTargetClass(self): + from google.cloud.spanner.session import Session + return Session + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_constructor(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + self.assertTrue(session.session_id is None) + self.assertTrue(session._database is database) + + def test_name_property_wo_session_id(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + with self.assertRaises(ValueError): + _ = session.name + + def test_name_property_w_session_id(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + session._session_id = self.SESSION_ID + self.assertEqual(session.name, self.SESSION_NAME) + + def test_create_w_session_id(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + session._session_id = self.SESSION_ID + with self.assertRaises(ValueError): + session.create() + + def test_create_ok(self): + session_pb = _SessionPB(self.SESSION_NAME) + gax_api = _SpannerApi(_create_session_response=session_pb) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + + session.create() + + self.assertEqual(session.session_id, self.SESSION_ID) + + database_name, options = gax_api._create_session_called_with + self.assertEqual(database_name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_create_error(self): + from google.gax.errors import GaxError + gax_api = _SpannerApi(_random_gax_error=True) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + + with self.assertRaises(GaxError): + session.create() + + database_name, options = gax_api._create_session_called_with + self.assertEqual(database_name, self.DATABASE_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_exists_wo_session_id(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + self.assertFalse(session.exists()) + + def test_exists_hit(self): + session_pb = _SessionPB(self.SESSION_NAME) + gax_api = _SpannerApi(_get_session_response=session_pb) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = self.SESSION_ID + + self.assertTrue(session.exists()) + + session_name, options = gax_api._get_session_called_with + self.assertEqual(session_name, self.SESSION_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_exists_miss(self): + gax_api = _SpannerApi() + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = self.SESSION_ID + + self.assertFalse(session.exists()) + + session_name, options = gax_api._get_session_called_with + self.assertEqual(session_name, self.SESSION_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_exists_error(self): + from google.gax.errors import GaxError + gax_api = _SpannerApi(_random_gax_error=True) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = self.SESSION_ID + + with self.assertRaises(GaxError): + session.exists() + + session_name, options = gax_api._get_session_called_with + self.assertEqual(session_name, self.SESSION_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_delete_wo_session_id(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + with self.assertRaises(ValueError): + session.delete() + + def test_delete_hit(self): + gax_api = _SpannerApi(_delete_session_ok=True) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = self.SESSION_ID + + session.delete() + + session_name, options = gax_api._delete_session_called_with + self.assertEqual(session_name, self.SESSION_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_delete_miss(self): + from google.cloud.exceptions import NotFound + gax_api = _SpannerApi(_delete_session_ok=False) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = self.SESSION_ID + + with self.assertRaises(NotFound): + session.delete() + + session_name, options = gax_api._delete_session_called_with + self.assertEqual(session_name, self.SESSION_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_delete_error(self): + from google.gax.errors import GaxError + gax_api = _SpannerApi(_random_gax_error=True) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = self.SESSION_ID + + with self.assertRaises(GaxError): + session.delete() + + session_name, options = gax_api._delete_session_called_with + self.assertEqual(session_name, self.SESSION_NAME) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_snapshot_not_created(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + + with self.assertRaises(ValueError): + session.snapshot() + + def test_snapshot_created(self): + from google.cloud.spanner.snapshot import Snapshot + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + session._session_id = 'DEADBEEF' # emulate 'session.create()' + + snapshot = session.snapshot() + + self.assertIsInstance(snapshot, Snapshot) + self.assertTrue(snapshot._session is session) + self.assertTrue(snapshot._strong) + + def test_read_not_created(self): + from google.cloud.spanner.keyset import KeySet + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + KEYS = ['bharney@example.com', 'phred@example.com'] + KEYSET = KeySet(keys=KEYS) + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + + with self.assertRaises(ValueError): + session.read(TABLE_NAME, COLUMNS, KEYSET) + + def test_read(self): + from google.cloud.spanner import session as MUT + from google.cloud._testing import _Monkey + from google.cloud.spanner.keyset import KeySet + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + KEYS = ['bharney@example.com', 'phred@example.com'] + KEYSET = KeySet(keys=KEYS) + INDEX = 'email-address-index' + LIMIT = 20 + TOKEN = b'DEADBEEF' + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + _read_with = [] + expected = object() + + class _Snapshot(object): + + def __init__(self, session, **kwargs): + self._session = session + self._kwargs = kwargs.copy() + + def read(self, table, columns, keyset, index='', limit=0, + resume_token=b''): + _read_with.append( + (table, columns, keyset, index, limit, resume_token)) + return expected + + with _Monkey(MUT, Snapshot=_Snapshot): + found = session.read( + TABLE_NAME, COLUMNS, KEYSET, + index=INDEX, limit=LIMIT, resume_token=TOKEN) + + self.assertIs(found, expected) + + self.assertEqual(len(_read_with), 1) + (table, columns, key_set, index, limit, resume_token) = _read_with[0] + + self.assertEqual(table, TABLE_NAME) + self.assertEqual(columns, COLUMNS) + self.assertEqual(key_set, KEYSET) + self.assertEqual(index, INDEX) + self.assertEqual(limit, LIMIT) + self.assertEqual(resume_token, TOKEN) + + def test_execute_sql_not_created(self): + SQL = 'SELECT first_name, age FROM citizens' + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + + with self.assertRaises(ValueError): + session.execute_sql(SQL) + + def test_execute_sql_defaults(self): + from google.cloud.spanner import session as MUT + from google.cloud._testing import _Monkey + SQL = 'SELECT first_name, age FROM citizens' + TOKEN = b'DEADBEEF' + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + _executed_sql_with = [] + expected = object() + + class _Snapshot(object): + + def __init__(self, session, **kwargs): + self._session = session + self._kwargs = kwargs.copy() + + def execute_sql( + self, sql, params=None, param_types=None, query_mode=None, + resume_token=None): + _executed_sql_with.append( + (sql, params, param_types, query_mode, resume_token)) + return expected + + with _Monkey(MUT, Snapshot=_Snapshot): + found = session.execute_sql(SQL, resume_token=TOKEN) + + self.assertIs(found, expected) + + self.assertEqual(len(_executed_sql_with), 1) + sql, params, param_types, query_mode, token = _executed_sql_with[0] + + self.assertEqual(sql, SQL) + self.assertEqual(params, None) + self.assertEqual(param_types, None) + self.assertEqual(query_mode, None) + self.assertEqual(token, TOKEN) + + def test_batch_not_created(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + + with self.assertRaises(ValueError): + session.batch() + + def test_batch_created(self): + from google.cloud.spanner.batch import Batch + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + batch = session.batch() + + self.assertIsInstance(batch, Batch) + self.assertTrue(batch._session is session) + + def test_transaction_not_created(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + + with self.assertRaises(ValueError): + session.transaction() + + def test_transaction_created(self): + from google.cloud.spanner.transaction import Transaction + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + transaction = session.transaction() + + self.assertIsInstance(transaction, Transaction) + self.assertTrue(transaction._session is session) + self.assertTrue(session._transaction is transaction) + + def test_transaction_w_existing_txn(self): + database = _Database(self.DATABASE_NAME) + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + existing = session.transaction() + another = session.transaction() # invalidates existing txn + + self.assertTrue(session._transaction is another) + self.assertTrue(existing._rolled_back) + + def test_retry_transaction_w_commit_error_txn_already_begun(self): + from google.gax.errors import GaxError + from google.cloud.spanner.transaction import Transaction + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ] + gax_api = _SpannerApi( + _commit_error=True, + ) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + begun_txn = session._transaction = Transaction(session) + begun_txn._id = b'FACEDACE' + + called_with = [] + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + + with self.assertRaises(GaxError): + session.run_in_transaction(unit_of_work) + + self.assertEqual(len(called_with), 1) + txn, args, kw = called_with[0] + self.assertIs(txn, begun_txn) + self.assertEqual(txn.committed, None) + self.assertEqual(args, ()) + self.assertEqual(kw, {}) + + def test_run_in_transaction_callback_raises_abort(self): + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + from google.cloud.spanner.transaction import Transaction + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ] + TRANSACTION_ID = b'FACEDACE' + transaction_pb = TransactionPB(id=TRANSACTION_ID) + gax_api = _SpannerApi( + _begin_transaction_response=transaction_pb, + _rollback_response=None, + ) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + called_with = [] + + class Testing(Exception): + pass + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + raise Testing() + + with self.assertRaises(Testing): + session.run_in_transaction(unit_of_work) + + self.assertEqual(len(called_with), 1) + txn, args, kw = called_with[0] + self.assertIsInstance(txn, Transaction) + self.assertIsNone(txn.committed) + self.assertTrue(txn._rolled_back) + self.assertEqual(args, ()) + self.assertEqual(kw, {}) + + def test_run_in_transaction_w_args_w_kwargs_wo_abort(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner.transaction import Transaction + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ] + TRANSACTION_ID = b'FACEDACE' + transaction_pb = TransactionPB(id=TRANSACTION_ID) + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + gax_api = _SpannerApi( + _begin_transaction_response=transaction_pb, + _commit_response=response, + ) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + called_with = [] + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + + committed = session.run_in_transaction( + unit_of_work, 'abc', some_arg='def') + + self.assertEqual(committed, now) + self.assertEqual(len(called_with), 1) + txn, args, kw = called_with[0] + self.assertIsInstance(txn, Transaction) + self.assertEqual(txn.committed, committed) + self.assertEqual(args, ('abc',)) + self.assertEqual(kw, {'some_arg': 'def'}) + + def test_run_in_transaction_w_abort_no_retry_metadata(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner.transaction import Transaction + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ] + TRANSACTION_ID = b'FACEDACE' + transaction_pb = TransactionPB(id=TRANSACTION_ID) + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + gax_api = _SpannerApi( + _begin_transaction_response=transaction_pb, + _commit_abort_count=1, + _commit_response=response, + ) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + called_with = [] + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + + committed = session.run_in_transaction( + unit_of_work, 'abc', some_arg='def') + + self.assertEqual(committed, now) + self.assertEqual(len(called_with), 2) + for index, (txn, args, kw) in enumerate(called_with): + self.assertIsInstance(txn, Transaction) + if index == 1: + self.assertEqual(txn.committed, committed) + else: + self.assertIsNone(txn.committed) + self.assertEqual(args, ('abc',)) + self.assertEqual(kw, {'some_arg': 'def'}) + + def test_run_in_transaction_w_abort_w_retry_metadata(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner.transaction import Transaction + from google.cloud.spanner import session as MUT + from google.cloud._testing import _Monkey + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ] + TRANSACTION_ID = b'FACEDACE' + RETRY_SECONDS = 12 + RETRY_NANOS = 3456 + transaction_pb = TransactionPB(id=TRANSACTION_ID) + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + gax_api = _SpannerApi( + _begin_transaction_response=transaction_pb, + _commit_abort_count=1, + _commit_abort_retry_seconds=RETRY_SECONDS, + _commit_abort_retry_nanos=RETRY_NANOS, + _commit_response=response, + ) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + called_with = [] + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + + time_module = _FauxTimeModule() + + with _Monkey(MUT, time=time_module): + committed = session.run_in_transaction( + unit_of_work, 'abc', some_arg='def') + + self.assertEqual(time_module._slept, + RETRY_SECONDS + RETRY_NANOS / 1.0e9) + self.assertEqual(committed, now) + self.assertEqual(len(called_with), 2) + for index, (txn, args, kw) in enumerate(called_with): + self.assertIsInstance(txn, Transaction) + if index == 1: + self.assertEqual(txn.committed, committed) + else: + self.assertIsNone(txn.committed) + self.assertEqual(args, ('abc',)) + self.assertEqual(kw, {'some_arg': 'def'}) + + def test_run_in_transaction_w_callback_raises_abort_wo_metadata(self): + import datetime + from google.gax.errors import GaxError + from grpc import StatusCode + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner.transaction import Transaction + from google.cloud.spanner import session as MUT + from google.cloud._testing import _Monkey + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ] + TRANSACTION_ID = b'FACEDACE' + RETRY_SECONDS = 1 + RETRY_NANOS = 3456 + transaction_pb = TransactionPB(id=TRANSACTION_ID) + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + gax_api = _SpannerApi( + _begin_transaction_response=transaction_pb, + _commit_abort_retry_seconds=RETRY_SECONDS, + _commit_abort_retry_nanos=RETRY_NANOS, + _commit_response=response, + ) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + called_with = [] + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + if len(called_with) < 2: + grpc_error = gax_api._make_grpc_error( + StatusCode.ABORTED, + trailing=gax_api._trailing_metadata()) + raise GaxError('conflict', grpc_error) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + + time_module = _FauxTimeModule() + + with _Monkey(MUT, time=time_module): + committed = session.run_in_transaction(unit_of_work) + + self.assertEqual(committed, now) + self.assertEqual(time_module._slept, + RETRY_SECONDS + RETRY_NANOS / 1.0e9) + self.assertEqual(len(called_with), 2) + for index, (txn, args, kw) in enumerate(called_with): + self.assertIsInstance(txn, Transaction) + if index == 0: + self.assertIsNone(txn.committed) + else: + self.assertEqual(txn.committed, now) + self.assertEqual(args, ()) + self.assertEqual(kw, {}) + + def test_run_in_transaction_w_abort_w_retry_metadata_deadline(self): + import datetime + from google.gax.errors import GaxError + from google.gax.grpc import exc_to_code + from grpc import StatusCode + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner.transaction import Transaction + from google.cloud.spanner import session as MUT + from google.cloud._testing import _Monkey + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ] + TRANSACTION_ID = b'FACEDACE' + RETRY_SECONDS = 1 + RETRY_NANOS = 3456 + transaction_pb = TransactionPB(id=TRANSACTION_ID) + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + gax_api = _SpannerApi( + _begin_transaction_response=transaction_pb, + _commit_abort_count=1, + _commit_abort_retry_seconds=RETRY_SECONDS, + _commit_abort_retry_nanos=RETRY_NANOS, + _commit_response=response, + ) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + called_with = [] + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + + time_module = _FauxTimeModule() + + with _Monkey(MUT, time=time_module): + with self.assertRaises(GaxError) as exc: + session.run_in_transaction( + unit_of_work, 'abc', some_arg='def', timeout_secs=0.01) + + self.assertEqual(exc_to_code(exc.exception.cause), StatusCode.ABORTED) + self.assertIsNone(time_module._slept) + self.assertEqual(len(called_with), 1) + txn, args, kw = called_with[0] + self.assertIsInstance(txn, Transaction) + self.assertIsNone(txn.committed) + self.assertEqual(args, ('abc',)) + self.assertEqual(kw, {'some_arg': 'def'}) + + def test_run_in_transaction_w_timeout(self): + from google.gax.errors import GaxError + from google.gax.grpc import exc_to_code + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + from grpc import StatusCode + from google.cloud.spanner.transaction import Transaction + TABLE_NAME = 'citizens' + COLUMNS = ['email', 'first_name', 'last_name', 'age'] + VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], + ] + TRANSACTION_ID = b'FACEDACE' + transaction_pb = TransactionPB(id=TRANSACTION_ID) + gax_api = _SpannerApi( + _begin_transaction_response=transaction_pb, + _commit_abort_count=1e6, + ) + database = _Database(self.DATABASE_NAME) + database.spanner_api = gax_api + session = self._makeOne(database) + session._session_id = 'DEADBEEF' + + called_with = [] + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + + with self.assertRaises(GaxError) as exc: + session.run_in_transaction(unit_of_work, timeout_secs=0.01) + + self.assertEqual(exc_to_code(exc.exception.cause), StatusCode.ABORTED) + + self.assertGreater(len(called_with), 1) + for txn, args, kw in called_with: + self.assertIsInstance(txn, Transaction) + self.assertIsNone(txn.committed) + self.assertEqual(args, ()) + self.assertEqual(kw, {}) + + +class _Database(object): + + def __init__(self, name): + self.name = name + + +class _SpannerApi(_GAXBaseAPI): + + _commit_abort_count = 0 + _commit_abort_retry_seconds = None + _commit_abort_retry_nanos = None + _random_gax_error = _commit_error = False + + def create_session(self, database, options=None): + from google.gax.errors import GaxError + self._create_session_called_with = database, options + if self._random_gax_error: + raise GaxError('error') + return self._create_session_response + + def get_session(self, name, options=None): + from google.gax.errors import GaxError + self._get_session_called_with = name, options + if self._random_gax_error: + raise GaxError('error') + try: + return self._get_session_response + except AttributeError: + raise GaxError('miss', self._make_grpc_not_found()) + + def delete_session(self, name, options=None): + from google.gax.errors import GaxError + self._delete_session_called_with = name, options + if self._random_gax_error: + raise GaxError('error') + if not self._delete_session_ok: + raise GaxError('miss', self._make_grpc_not_found()) + + def begin_transaction(self, session, options_, options=None): + self._begun = (session, options_, options) + return self._begin_transaction_response + + def _trailing_metadata(self): + from google.protobuf.duration_pb2 import Duration + from google.rpc.error_details_pb2 import RetryInfo + from grpc._common import cygrpc_metadata + if self._commit_abort_retry_nanos is None: + return cygrpc_metadata(()) + retry_info = RetryInfo( + retry_delay=Duration( + seconds=self._commit_abort_retry_seconds, + nanos=self._commit_abort_retry_nanos)) + return cygrpc_metadata([ + ('google.rpc.retryinfo-bin', retry_info.SerializeToString())]) + + def commit(self, session, mutations, + transaction_id='', single_use_transaction=None, options=None): + from grpc import StatusCode + from google.gax.errors import GaxError + assert single_use_transaction is None + self._committed = (session, mutations, transaction_id, options) + if self._commit_error: + raise GaxError('error', self._make_grpc_error(StatusCode.UNKNOWN)) + if self._commit_abort_count > 0: + self._commit_abort_count -= 1 + grpc_error = self._make_grpc_error( + StatusCode.ABORTED, trailing=self._trailing_metadata()) + raise GaxError('conflict', grpc_error) + return self._commit_response + + def rollback(self, session, transaction_id, options=None): + self._rolled_back = (session, transaction_id, options) + return self._rollback_response + + +class _SessionPB(object): + + def __init__(self, name): + self.name = name + + +class _FauxTimeModule(object): + + _slept = None + + def time(self): + import time + return time.time() + + def sleep(self, seconds): + self._slept = seconds diff --git a/spanner/unit_tests/test_snapshot.py b/spanner/unit_tests/test_snapshot.py new file mode 100644 index 000000000000..3e8fe26583ef --- /dev/null +++ b/spanner/unit_tests/test_snapshot.py @@ -0,0 +1,460 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from google.cloud._testing import _GAXBaseAPI + + +TABLE_NAME = 'citizens' +COLUMNS = ['email', 'first_name', 'last_name', 'age'] +SQL_QUERY = """\ +SELECT first_name, last_name, age FROM citizens ORDER BY age""" +SQL_QUERY_WITH_PARAM = """ +SELECT first_name, last_name, email FROM citizens WHERE age <= @max_age""" +PARAMS = {'max_age': 30} +PARAM_TYPES = {'max_age': 'INT64'} +SQL_QUERY_WITH_BYTES_PARAM = """\ +SELECT image_name FROM images WHERE @bytes IN image_data""" +PARAMS_WITH_BYTES = {'bytes': b'DEADBEEF'} + + +class Test_SnapshotBase(unittest.TestCase): + + PROJECT_ID = 'project-id' + INSTANCE_ID = 'instance-id' + INSTANCE_NAME = 'projects/' + PROJECT_ID + '/instances/' + INSTANCE_ID + DATABASE_ID = 'database-id' + DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID + SESSION_ID = 'session-id' + SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID + + def _getTargetClass(self): + from google.cloud.spanner.snapshot import _SnapshotBase + return _SnapshotBase + + def _makeOne(self, session): + return self._getTargetClass()(session) + + def _makeDerived(self, session): + + class _Derived(self._getTargetClass()): + + def _make_txn_selector(self): + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionOptions, TransactionSelector) + options = TransactionOptions( + read_only=TransactionOptions.ReadOnly(strong=True)) + return TransactionSelector(single_use=options) + + return _Derived(session) + + def test_ctor(self): + session = _Session() + base = self._makeOne(session) + self.assertTrue(base._session is session) + + def test__make_txn_selector_virtual(self): + session = _Session() + base = self._makeOne(session) + with self.assertRaises(NotImplementedError): + base._make_txn_selector() + + def test_read_grpc_error(self): + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionSelector) + from google.gax.errors import GaxError + from google.cloud.spanner.keyset import KeySet + KEYSET = KeySet(all_=True) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _random_gax_error=True) + session = _Session(database) + derived = self._makeDerived(session) + + with self.assertRaises(GaxError): + derived.read(TABLE_NAME, COLUMNS, KEYSET) + + (r_session, table, columns, key_set, transaction, index, + limit, resume_token, options) = api._streaming_read_with + + self.assertEqual(r_session, self.SESSION_NAME) + self.assertTrue(transaction.single_use.read_only.strong) + self.assertEqual(table, TABLE_NAME) + self.assertEqual(columns, COLUMNS) + self.assertEqual(key_set, KEYSET.to_pb()) + self.assertIsInstance(transaction, TransactionSelector) + self.assertEqual(index, '') + self.assertEqual(limit, 0) + self.assertEqual(resume_token, b'') + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_read_normal(self): + from google.protobuf.struct_pb2 import Struct + from google.cloud.proto.spanner.v1.result_set_pb2 import ( + PartialResultSet, ResultSetMetadata, ResultSetStats) + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionSelector) + from google.cloud.proto.spanner.v1.type_pb2 import Type, StructType + from google.cloud.proto.spanner.v1.type_pb2 import STRING, INT64 + from google.cloud.spanner.keyset import KeySet + from google.cloud.spanner._helpers import _make_value_pb + VALUES = [ + [u'bharney', 31], + [u'phred', 32], + ] + VALUE_PBS = [ + [_make_value_pb(item) for item in row] + for row in VALUES + ] + struct_type_pb = StructType(fields=[ + StructType.Field(name='name', type=Type(code=STRING)), + StructType.Field(name='age', type=Type(code=INT64)), + ]) + metadata_pb = ResultSetMetadata(row_type=struct_type_pb) + stats_pb = ResultSetStats( + query_stats=Struct(fields={ + 'rows_returned': _make_value_pb(2), + })) + result_sets = [ + PartialResultSet(values=VALUE_PBS[0], metadata=metadata_pb), + PartialResultSet(values=VALUE_PBS[1], stats=stats_pb), + ] + KEYS = ['bharney@example.com', 'phred@example.com'] + KEYSET = KeySet(keys=KEYS) + INDEX = 'email-address-index' + LIMIT = 20 + TOKEN = b'DEADBEEF' + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _streaming_read_response=_MockCancellableIterator(*result_sets)) + session = _Session(database) + derived = self._makeDerived(session) + + result_set = derived.read( + TABLE_NAME, COLUMNS, KEYSET, + index=INDEX, limit=LIMIT, resume_token=TOKEN) + + result_set.consume_all() + self.assertEqual(list(result_set.rows), VALUES) + self.assertEqual(result_set.metadata, metadata_pb) + self.assertEqual(result_set.stats, stats_pb) + + (r_session, table, columns, key_set, transaction, index, + limit, resume_token, options) = api._streaming_read_with + + self.assertEqual(r_session, self.SESSION_NAME) + self.assertEqual(table, TABLE_NAME) + self.assertEqual(columns, COLUMNS) + self.assertEqual(key_set, KEYSET.to_pb()) + self.assertIsInstance(transaction, TransactionSelector) + self.assertTrue(transaction.single_use.read_only.strong) + self.assertEqual(index, INDEX) + self.assertEqual(limit, LIMIT) + self.assertEqual(resume_token, TOKEN) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_execute_sql_grpc_error(self): + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionSelector) + from google.gax.errors import GaxError + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _random_gax_error=True) + session = _Session(database) + derived = self._makeDerived(session) + + with self.assertRaises(GaxError): + derived.execute_sql(SQL_QUERY) + + (r_session, sql, transaction, params, param_types, + resume_token, query_mode, options) = api._executed_streaming_sql_with + + self.assertEqual(r_session, self.SESSION_NAME) + self.assertEqual(sql, SQL_QUERY) + self.assertIsInstance(transaction, TransactionSelector) + self.assertTrue(transaction.single_use.read_only.strong) + self.assertEqual(params, None) + self.assertEqual(param_types, None) + self.assertEqual(resume_token, b'') + self.assertEqual(query_mode, None) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_execute_sql_w_params_wo_param_types(self): + database = _Database() + session = _Session(database) + derived = self._makeDerived(session) + + with self.assertRaises(ValueError): + derived.execute_sql(SQL_QUERY_WITH_PARAM, PARAMS) + + def test_execute_sql_normal(self): + from google.protobuf.struct_pb2 import Struct + from google.cloud.proto.spanner.v1.result_set_pb2 import ( + PartialResultSet, ResultSetMetadata, ResultSetStats) + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + TransactionSelector) + from google.cloud.proto.spanner.v1.type_pb2 import Type, StructType + from google.cloud.proto.spanner.v1.type_pb2 import STRING, INT64 + from google.cloud.spanner._helpers import _make_value_pb + VALUES = [ + [u'bharney', u'rhubbyl', 31], + [u'phred', u'phlyntstone', 32], + ] + VALUE_PBS = [ + [_make_value_pb(item) for item in row] + for row in VALUES + ] + MODE = 2 # PROFILE + TOKEN = b'DEADBEEF' + struct_type_pb = StructType(fields=[ + StructType.Field(name='first_name', type=Type(code=STRING)), + StructType.Field(name='last_name', type=Type(code=STRING)), + StructType.Field(name='age', type=Type(code=INT64)), + ]) + metadata_pb = ResultSetMetadata(row_type=struct_type_pb) + stats_pb = ResultSetStats( + query_stats=Struct(fields={ + 'rows_returned': _make_value_pb(2), + })) + result_sets = [ + PartialResultSet(values=VALUE_PBS[0], metadata=metadata_pb), + PartialResultSet(values=VALUE_PBS[1], stats=stats_pb), + ] + iterator = _MockCancellableIterator(*result_sets) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _execute_streaming_sql_response=iterator) + session = _Session(database) + derived = self._makeDerived(session) + + result_set = derived.execute_sql( + SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, + query_mode=MODE, resume_token=TOKEN) + + result_set.consume_all() + self.assertEqual(list(result_set.rows), VALUES) + self.assertEqual(result_set.metadata, metadata_pb) + self.assertEqual(result_set.stats, stats_pb) + + (r_session, sql, transaction, params, param_types, + resume_token, query_mode, options) = api._executed_streaming_sql_with + + self.assertEqual(r_session, self.SESSION_NAME) + self.assertEqual(sql, SQL_QUERY_WITH_PARAM) + self.assertIsInstance(transaction, TransactionSelector) + self.assertTrue(transaction.single_use.read_only.strong) + expected_params = Struct(fields={ + key: _make_value_pb(value) for (key, value) in PARAMS.items()}) + self.assertEqual(params, expected_params) + self.assertEqual(param_types, PARAM_TYPES) + self.assertEqual(query_mode, MODE) + self.assertEqual(resume_token, TOKEN) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + +class _MockCancellableIterator(object): + + cancel_calls = 0 + + def __init__(self, *values): + self.iter_values = iter(values) + + def next(self): + return next(self.iter_values) + + def __next__(self): # pragma: NO COVER Py3k + return self.next() + + +class TestSnapshot(unittest.TestCase): + + PROJECT_ID = 'project-id' + INSTANCE_ID = 'instance-id' + INSTANCE_NAME = 'projects/' + PROJECT_ID + '/instances/' + INSTANCE_ID + DATABASE_ID = 'database-id' + DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID + SESSION_ID = 'session-id' + SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID + + def _getTargetClass(self): + from google.cloud.spanner.snapshot import Snapshot + return Snapshot + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def _makeTimestamp(self): + import datetime + from google.cloud._helpers import UTC + return datetime.datetime.utcnow().replace(tzinfo=UTC) + + def _makeDuration(self, seconds=1, microseconds=0): + import datetime + return datetime.timedelta(seconds=seconds, microseconds=microseconds) + + def test_ctor_defaults(self): + session = _Session() + snapshot = self._makeOne(session) + self.assertTrue(snapshot._session is session) + self.assertTrue(snapshot._strong) + self.assertIsNone(snapshot._read_timestamp) + self.assertIsNone(snapshot._min_read_timestamp) + self.assertIsNone(snapshot._max_staleness) + self.assertIsNone(snapshot._exact_staleness) + + def test_ctor_w_multiple_options(self): + timestamp = self._makeTimestamp() + duration = self._makeDuration() + session = _Session() + + with self.assertRaises(ValueError): + self._makeOne( + session, read_timestamp=timestamp, max_staleness=duration) + + def test_ctor_w_read_timestamp(self): + timestamp = self._makeTimestamp() + session = _Session() + snapshot = self._makeOne(session, read_timestamp=timestamp) + self.assertTrue(snapshot._session is session) + self.assertFalse(snapshot._strong) + self.assertEqual(snapshot._read_timestamp, timestamp) + self.assertIsNone(snapshot._min_read_timestamp) + self.assertIsNone(snapshot._max_staleness) + self.assertIsNone(snapshot._exact_staleness) + + def test_ctor_w_min_read_timestamp(self): + timestamp = self._makeTimestamp() + session = _Session() + snapshot = self._makeOne(session, min_read_timestamp=timestamp) + self.assertTrue(snapshot._session is session) + self.assertFalse(snapshot._strong) + self.assertIsNone(snapshot._read_timestamp) + self.assertEqual(snapshot._min_read_timestamp, timestamp) + self.assertIsNone(snapshot._max_staleness) + self.assertIsNone(snapshot._exact_staleness) + + def test_ctor_w_max_staleness(self): + duration = self._makeDuration() + session = _Session() + snapshot = self._makeOne(session, max_staleness=duration) + self.assertTrue(snapshot._session is session) + self.assertFalse(snapshot._strong) + self.assertIsNone(snapshot._read_timestamp) + self.assertIsNone(snapshot._min_read_timestamp) + self.assertEqual(snapshot._max_staleness, duration) + self.assertIsNone(snapshot._exact_staleness) + + def test_ctor_w_exact_staleness(self): + duration = self._makeDuration() + session = _Session() + snapshot = self._makeOne(session, exact_staleness=duration) + self.assertTrue(snapshot._session is session) + self.assertFalse(snapshot._strong) + self.assertIsNone(snapshot._read_timestamp) + self.assertIsNone(snapshot._min_read_timestamp) + self.assertIsNone(snapshot._max_staleness) + self.assertEqual(snapshot._exact_staleness, duration) + + def test__make_txn_selector_strong(self): + session = _Session() + snapshot = self._makeOne(session) + selector = snapshot._make_txn_selector() + options = selector.single_use + self.assertTrue(options.read_only.strong) + + def test__make_txn_selector_w_read_timestamp(self): + from google.cloud._helpers import _pb_timestamp_to_datetime + timestamp = self._makeTimestamp() + session = _Session() + snapshot = self._makeOne(session, read_timestamp=timestamp) + selector = snapshot._make_txn_selector() + options = selector.single_use + self.assertEqual( + _pb_timestamp_to_datetime(options.read_only.read_timestamp), + timestamp) + + def test__make_txn_selector_w_min_read_timestamp(self): + from google.cloud._helpers import _pb_timestamp_to_datetime + timestamp = self._makeTimestamp() + session = _Session() + snapshot = self._makeOne(session, min_read_timestamp=timestamp) + selector = snapshot._make_txn_selector() + options = selector.single_use + self.assertEqual( + _pb_timestamp_to_datetime(options.read_only.min_read_timestamp), + timestamp) + + def test__make_txn_selector_w_max_staleness(self): + duration = self._makeDuration(seconds=3, microseconds=123456) + session = _Session() + snapshot = self._makeOne(session, max_staleness=duration) + selector = snapshot._make_txn_selector() + options = selector.single_use + self.assertEqual(options.read_only.max_staleness.seconds, 3) + self.assertEqual(options.read_only.max_staleness.nanos, 123456000) + + def test__make_txn_selector_w_exact_staleness(self): + duration = self._makeDuration(seconds=3, microseconds=123456) + session = _Session() + snapshot = self._makeOne(session, exact_staleness=duration) + selector = snapshot._make_txn_selector() + options = selector.single_use + self.assertEqual(options.read_only.exact_staleness.seconds, 3) + self.assertEqual(options.read_only.exact_staleness.nanos, 123456000) + + +class _Session(object): + + def __init__(self, database=None, name=TestSnapshot.SESSION_NAME): + self._database = database + self.name = name + + +class _Database(object): + name = 'testing' + + +class _FauxSpannerAPI(_GAXBaseAPI): + + _read_with = None + + # pylint: disable=too-many-arguments + def streaming_read(self, session, table, columns, key_set, + transaction=None, index='', limit=0, + resume_token='', options=None): + from google.gax.errors import GaxError + self._streaming_read_with = ( + session, table, columns, key_set, transaction, index, + limit, resume_token, options) + if self._random_gax_error: + raise GaxError('error') + return self._streaming_read_response + # pylint: enable=too-many-arguments + + def execute_streaming_sql(self, session, sql, transaction=None, + params=None, param_types=None, + resume_token='', query_mode=None, options=None): + from google.gax.errors import GaxError + self._executed_streaming_sql_with = ( + session, sql, transaction, params, param_types, resume_token, + query_mode, options) + if self._random_gax_error: + raise GaxError('error') + return self._execute_streaming_sql_response diff --git a/spanner/unit_tests/test_streamed.py b/spanner/unit_tests/test_streamed.py new file mode 100644 index 000000000000..115eda9b96f0 --- /dev/null +++ b/spanner/unit_tests/test_streamed.py @@ -0,0 +1,966 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + + +class TestStreamedResultSet(unittest.TestCase): + + def _getTargetClass(self): + from google.cloud.spanner.streamed import StreamedResultSet + return StreamedResultSet + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_defaults(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + self.assertIs(streamed._response_iterator, iterator) + self.assertEqual(streamed.rows, []) + self.assertIsNone(streamed.metadata) + self.assertIsNone(streamed.stats) + self.assertIsNone(streamed.resume_token) + + def test_fields_unset(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + with self.assertRaises(AttributeError): + _ = streamed.fields + + @staticmethod + def _makeScalarField(name, type_): + from google.cloud.proto.spanner.v1.type_pb2 import StructType + from google.cloud.proto.spanner.v1.type_pb2 import Type + return StructType.Field(name=name, type=Type(code=type_)) + + @staticmethod + def _makeArrayField(name, element_type_code=None, element_type=None): + from google.cloud.proto.spanner.v1.type_pb2 import StructType + from google.cloud.proto.spanner.v1.type_pb2 import Type + if element_type is None: + element_type = Type(code=element_type_code) + array_type = Type( + code='ARRAY', array_element_type=element_type) + return StructType.Field(name=name, type=array_type) + + @staticmethod + def _makeStructType(struct_type_fields): + from google.cloud.proto.spanner.v1.type_pb2 import StructType + from google.cloud.proto.spanner.v1.type_pb2 import Type + fields = [ + StructType.Field(name=key, type=Type(code=value)) + for key, value in struct_type_fields + ] + struct_type = StructType(fields=fields) + return Type(code='STRUCT', struct_type=struct_type) + + @staticmethod + def _makeValue(value): + from google.cloud.spanner._helpers import _make_value_pb + return _make_value_pb(value) + + @staticmethod + def _makeListValue(values=(), value_pbs=None): + from google.protobuf.struct_pb2 import ListValue + from google.protobuf.struct_pb2 import Value + from google.cloud.spanner._helpers import _make_list_value_pb + if value_pbs is not None: + return Value(list_value=ListValue(values=value_pbs)) + return Value(list_value=_make_list_value_pb(values)) + + def test_properties_set(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + ] + metadata = streamed._metadata = _ResultSetMetadataPB(FIELDS) + stats = streamed._stats = _ResultSetStatsPB() + self.assertEqual(list(streamed.fields), FIELDS) + self.assertIs(streamed.metadata, metadata) + self.assertIs(streamed.stats, stats) + + def test__merge_chunk_bool(self): + from google.cloud.spanner.streamed import Unmergeable + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('registered_voter', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeValue(True) + chunk = self._makeValue(False) + + with self.assertRaises(Unmergeable): + streamed._merge_chunk(chunk) + + def test__merge_chunk_int64(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('age', 'INT64'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeValue(42) + chunk = self._makeValue(13) + + merged = streamed._merge_chunk(chunk) + self.assertEqual(merged.string_value, '4213') + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_float64_nan_string(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('weight', 'FLOAT64'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeValue(u'Na') + chunk = self._makeValue(u'N') + + merged = streamed._merge_chunk(chunk) + self.assertEqual(merged.string_value, u'NaN') + + def test__merge_chunk_float64_w_empty(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('weight', 'FLOAT64'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeValue(3.14159) + chunk = self._makeValue('') + + merged = streamed._merge_chunk(chunk) + self.assertEqual(merged.number_value, 3.14159) + + def test__merge_chunk_float64_w_float64(self): + from google.cloud.spanner.streamed import Unmergeable + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('weight', 'FLOAT64'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeValue(3.14159) + chunk = self._makeValue(2.71828) + + with self.assertRaises(Unmergeable): + streamed._merge_chunk(chunk) + + def test__merge_chunk_string(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('name', 'STRING'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeValue(u'phred') + chunk = self._makeValue(u'wylma') + + merged = streamed._merge_chunk(chunk) + + self.assertEqual(merged.string_value, u'phredwylma') + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_bool(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeArrayField('name', element_type_code='BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeListValue([True, True]) + chunk = self._makeListValue([False, False, False]) + + merged = streamed._merge_chunk(chunk) + + expected = self._makeListValue([True, True, False, False, False]) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_int(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeArrayField('name', element_type_code='INT64'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeListValue([0, 1, 2]) + chunk = self._makeListValue([3, 4, 5]) + + merged = streamed._merge_chunk(chunk) + + expected = self._makeListValue([0, 1, 23, 4, 5]) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_float(self): + import math + PI = math.pi + EULER = math.e + SQRT_2 = math.sqrt(2.0) + LOG_10 = math.log(10) + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeArrayField('name', element_type_code='FLOAT64'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeListValue([PI, SQRT_2]) + chunk = self._makeListValue(['', EULER, LOG_10]) + + merged = streamed._merge_chunk(chunk) + + expected = self._makeListValue([PI, SQRT_2, EULER, LOG_10]) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_string(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeArrayField('name', element_type_code='STRING'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeListValue([u'A', u'B', u'C']) + chunk = self._makeListValue([None, u'D', u'E']) + + merged = streamed._merge_chunk(chunk) + + expected = self._makeListValue([u'A', u'B', u'C', None, u'D', u'E']) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_string_with_null(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeArrayField('name', element_type_code='STRING'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeListValue([u'A', u'B', u'C']) + chunk = self._makeListValue([u'D', u'E']) + + merged = streamed._merge_chunk(chunk) + + expected = self._makeListValue([u'A', u'B', u'CD', u'E']) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_array_of_int(self): + from google.cloud.proto.spanner.v1.type_pb2 import StructType + from google.cloud.proto.spanner.v1.type_pb2 import Type + subarray_type = Type( + code='ARRAY', array_element_type=Type(code='INT64')) + array_type = Type(code='ARRAY', array_element_type=subarray_type) + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + StructType.Field(name='loloi', type=array_type) + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeListValue(value_pbs=[ + self._makeListValue([0, 1]), + self._makeListValue([2]), + ]) + chunk = self._makeListValue(value_pbs=[ + self._makeListValue([3]), + self._makeListValue([4, 5]), + ]) + + merged = streamed._merge_chunk(chunk) + + expected = self._makeListValue(value_pbs=[ + self._makeListValue([0, 1]), + self._makeListValue([23]), + self._makeListValue([4, 5]), + ]) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_array_of_string(self): + from google.cloud.proto.spanner.v1.type_pb2 import StructType + from google.cloud.proto.spanner.v1.type_pb2 import Type + subarray_type = Type( + code='ARRAY', array_element_type=Type(code='STRING')) + array_type = Type(code='ARRAY', array_element_type=subarray_type) + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + StructType.Field(name='lolos', type=array_type) + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeListValue(value_pbs=[ + self._makeListValue([u'A', u'B']), + self._makeListValue([u'C']), + ]) + chunk = self._makeListValue(value_pbs=[ + self._makeListValue([u'D']), + self._makeListValue([u'E', u'F']), + ]) + + merged = streamed._merge_chunk(chunk) + + expected = self._makeListValue(value_pbs=[ + self._makeListValue([u'A', u'B']), + self._makeListValue([u'CD']), + self._makeListValue([u'E', u'F']), + ]) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_struct(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + struct_type = self._makeStructType([ + ('name', 'STRING'), + ('age', 'INT64'), + ]) + FIELDS = [ + self._makeArrayField('test', element_type=struct_type), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + partial = self._makeListValue([u'Phred ']) + streamed._pending_chunk = self._makeListValue(value_pbs=[partial]) + rest = self._makeListValue([u'Phlyntstone', 31]) + chunk = self._makeListValue(value_pbs=[rest]) + + merged = streamed._merge_chunk(chunk) + + struct = self._makeListValue([u'Phred Phlyntstone', 31]) + expected = self._makeListValue(value_pbs=[struct]) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test__merge_chunk_array_of_struct_unmergeable(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + struct_type = self._makeStructType([ + ('name', 'STRING'), + ('registered', 'BOOL'), + ('voted', 'BOOL'), + ]) + FIELDS = [ + self._makeArrayField('test', element_type=struct_type), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + partial = self._makeListValue([u'Phred Phlyntstone', True]) + streamed._pending_chunk = self._makeListValue(value_pbs=[partial]) + rest = self._makeListValue([True]) + chunk = self._makeListValue(value_pbs=[rest]) + + merged = streamed._merge_chunk(chunk) + + struct = self._makeListValue([u'Phred Phlyntstone', True, True]) + expected = self._makeListValue(value_pbs=[struct]) + self.assertEqual(merged, expected) + self.assertIsNone(streamed._pending_chunk) + + def test_merge_values_empty_and_empty(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._current_row = [] + streamed._merge_values([]) + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, []) + + def test_merge_values_empty_and_partial(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + BARE = [u'Phred Phlyntstone', 42] + VALUES = [self._makeValue(bare) for bare in BARE] + streamed._current_row = [] + streamed._merge_values(VALUES) + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, BARE) + + def test_merge_values_empty_and_filled(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + BARE = [u'Phred Phlyntstone', 42, True] + VALUES = [self._makeValue(bare) for bare in BARE] + streamed._current_row = [] + streamed._merge_values(VALUES) + self.assertEqual(streamed.rows, [BARE]) + self.assertEqual(streamed._current_row, []) + + def test_merge_values_empty_and_filled_plus(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + BARE = [ + u'Phred Phlyntstone', 42, True, + u'Bharney Rhubble', 39, True, + u'Wylma Phlyntstone', + ] + VALUES = [self._makeValue(bare) for bare in BARE] + streamed._current_row = [] + streamed._merge_values(VALUES) + self.assertEqual(streamed.rows, [BARE[0:3], BARE[3:6]]) + self.assertEqual(streamed._current_row, BARE[6:]) + + def test_merge_values_partial_and_empty(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + BEFORE = [ + u'Phred Phlyntstone' + ] + streamed._current_row[:] = BEFORE + streamed._merge_values([]) + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, BEFORE) + + def test_merge_values_partial_and_partial(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + BEFORE = [u'Phred Phlyntstone'] + streamed._current_row[:] = BEFORE + MERGED = [42] + TO_MERGE = [self._makeValue(item) for item in MERGED] + streamed._merge_values(TO_MERGE) + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, BEFORE + MERGED) + + def test_merge_values_partial_and_filled(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + BEFORE = [ + u'Phred Phlyntstone' + ] + streamed._current_row[:] = BEFORE + MERGED = [42, True] + TO_MERGE = [self._makeValue(item) for item in MERGED] + streamed._merge_values(TO_MERGE) + self.assertEqual(streamed.rows, [BEFORE + MERGED]) + self.assertEqual(streamed._current_row, []) + + def test_merge_values_partial_and_filled_plus(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + streamed._metadata = _ResultSetMetadataPB(FIELDS) + BEFORE = [ + self._makeValue(u'Phred Phlyntstone') + ] + streamed._current_row[:] = BEFORE + MERGED = [ + 42, True, + u'Bharney Rhubble', 39, True, + u'Wylma Phlyntstone', + ] + TO_MERGE = [self._makeValue(item) for item in MERGED] + VALUES = BEFORE + MERGED + streamed._merge_values(TO_MERGE) + self.assertEqual(streamed.rows, [VALUES[0:3], VALUES[3:6]]) + self.assertEqual(streamed._current_row, VALUES[6:]) + + def test_consume_next_empty(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + with self.assertRaises(StopIteration): + streamed.consume_next() + + def test_consume_next_first_set_partial(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + metadata = _ResultSetMetadataPB(FIELDS) + BARE = [u'Phred Phlyntstone', 42] + VALUES = [self._makeValue(bare) for bare in BARE] + result_set = _PartialResultSetPB(VALUES, metadata=metadata) + iterator = _MockCancellableIterator(result_set) + streamed = self._makeOne(iterator) + streamed.consume_next() + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, BARE) + self.assertTrue(streamed.metadata is metadata) + self.assertEqual(streamed.resume_token, result_set.resume_token) + + def test_consume_next_w_partial_result(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + VALUES = [ + self._makeValue(u'Phred '), + ] + result_set = _PartialResultSetPB(VALUES, chunked_value=True) + iterator = _MockCancellableIterator(result_set) + streamed = self._makeOne(iterator) + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed.consume_next() + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, []) + self.assertEqual(streamed._pending_chunk, VALUES[0]) + self.assertEqual(streamed.resume_token, result_set.resume_token) + + def test_consume_next_w_pending_chunk(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + BARE = [ + u'Phlyntstone', 42, True, + u'Bharney Rhubble', 39, True, + u'Wylma Phlyntstone', + ] + VALUES = [self._makeValue(bare) for bare in BARE] + result_set = _PartialResultSetPB(VALUES) + iterator = _MockCancellableIterator(result_set) + streamed = self._makeOne(iterator) + streamed._metadata = _ResultSetMetadataPB(FIELDS) + streamed._pending_chunk = self._makeValue(u'Phred ') + streamed.consume_next() + self.assertEqual(streamed.rows, [ + [u'Phred Phlyntstone', BARE[1], BARE[2]], + [BARE[3], BARE[4], BARE[5]], + ]) + self.assertEqual(streamed._current_row, [BARE[6]]) + self.assertIsNone(streamed._pending_chunk) + self.assertEqual(streamed.resume_token, result_set.resume_token) + + def test_consume_next_last_set(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + metadata = _ResultSetMetadataPB(FIELDS) + stats = _ResultSetStatsPB( + rows_returned="1", + elapsed_time="1.23 secs", + cpu_tme="0.98 secs", + ) + BARE = [u'Phred Phlyntstone', 42, True] + VALUES = [self._makeValue(bare) for bare in BARE] + result_set = _PartialResultSetPB(VALUES, stats=stats) + iterator = _MockCancellableIterator(result_set) + streamed = self._makeOne(iterator) + streamed._metadata = metadata + streamed.consume_next() + self.assertEqual(streamed.rows, [BARE]) + self.assertEqual(streamed._current_row, []) + self.assertTrue(streamed._stats is stats) + self.assertEqual(streamed.resume_token, result_set.resume_token) + + def test_consume_all_empty(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + streamed.consume_all() + + def test_consume_all_one_result_set_partial(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + metadata = _ResultSetMetadataPB(FIELDS) + BARE = [u'Phred Phlyntstone', 42] + VALUES = [self._makeValue(bare) for bare in BARE] + result_set = _PartialResultSetPB(VALUES, metadata=metadata) + iterator = _MockCancellableIterator(result_set) + streamed = self._makeOne(iterator) + streamed.consume_all() + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, BARE) + self.assertTrue(streamed.metadata is metadata) + + def test_consume_all_multiple_result_sets_filled(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + metadata = _ResultSetMetadataPB(FIELDS) + BARE = [ + u'Phred Phlyntstone', 42, True, + u'Bharney Rhubble', 39, True, + u'Wylma Phlyntstone', 41, True, + ] + VALUES = [self._makeValue(bare) for bare in BARE] + result_set1 = _PartialResultSetPB(VALUES[:4], metadata=metadata) + result_set2 = _PartialResultSetPB(VALUES[4:]) + iterator = _MockCancellableIterator(result_set1, result_set2) + streamed = self._makeOne(iterator) + streamed.consume_all() + self.assertEqual(streamed.rows, [ + [BARE[0], BARE[1], BARE[2]], + [BARE[3], BARE[4], BARE[5]], + [BARE[6], BARE[7], BARE[8]], + ]) + self.assertEqual(streamed._current_row, []) + self.assertIsNone(streamed._pending_chunk) + + def test___iter___empty(self): + iterator = _MockCancellableIterator() + streamed = self._makeOne(iterator) + found = list(streamed) + self.assertEqual(found, []) + + def test___iter___one_result_set_partial(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + metadata = _ResultSetMetadataPB(FIELDS) + BARE = [u'Phred Phlyntstone', 42] + VALUES = [self._makeValue(bare) for bare in BARE] + result_set = _PartialResultSetPB(VALUES, metadata=metadata) + iterator = _MockCancellableIterator(result_set) + streamed = self._makeOne(iterator) + found = list(streamed) + self.assertEqual(found, []) + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, BARE) + self.assertTrue(streamed.metadata is metadata) + + def test___iter___multiple_result_sets_filled(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + metadata = _ResultSetMetadataPB(FIELDS) + BARE = [ + u'Phred Phlyntstone', 42, True, + u'Bharney Rhubble', 39, True, + u'Wylma Phlyntstone', 41, True, + ] + VALUES = [self._makeValue(bare) for bare in BARE] + result_set1 = _PartialResultSetPB(VALUES[:4], metadata=metadata) + result_set2 = _PartialResultSetPB(VALUES[4:]) + iterator = _MockCancellableIterator(result_set1, result_set2) + streamed = self._makeOne(iterator) + found = list(streamed) + self.assertEqual(found, [ + [BARE[0], BARE[1], BARE[2]], + [BARE[3], BARE[4], BARE[5]], + [BARE[6], BARE[7], BARE[8]], + ]) + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, []) + self.assertIsNone(streamed._pending_chunk) + + def test___iter___w_existing_rows_read(self): + FIELDS = [ + self._makeScalarField('full_name', 'STRING'), + self._makeScalarField('age', 'INT64'), + self._makeScalarField('married', 'BOOL'), + ] + metadata = _ResultSetMetadataPB(FIELDS) + ALREADY = [ + [u'Pebbylz Phlyntstone', 4, False], + [u'Dino Rhubble', 4, False], + ] + BARE = [ + u'Phred Phlyntstone', 42, True, + u'Bharney Rhubble', 39, True, + u'Wylma Phlyntstone', 41, True, + ] + VALUES = [self._makeValue(bare) for bare in BARE] + result_set1 = _PartialResultSetPB(VALUES[:4], metadata=metadata) + result_set2 = _PartialResultSetPB(VALUES[4:]) + iterator = _MockCancellableIterator(result_set1, result_set2) + streamed = self._makeOne(iterator) + streamed._rows[:] = ALREADY + found = list(streamed) + self.assertEqual(found, ALREADY + [ + [BARE[0], BARE[1], BARE[2]], + [BARE[3], BARE[4], BARE[5]], + [BARE[6], BARE[7], BARE[8]], + ]) + self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._current_row, []) + self.assertIsNone(streamed._pending_chunk) + + +class _MockCancellableIterator(object): + + cancel_calls = 0 + + def __init__(self, *values): + self.iter_values = iter(values) + + def next(self): + return next(self.iter_values) + + def __next__(self): # pragma: NO COVER Py3k + return self.next() + + +class _ResultSetMetadataPB(object): + + def __init__(self, fields): + from google.cloud.proto.spanner.v1.type_pb2 import StructType + self.row_type = StructType(fields=fields) + + +class _ResultSetStatsPB(object): + + def __init__(self, query_plan=None, **query_stats): + from google.protobuf.struct_pb2 import Struct + from google.cloud.spanner._helpers import _make_value_pb + self.query_plan = query_plan + self.query_stats = Struct(fields={ + key: _make_value_pb(value) for key, value in query_stats.items()}) + + +class _PartialResultSetPB(object): + + resume_token = b'DEADBEEF' + + def __init__(self, values, metadata=None, stats=None, chunked_value=False): + self.values = values + self.metadata = metadata + self.stats = stats + self.chunked_value = chunked_value + + def HasField(self, name): + assert name == 'stats' + return self.stats is not None + + +class TestStreamedResultSet_JSON_acceptance_tests(unittest.TestCase): + + _json_tests = None + + def _getTargetClass(self): + from google.cloud.spanner.streamed import StreamedResultSet + return StreamedResultSet + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def _load_json_test(self, test_name): + import os + if self.__class__._json_tests is None: + dirname = os.path.dirname(__file__) + filename = os.path.join( + dirname, 'streaming-read-acceptance-test.json') + raw = _parse_streaming_read_acceptance_tests(filename) + tests = self.__class__._json_tests = {} + for (name, partial_result_sets, results) in raw: + tests[name] = partial_result_sets, results + return self.__class__._json_tests[test_name] + + # Non-error cases + + def _match_results(self, testcase_name, assert_equality=None): + partial_result_sets, expected = self._load_json_test(testcase_name) + iterator = _MockCancellableIterator(*partial_result_sets) + partial = self._makeOne(iterator) + partial.consume_all() + if assert_equality is not None: + assert_equality(partial.rows, expected) + else: + self.assertEqual(partial.rows, expected) + + def test_basic(self): + self._match_results('Basic Test') + + def test_string_chunking(self): + self._match_results('String Chunking Test') + + def test_string_array_chunking(self): + self._match_results('String Array Chunking Test') + + def test_string_array_chunking_with_nulls(self): + self._match_results('String Array Chunking Test With Nulls') + + def test_string_array_chunking_with_empty_strings(self): + self._match_results('String Array Chunking Test With Empty Strings') + + def test_string_array_chunking_with_one_large_string(self): + self._match_results('String Array Chunking Test With One Large String') + + def test_int64_array_chunking(self): + self._match_results('INT64 Array Chunking Test') + + def test_float64_array_chunking(self): + import math + + def assert_float_equality(lhs, rhs): + # NaN, +Inf, and -Inf can't be tested for equality + if lhs is None: + self.assertIsNone(rhs) + elif math.isnan(lhs): + self.assertTrue(math.isnan(rhs)) + elif math.isinf(lhs): + self.assertTrue(math.isinf(rhs)) + # but +Inf and -Inf can be tested for magnitude + self.assertTrue((lhs > 0) == (rhs > 0)) + else: + self.assertEqual(lhs, rhs) + + def assert_rows_equality(lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + for l_rows, r_rows in zip(lhs, rhs): + self.assertEqual(len(l_rows), len(r_rows)) + for l_row, r_row in zip(l_rows, r_rows): + self.assertEqual(len(l_row), len(r_row)) + for l_cell, r_cell in zip(l_row, r_row): + assert_float_equality(l_cell, r_cell) + + self._match_results( + 'FLOAT64 Array Chunking Test', assert_rows_equality) + + def test_struct_array_chunking(self): + self._match_results('Struct Array Chunking Test') + + def test_nested_struct_array(self): + self._match_results('Nested Struct Array Test') + + def test_nested_struct_array_chunking(self): + self._match_results('Nested Struct Array Chunking Test') + + def test_struct_array_and_string_chunking(self): + self._match_results('Struct Array And String Chunking Test') + + def test_multiple_row_single_chunk(self): + self._match_results('Multiple Row Single Chunk') + + def test_multiple_row_multiple_chunks(self): + self._match_results('Multiple Row Multiple Chunks') + + def test_multiple_row_chunks_non_chunks_interleaved(self): + self._match_results('Multiple Row Chunks/Non Chunks Interleaved') + + +def _generate_partial_result_sets(prs_text_pbs): + from google.protobuf.json_format import Parse + from google.cloud.proto.spanner.v1.result_set_pb2 import PartialResultSet + + partial_result_sets = [] + + for prs_text_pb in prs_text_pbs: + prs = PartialResultSet() + partial_result_sets.append(Parse(prs_text_pb, prs)) + + return partial_result_sets + + +def _normalize_int_array(cell): + normalized = [] + for subcell in cell: + if subcell is not None: + subcell = int(subcell) + normalized.append(subcell) + return normalized + + +def _normalize_float(cell): + if cell == u'Infinity': + return float('inf') + if cell == u'-Infinity': + return float('-inf') + if cell == u'NaN': + return float('nan') + if cell is not None: + return float(cell) + + +def _normalize_results(rows_data, fields): + """Helper for _parse_streaming_read_acceptance_tests""" + from google.cloud.proto.spanner.v1 import type_pb2 + normalized = [] + for row_data in rows_data: + row = [] + assert len(row_data) == len(fields) + for cell, field in zip(row_data, fields): + if field.type.code == type_pb2.INT64: + cell = int(cell) + if field.type.code == type_pb2.FLOAT64: + cell = _normalize_float(cell) + elif field.type.code == type_pb2.BYTES: + cell = cell.encode('utf8') + elif field.type.code == type_pb2.ARRAY: + if field.type.array_element_type.code == type_pb2.INT64: + cell = _normalize_int_array(cell) + elif field.type.array_element_type.code == type_pb2.FLOAT64: + cell = [_normalize_float(subcell) for subcell in cell] + row.append(cell) + normalized.append(row) + return normalized + + +def _parse_streaming_read_acceptance_tests(filename): + """Parse acceptance tests from JSON + + See: streaming-read-acceptance-test.json + """ + import json + + with open(filename) as json_file: + test_json = json.load(json_file) + + for test in test_json['tests']: + name = test['name'] + partial_result_sets = _generate_partial_result_sets(test['chunks']) + fields = partial_result_sets[0].metadata.row_type.fields + result = _normalize_results(test['result']['value'], fields) + yield name, partial_result_sets, result diff --git a/spanner/unit_tests/test_transaction.py b/spanner/unit_tests/test_transaction.py new file mode 100644 index 000000000000..265c0d8a6967 --- /dev/null +++ b/spanner/unit_tests/test_transaction.py @@ -0,0 +1,392 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from google.cloud._testing import _GAXBaseAPI + + +TABLE_NAME = 'citizens' +COLUMNS = ['email', 'first_name', 'last_name', 'age'] +VALUES = [ + ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], + ['bharney@example.com', 'Bharney', 'Rhubble', 31], +] + + +class TestTransaction(unittest.TestCase): + + PROJECT_ID = 'project-id' + INSTANCE_ID = 'instance-id' + INSTANCE_NAME = 'projects/' + PROJECT_ID + '/instances/' + INSTANCE_ID + DATABASE_ID = 'database-id' + DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID + SESSION_ID = 'session-id' + SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID + TRANSACTION_ID = b'DEADBEEF' + + def _getTargetClass(self): + from google.cloud.spanner.transaction import Transaction + return Transaction + + def _makeOne(self, *args, **kwargs): + return self._getTargetClass()(*args, **kwargs) + + def test_ctor_defaults(self): + session = _Session() + transaction = self._makeOne(session) + self.assertTrue(transaction._session is session) + self.assertIsNone(transaction._id) + self.assertIsNone(transaction.committed) + self.assertEqual(transaction._rolled_back, False) + + def test__check_state_not_begun(self): + session = _Session() + transaction = self._makeOne(session) + with self.assertRaises(ValueError): + transaction._check_state() + + def test__check_state_already_committed(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = b'DEADBEEF' + transaction.committed = object() + with self.assertRaises(ValueError): + transaction._check_state() + + def test__check_state_already_rolled_back(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = b'DEADBEEF' + transaction._rolled_back = True + with self.assertRaises(ValueError): + transaction._check_state() + + def test__check_state_ok(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = b'DEADBEEF' + transaction._check_state() # does not raise + + def test__make_txn_selector(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + selector = transaction._make_txn_selector() + self.assertEqual(selector.id, self.TRANSACTION_ID) + + def test_begin_already_begun(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + with self.assertRaises(ValueError): + transaction.begin() + + def test_begin_already_rolled_back(self): + session = _Session() + transaction = self._makeOne(session) + transaction._rolled_back = True + with self.assertRaises(ValueError): + transaction.begin() + + def test_begin_already_committed(self): + session = _Session() + transaction = self._makeOne(session) + transaction.committed = object() + with self.assertRaises(ValueError): + transaction.begin() + + def test_begin_w_gax_error(self): + from google.gax.errors import GaxError + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _random_gax_error=True) + session = _Session(database) + transaction = self._makeOne(session) + + with self.assertRaises(GaxError): + transaction.begin() + + session_id, txn_options, options = api._begun + self.assertEqual(session_id, session.name) + self.assertTrue(txn_options.HasField('read_write')) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_begin_ok(self): + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _begin_transaction_response=transaction_pb) + session = _Session(database) + transaction = self._makeOne(session) + + txn_id = transaction.begin() + + self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(transaction._id, self.TRANSACTION_ID) + + session_id, txn_options, options = api._begun + self.assertEqual(session_id, session.name) + self.assertTrue(txn_options.HasField('read_write')) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_rollback_not_begun(self): + session = _Session() + transaction = self._makeOne(session) + with self.assertRaises(ValueError): + transaction.rollback() + + def test_rollback_already_committed(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + transaction.committed = object() + with self.assertRaises(ValueError): + transaction.rollback() + + def test_rollback_already_rolled_back(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + transaction._rolled_back = True + with self.assertRaises(ValueError): + transaction.rollback() + + def test_rollback_w_gax_error(self): + from google.gax.errors import GaxError + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _random_gax_error=True) + session = _Session(database) + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + transaction.insert(TABLE_NAME, COLUMNS, VALUES) + + with self.assertRaises(GaxError): + transaction.rollback() + + self.assertFalse(transaction._rolled_back) + + session_id, txn_id, options = api._rolled_back + self.assertEqual(session_id, session.name) + self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_rollback_ok(self): + from google.protobuf.empty_pb2 import Empty + empty_pb = Empty() + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _rollback_response=empty_pb) + session = _Session(database) + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + transaction.replace(TABLE_NAME, COLUMNS, VALUES) + + transaction.rollback() + + self.assertTrue(transaction._rolled_back) + + session_id, txn_id, options = api._rolled_back + self.assertEqual(session_id, session.name) + self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_commit_not_begun(self): + session = _Session() + transaction = self._makeOne(session) + with self.assertRaises(ValueError): + transaction.commit() + + def test_commit_already_committed(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + transaction.committed = object() + with self.assertRaises(ValueError): + transaction.commit() + + def test_commit_already_rolled_back(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + transaction._rolled_back = True + with self.assertRaises(ValueError): + transaction.commit() + + def test_commit_no_mutations(self): + session = _Session() + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + with self.assertRaises(ValueError): + transaction.commit() + + def test_commit_w_gax_error(self): + from google.gax.errors import GaxError + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _random_gax_error=True) + session = _Session(database) + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + transaction.replace(TABLE_NAME, COLUMNS, VALUES) + + with self.assertRaises(GaxError): + transaction.commit() + + self.assertIsNone(transaction.committed) + + session_id, mutations, txn_id, options = api._committed + self.assertEqual(session_id, session.name) + self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(mutations, transaction._mutations) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_commit_ok(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.spanner.keyset import KeySet + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) + response = CommitResponse(commit_timestamp=now_pb) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _commit_response=response) + session = _Session(database) + transaction = self._makeOne(session) + transaction._id = self.TRANSACTION_ID + transaction.delete(TABLE_NAME, keyset) + + transaction.commit() + + self.assertEqual(transaction.committed, now) + + session_id, mutations, txn_id, options = api._committed + self.assertEqual(session_id, session.name) + self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(mutations, transaction._mutations) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_context_mgr_success(self): + import datetime + from google.cloud.proto.spanner.v1.spanner_pb2 import CommitResponse + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + database = _Database() + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _begin_transaction_response=transaction_pb, + _commit_response=response) + session = _Session(database) + transaction = self._makeOne(session) + + with transaction: + transaction.insert(TABLE_NAME, COLUMNS, VALUES) + + self.assertEqual(transaction.committed, now) + + session_id, mutations, txn_id, options = api._committed + self.assertEqual(session_id, self.SESSION_NAME) + self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(mutations, transaction._mutations) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + def test_context_mgr_failure(self): + from google.protobuf.empty_pb2 import Empty + empty_pb = Empty() + from google.cloud.proto.spanner.v1.transaction_pb2 import ( + Transaction as TransactionPB) + transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI( + _begin_transaction_response=transaction_pb, + _rollback_response=empty_pb) + session = _Session(database) + transaction = self._makeOne(session) + + with self.assertRaises(Exception): + with transaction: + transaction.insert(TABLE_NAME, COLUMNS, VALUES) + raise Exception("bail out") + + self.assertEqual(transaction.committed, None) + self.assertTrue(transaction._rolled_back) + self.assertEqual(len(transaction._mutations), 1) + + self.assertEqual(api._committed, None) + + session_id, txn_id, options = api._rolled_back + self.assertEqual(session_id, session.name) + self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(options.kwargs['metadata'], + [('google-cloud-resource-prefix', database.name)]) + + +class _Database(object): + name = 'testing' + + +class _Session(object): + + def __init__(self, database=None, name=TestTransaction.SESSION_NAME): + self._database = database + self.name = name + + +class _FauxSpannerAPI(_GAXBaseAPI): + + _committed = None + + def begin_transaction(self, session, options_, options=None): + from google.gax.errors import GaxError + self._begun = (session, options_, options) + if self._random_gax_error: + raise GaxError('error') + return self._begin_transaction_response + + def rollback(self, session, transaction_id, options=None): + from google.gax.errors import GaxError + self._rolled_back = (session, transaction_id, options) + if self._random_gax_error: + raise GaxError('error') + return self._rollback_response + + def commit(self, session, mutations, + transaction_id='', single_use_transaction=None, options=None): + from google.gax.errors import GaxError + assert single_use_transaction is None + self._committed = (session, mutations, transaction_id, options) + if self._random_gax_error: + raise GaxError('error') + return self._commit_response diff --git a/system_tests/attempt_system_tests.py b/system_tests/attempt_system_tests.py index 1c552f88dfa5..747d70b62a34 100644 --- a/system_tests/attempt_system_tests.py +++ b/system_tests/attempt_system_tests.py @@ -68,6 +68,7 @@ 'translate', 'monitoring', 'bigtable', + 'spanner', ) SCRIPTS_DIR = os.path.dirname(__file__) diff --git a/system_tests/run_system_test.py b/system_tests/run_system_test.py index c0dce7c6caaf..2ea9999e9e56 100644 --- a/system_tests/run_system_test.py +++ b/system_tests/run_system_test.py @@ -23,6 +23,7 @@ import logging_ import monitoring import pubsub +import spanner import speech import storage import system_test_utils @@ -31,17 +32,18 @@ TEST_MODULES = { - 'datastore': datastore, - 'speech': speech, - 'vision': vision, - 'storage': storage, - 'pubsub': pubsub, 'bigquery': bigquery, 'bigtable': bigtable, + 'datastore': datastore, 'language': language, 'logging': logging_, 'monitoring': monitoring, + 'pubsub': pubsub, + 'spanner': spanner, + 'speech': speech, + 'storage': storage, 'translate': translate, + 'vision': vision, } @@ -53,7 +55,7 @@ def get_parser(): parser = argparse.ArgumentParser( description='google-cloud test runner against actual project.') parser.add_argument('--package', dest='package', - choices=TEST_MODULES.keys(), + choices=sorted(TEST_MODULES), default='datastore', help='Package to be tested.') parser.add_argument( '--ignore-requirements', diff --git a/system_tests/spanner.py b/system_tests/spanner.py new file mode 100644 index 000000000000..47d694d717ab --- /dev/null +++ b/system_tests/spanner.py @@ -0,0 +1,445 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +import os +import unittest + +from google.cloud.proto.spanner.v1.type_pb2 import STRING +from google.cloud.proto.spanner.v1.type_pb2 import Type +from google.cloud.spanner.client import Client +from google.cloud.spanner.pool import BurstyPool +from google.cloud.spanner._fixtures import DDL_STATEMENTS + +from retry import RetryErrors +from retry import RetryInstanceState +from retry import RetryResult +from system_test_utils import unique_resource_id + +IS_TRAVIS = os.getenv('TRAVIS') == 'true' +CREATE_INSTANCE = IS_TRAVIS or os.getenv( + 'GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE') is not None + +if CREATE_INSTANCE: + INSTANCE_ID = 'google-cloud' + unique_resource_id('-') +else: + INSTANCE_ID = os.environ.get('GOOGLE_CLOUD_TESTS_SPANNER_INSTANCE', + 'google-cloud-python-systest') +DATABASE_ID = 'test_database' +EXISTING_INSTANCES = [] + + +class Config(object): + """Run-time configuration to be modified at set-up. + + This is a mutable stand-in to allow test set-up to modify + global state. + """ + CLIENT = None + INSTANCE_CONFIG = None + INSTANCE = None + + +def _retry_on_unavailable(exc): + """Retry only errors whose status code is 'UNAVAILABLE'.""" + from grpc import StatusCode + return exc.code() == StatusCode.UNAVAILABLE + + +def _has_all_ddl(database): + return len(database.ddl_statements) == len(DDL_STATEMENTS) + + +def setUpModule(): + from grpc._channel import _Rendezvous + Config.CLIENT = Client() + retry = RetryErrors(_Rendezvous, error_predicate=_retry_on_unavailable) + + configs = list(retry(Config.CLIENT.list_instance_configs)()) + + if len(configs) < 1: + raise ValueError('List instance configs failed in module set up.') + + Config.INSTANCE_CONFIG = configs[0] + config_name = configs[0].name + + def _list_instances(): + return list(Config.CLIENT.list_instances()) + + instances = retry(_list_instances)() + EXISTING_INSTANCES[:] = instances + + if CREATE_INSTANCE: + Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID, config_name) + created_op = Config.INSTANCE.create() + created_op.result(30) # block until completion + + else: + Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID) + Config.INSTANCE.reload() + + +def tearDownModule(): + if CREATE_INSTANCE: + Config.INSTANCE.delete() + + +class TestInstanceAdminAPI(unittest.TestCase): + + def setUp(self): + self.instances_to_delete = [] + + def tearDown(self): + for instance in self.instances_to_delete: + instance.delete() + + def test_list_instances(self): + instances = list(Config.CLIENT.list_instances()) + # We have added one new instance in `setUpModule`. + if CREATE_INSTANCE: + self.assertEqual(len(instances), len(EXISTING_INSTANCES) + 1) + for instance in instances: + instance_existence = (instance in EXISTING_INSTANCES or + instance == Config.INSTANCE) + self.assertTrue(instance_existence) + + def test_reload_instance(self): + # Use same arguments as Config.INSTANCE (created in `setUpModule`) + # so we can use reload() on a fresh instance. + instance = Config.CLIENT.instance( + INSTANCE_ID, Config.INSTANCE_CONFIG.name) + # Make sure metadata unset before reloading. + instance.display_name = None + + instance.reload() + self.assertEqual(instance.display_name, Config.INSTANCE.display_name) + + @unittest.skipUnless(CREATE_INSTANCE, 'Skipping instance creation') + def test_create_instance(self): + ALT_INSTANCE_ID = 'new' + unique_resource_id('-') + instance = Config.CLIENT.instance( + ALT_INSTANCE_ID, Config.INSTANCE_CONFIG.name) + operation = instance.create() + # Make sure this instance gets deleted after the test case. + self.instances_to_delete.append(instance) + + # We want to make sure the operation completes. + operation.result(30) # raises on failure / timeout. + + # Create a new instance instance and make sure it is the same. + instance_alt = Config.CLIENT.instance( + ALT_INSTANCE_ID, Config.INSTANCE_CONFIG.name) + instance_alt.reload() + + self.assertEqual(instance, instance_alt) + self.assertEqual(instance.display_name, instance_alt.display_name) + + def test_update_instance(self): + OLD_DISPLAY_NAME = Config.INSTANCE.display_name + NEW_DISPLAY_NAME = 'Foo Bar Baz' + Config.INSTANCE.display_name = NEW_DISPLAY_NAME + operation = Config.INSTANCE.update() + + # We want to make sure the operation completes. + operation.result(30) # raises on failure / timeout. + + # Create a new instance instance and reload it. + instance_alt = Config.CLIENT.instance(INSTANCE_ID, None) + self.assertNotEqual(instance_alt.display_name, NEW_DISPLAY_NAME) + instance_alt.reload() + self.assertEqual(instance_alt.display_name, NEW_DISPLAY_NAME) + + # Make sure to put the instance back the way it was for the + # other test cases. + Config.INSTANCE.display_name = OLD_DISPLAY_NAME + Config.INSTANCE.update() + + +class TestDatabaseAdminAPI(unittest.TestCase): + + @classmethod + def setUpClass(cls): + pool = BurstyPool() + cls._db = Config.INSTANCE.database(DATABASE_ID, pool=pool) + cls._db.create() + + @classmethod + def tearDownClass(cls): + cls._db.drop() + + def setUp(self): + self.to_delete = [] + + def tearDown(self): + for doomed in self.to_delete: + doomed.drop() + + def test_list_databases(self): + # Since `Config.INSTANCE` is newly created in `setUpModule`, the + # database created in `setUpClass` here will be the only one. + databases = list(Config.INSTANCE.list_databases()) + self.assertEqual(databases, [self._db]) + + def test_create_database(self): + pool = BurstyPool() + temp_db_id = 'temp-db' # test w/ hyphen + temp_db = Config.INSTANCE.database(temp_db_id, pool=pool) + operation = temp_db.create() + self.to_delete.append(temp_db) + + # We want to make sure the operation completes. + operation.result(30) # raises on failure / timeout. + + name_attr = operator.attrgetter('name') + expected = sorted([temp_db, self._db], key=name_attr) + + databases = list(Config.INSTANCE.list_databases()) + found = sorted(databases, key=name_attr) + self.assertEqual(found, expected) + + def test_update_database_ddl(self): + pool = BurstyPool() + temp_db_id = 'temp_db' + temp_db = Config.INSTANCE.database(temp_db_id, pool=pool) + create_op = temp_db.create() + self.to_delete.append(temp_db) + + # We want to make sure the operation completes. + create_op.result(60) # raises on failure / timeout. + + operation = temp_db.update_ddl(DDL_STATEMENTS) + + # We want to make sure the operation completes. + operation.result(30) # raises on failure / timeout. + + temp_db.reload() + + self.assertEqual(len(temp_db.ddl_statements), len(DDL_STATEMENTS)) + + +class TestSessionAPI(unittest.TestCase): + TABLE = 'contacts' + COLUMNS = ('contact_id', 'first_name', 'last_name', 'email') + ROW_DATA = ( + (1, u'Phred', u'Phlyntstone', u'phred@example.com'), + (2, u'Bharney', u'Rhubble', u'bharney@example.com'), + (3, u'Wylma', u'Phlyntstone', u'wylma@example.com'), + ) + SQL = 'SELECT * FROM contacts ORDER BY contact_id' + + @classmethod + def setUpClass(cls): + pool = BurstyPool() + cls._db = Config.INSTANCE.database( + DATABASE_ID, ddl_statements=DDL_STATEMENTS, pool=pool) + operation = cls._db.create() + operation.result(30) # raises on failure / timeout. + + @classmethod + def tearDownClass(cls): + cls._db.drop() + + def setUp(self): + self.to_delete = [] + + def tearDown(self): + for doomed in self.to_delete: + doomed.delete() + + def _check_row_data(self, row_data): + self.assertEqual(len(row_data), len(self.ROW_DATA)) + for found, expected in zip(row_data, self.ROW_DATA): + self.assertEqual(len(found), len(expected)) + for f_cell, e_cell in zip(found, expected): + self.assertEqual(f_cell, e_cell) + + def test_session_crud(self): + retry_true = RetryResult(operator.truth) + retry_false = RetryResult(operator.not_) + session = self._db.session() + self.assertFalse(session.exists()) + session.create() + retry_true(session.exists)() + session.delete() + retry_false(session.exists)() + + def test_batch_insert_then_read(self): + from google.cloud.spanner import KeySet + keyset = KeySet(all_=True) + + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + batch = session.batch() + batch.delete(self.TABLE, keyset) + batch.insert(self.TABLE, self.COLUMNS, self.ROW_DATA) + batch.commit() + + snapshot = session.snapshot(read_timestamp=batch.committed) + rows = list(snapshot.read(self.TABLE, self.COLUMNS, keyset)) + self._check_row_data(rows) + + def test_batch_insert_or_update_then_query(self): + + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.insert_or_update(self.TABLE, self.COLUMNS, self.ROW_DATA) + + snapshot = session.snapshot(read_timestamp=batch.committed) + rows = list(snapshot.execute_sql(self.SQL)) + self._check_row_data(rows) + + def test_transaction_read_and_insert_then_rollback(self): + from google.cloud.spanner import KeySet + keyset = KeySet(all_=True) + + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, keyset) + + transaction = session.transaction() + transaction.begin() + rows = list(transaction.read(self.TABLE, self.COLUMNS, keyset)) + self.assertEqual(rows, []) + + transaction.insert(self.TABLE, self.COLUMNS, self.ROW_DATA) + + # Inserted rows can't be read until after commit. + rows = list(transaction.read(self.TABLE, self.COLUMNS, keyset)) + self.assertEqual(rows, []) + transaction.rollback() + + rows = list(session.read(self.TABLE, self.COLUMNS, keyset)) + self.assertEqual(rows, []) + + def test_transaction_read_and_insert_or_update_then_commit(self): + from google.cloud.spanner import KeySet + keyset = KeySet(all_=True) + + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, keyset) + + with session.transaction() as transaction: + rows = list(transaction.read(self.TABLE, self.COLUMNS, keyset)) + self.assertEqual(rows, []) + + transaction.insert_or_update( + self.TABLE, self.COLUMNS, self.ROW_DATA) + + # Inserted rows can't be read until after commit. + rows = list(transaction.read(self.TABLE, self.COLUMNS, keyset)) + self.assertEqual(rows, []) + + rows = list(session.read(self.TABLE, self.COLUMNS, keyset)) + self._check_row_data(rows) + + def _set_up_table(self, row_count): + from google.cloud.spanner import KeySet + + def _row_data(max_index): + for index in range(max_index): + yield [index, 'First%09d' % (index,), 'Last09%d' % (index), + 'test-%09d@example.com' % (index,)] + + keyset = KeySet(all_=True) + + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.transaction() as transaction: + transaction.delete(self.TABLE, keyset) + transaction.insert(self.TABLE, self.COLUMNS, _row_data(row_count)) + + return session, keyset, transaction.committed + + def test_read_w_manual_consume(self): + ROW_COUNT = 4000 + session, keyset, committed = self._set_up_table(ROW_COUNT) + + snapshot = session.snapshot(read_timestamp=committed) + streamed = snapshot.read(self.TABLE, self.COLUMNS, keyset) + + retrieved = 0 + while True: + try: + streamed.consume_next() + except StopIteration: + break + retrieved += len(streamed.rows) + streamed.rows[:] = () + + self.assertEqual(retrieved, ROW_COUNT) + self.assertEqual(streamed._current_row, []) + self.assertEqual(streamed._pending_chunk, None) + + def test_execute_sql_w_manual_consume(self): + ROW_COUNT = 4000 + session, _, committed = self._set_up_table(ROW_COUNT) + + snapshot = session.snapshot(read_timestamp=committed) + streamed = snapshot.execute_sql(self.SQL) + + retrieved = 0 + while True: + try: + streamed.consume_next() + except StopIteration: + break + retrieved += len(streamed.rows) + streamed.rows[:] = () + + self.assertEqual(retrieved, ROW_COUNT) + self.assertEqual(streamed._current_row, []) + self.assertEqual(streamed._pending_chunk, None) + + def test_execute_sql_w_query_param(self): + SQL = 'SELECT * FROM contacts WHERE first_name = @first_name' + ROW_COUNT = 10 + session, _, committed = self._set_up_table(ROW_COUNT) + + snapshot = session.snapshot(read_timestamp=committed) + rows = list(snapshot.execute_sql( + SQL, + params={'first_name': 'First%09d' % (0,)}, + param_types={'first_name': Type(code=STRING)}, + )) + + self.assertEqual(len(rows), 1) diff --git a/tox.ini b/tox.ini index 043c424d05a1..401b2d07ea07 100644 --- a/tox.ini +++ b/tox.ini @@ -20,6 +20,7 @@ deps = {toxinidir}/translate {toxinidir}/speech {toxinidir}/runtimeconfig + {toxinidir}/spanner mock pytest passenv = @@ -141,6 +142,12 @@ covercmd = --cov-append \ --cov-config {toxinidir}/.coveragerc \ runtimeconfig/unit_tests + py.test --quiet \ + --cov=google.cloud \ + --cov=unit_tests \ + --cov-append \ + --cov-config {toxinidir}/.coveragerc \ + spanner/unit_tests coverage report --show-missing --fail-under=100 [testenv]