diff --git a/CHANGES.md b/CHANGES.md index 50d04b6f0e45..252d34aa9af6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -72,7 +72,7 @@ ## New Features / Improvements -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Support configuring Firestore database on ReadFn transforms (Java) ([#36904](https://github.com/apache/beam/issues/36904)). ## Breaking Changes diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreStatefulComponentFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreStatefulComponentFactory.java index 390e102b6010..fd124cb9236f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreStatefulComponentFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreStatefulComponentFactory.java @@ -63,12 +63,10 @@ private FirestoreStatefulComponentFactory() {} *

The instance returned by this method is expected to bind to the lifecycle of a bundle. * * @param options The instance of options to read from + * @param configuredProjectId The project to target, if null, falls back to value in options. + * @param configuredDatabaseId The database to target, if null, falls back to value in options. * @return a new {@link FirestoreStub} pre-configured with values from the provided options */ - FirestoreStub getFirestoreStub(PipelineOptions options) { - return getFirestoreStub(options, null, null); - } - FirestoreStub getFirestoreStub( PipelineOptions options, @Nullable String configuredProjectId, diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java index 446d097a8ed8..3f22e636e8ab 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java @@ -595,8 +595,11 @@ private ListCollectionIds( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } @Override @@ -613,7 +616,8 @@ public PCollection expand(PCollection input) { @Override public Builder toBuilder() { - return new Builder(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + return new Builder( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } /** @@ -653,8 +657,16 @@ private Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); } @Override @@ -667,9 +679,16 @@ ListCollectionIds buildSafe( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { return new ListCollectionIds( - clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); } } } @@ -710,8 +729,11 @@ private ListDocuments( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } @Override @@ -728,7 +750,8 @@ public PCollection expand(PCollection input) { @Override public Builder toBuilder() { - return new Builder(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + return new Builder( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } /** @@ -768,8 +791,16 @@ private Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); } @Override @@ -782,8 +813,16 @@ ListDocuments buildSafe( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - return new ListDocuments(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + return new ListDocuments( + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); } } } @@ -824,8 +863,11 @@ private RunQuery( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } @Override @@ -841,7 +883,8 @@ public PCollection expand(PCollection input) @Override public Builder toBuilder() { - return new Builder(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + return new Builder( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } /** @@ -881,8 +924,16 @@ private Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); } @Override @@ -895,8 +946,16 @@ RunQuery buildSafe( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - return new RunQuery(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + return new RunQuery( + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); } } } @@ -937,8 +996,11 @@ private BatchGetDocuments( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } @Override @@ -955,7 +1017,8 @@ public PCollection expand( @Override public Builder toBuilder() { - return new Builder(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + return new Builder( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } /** @@ -995,8 +1058,16 @@ public Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); } @Override @@ -1009,9 +1080,16 @@ BatchGetDocuments buildSafe( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { return new BatchGetDocuments( - clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); } } } @@ -1061,8 +1139,11 @@ private PartitionQuery( FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, boolean nameOnlyQuery, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); this.nameOnlyQuery = nameOnlyQuery; } @@ -1106,7 +1187,13 @@ public RunQueryRequest apply(RunQueryRequest input) { @Override public Builder toBuilder() { return new Builder( - clock, firestoreStatefulComponentFactory, rpcQosOptions, nameOnlyQuery, readTime); + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + nameOnlyQuery, + readTime, + projectId, + databaseId); } /** @@ -1149,8 +1236,16 @@ public Builder( FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, boolean nameOnlyQuery, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + readTime, + projectId, + databaseId); this.nameOnlyQuery = nameOnlyQuery; } @@ -1175,9 +1270,17 @@ PartitionQuery buildSafe( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { return new PartitionQuery( - clock, firestoreStatefulComponentFactory, rpcQosOptions, nameOnlyQuery, readTime); + clock, + firestoreStatefulComponentFactory, + rpcQosOptions, + nameOnlyQuery, + readTime, + projectId, + databaseId); } } @@ -1365,18 +1468,13 @@ public static final class BatchWriteWithSummary BatchWriteWithSummary, BatchWriteWithSummary.Builder> { - private final @Nullable String projectId; - private final @Nullable String databaseId; - public BatchWriteWithSummary( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, @Nullable String projectId, @Nullable String databaseId) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions); - this.projectId = projectId; - this.databaseId = databaseId; + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); } @Override @@ -1396,7 +1494,8 @@ public PCollection expand( @Override public Builder toBuilder() { - return new Builder(clock, firestoreStatefulComponentFactory, rpcQosOptions); + return new Builder( + clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); } /** @@ -1429,9 +1528,6 @@ public static final class Builder BatchWriteWithSummary, BatchWriteWithSummary.Builder> { - private @Nullable String projectId; - private @Nullable String databaseId; - private Builder() { super(); } @@ -1439,39 +1535,15 @@ private Builder() { private Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, - RpcQosOptions rpcQosOptions) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions); - } - - /** Set the GCP project ID to be used by the Firestore client. */ - private Builder setProjectId(@Nullable String projectId) { - this.projectId = projectId; - return this; - } - - /** Set the Firestore database ID (e.g., "(default)"). */ - private Builder setDatabaseId(@Nullable String databaseId) { - this.databaseId = databaseId; - return this; - } - - @VisibleForTesting - @Nullable - String getProjectId() { - return this.projectId; - } - - @VisibleForTesting - @Nullable - String getDatabaseId() { - return this.databaseId; + RpcQosOptions rpcQosOptions, + @Nullable String projectId, + @Nullable String databaseId) { + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); } public BatchWriteWithDeadLetterQueue.Builder withDeadLetterQueue() { return new BatchWriteWithDeadLetterQueue.Builder( - clock, firestoreStatefulComponentFactory, rpcQosOptions) - .setProjectId(projectId) - .setDatabaseId(databaseId); + clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); } @Override @@ -1530,18 +1602,13 @@ public static final class BatchWriteWithDeadLetterQueue BatchWriteWithDeadLetterQueue, BatchWriteWithDeadLetterQueue.Builder> { - private final @Nullable String projectId; - private final @Nullable String databaseId; - private BatchWriteWithDeadLetterQueue( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, @Nullable String projectId, @Nullable String databaseId) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions); - this.projectId = projectId; - this.databaseId = databaseId; + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); } @Override @@ -1560,7 +1627,8 @@ public PCollection expand(PCollection { - private @Nullable String projectId; - private @Nullable String databaseId; - private Builder() { super(); } - private Builder setProjectId(@Nullable String projectId) { - this.projectId = projectId; - return this; - } - - private Builder setDatabaseId(@Nullable String databaseId) { - this.databaseId = databaseId; - return this; - } - - @VisibleForTesting - @Nullable - String getProjectId() { - return this.projectId; - } - - @VisibleForTesting - @Nullable - String getDatabaseId() { - return this.databaseId; - } - private Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, - RpcQosOptions rpcQosOptions) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions); + RpcQosOptions rpcQosOptions, + @Nullable String projectId, + @Nullable String databaseId) { + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); } @Override @@ -1790,14 +1835,20 @@ private abstract static class Transform< final JodaClock clock; final FirestoreStatefulComponentFactory firestoreStatefulComponentFactory; final RpcQosOptions rpcQosOptions; + final @Nullable String projectId; + final @Nullable String databaseId; Transform( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, - RpcQosOptions rpcQosOptions) { + RpcQosOptions rpcQosOptions, + @Nullable String projectId, + @Nullable String databaseId) { this.clock = clock; this.firestoreStatefulComponentFactory = firestoreStatefulComponentFactory; this.rpcQosOptions = rpcQosOptions; + this.projectId = projectId; + this.databaseId = databaseId; } @Override @@ -1838,20 +1889,28 @@ abstract static class Builder< JodaClock clock; FirestoreStatefulComponentFactory firestoreStatefulComponentFactory; RpcQosOptions rpcQosOptions; + @Nullable String projectId; + @Nullable String databaseId; Builder() { clock = JodaClock.DEFAULT; firestoreStatefulComponentFactory = FirestoreStatefulComponentFactory.INSTANCE; rpcQosOptions = RpcQosOptions.defaultOptions(); + projectId = null; + databaseId = null; } private Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, - RpcQosOptions rpcQosOptions) { + RpcQosOptions rpcQosOptions, + @Nullable String projectId, + @Nullable String databaseId) { this.clock = clock; this.firestoreStatefulComponentFactory = firestoreStatefulComponentFactory; this.rpcQosOptions = rpcQosOptions; + this.projectId = projectId; + this.databaseId = databaseId; } /** @@ -1934,6 +1993,28 @@ public final BldrT withRpcQosOptions(RpcQosOptions rpcQosOptions) { this.rpcQosOptions = rpcQosOptions; return self(); } + + public final BldrT setProjectId(@Nullable String projectId) { + this.projectId = projectId; + return self(); + } + + public final BldrT setDatabaseId(@Nullable String databaseId) { + this.databaseId = databaseId; + return self(); + } + + @VisibleForTesting + @Nullable + String getProjectId() { + return this.projectId; + } + + @VisibleForTesting + @Nullable + String getDatabaseId() { + return this.databaseId; + } } } @@ -1950,8 +2031,10 @@ private abstract static class ReadTransform< JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); this.readTime = readTime; } @@ -1975,8 +2058,10 @@ private Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); this.readTime = readTime; } @@ -1986,7 +2071,9 @@ final TrfmT genericBuild() { requireNonNull(clock, "clock must be non null"), requireNonNull(firestoreStatefulComponentFactory, "firestoreFactory must be non null"), requireNonNull(rpcQosOptions, "rpcQosOptions must be non null"), - readTime); + readTime, + projectId, + databaseId); } @Override @@ -2001,12 +2088,24 @@ abstract TrfmT buildSafe( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId); public final BldrT withReadTime(@Nullable Instant readTime) { this.readTime = readTime; return self(); } + + public final BldrT withProjectId(@Nullable String projectId) { + this.projectId = projectId; + return self(); + } + + public final BldrT withDatabaseId(@Nullable String databaseId) { + this.databaseId = databaseId; + return self(); + } } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java index 51e5efa380e8..84e1cb1be0ac 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java @@ -100,6 +100,17 @@ static final class RunQueryFn super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); } + RunQueryFn( + JodaClock clock, + FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, + RpcQosOptions rpcQosOptions, + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); + } + @Override public Context getRpcAttemptContext() { return FirestoreV1RpcAttemptContexts.V1FnRpcAttemptContext.RunQuery; @@ -167,7 +178,7 @@ public PartitionQueryFn( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, null); + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, null, null, null); } public PartitionQueryFn( @@ -175,7 +186,18 @@ public PartitionQueryFn( FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, null, null); + } + + public PartitionQueryFn( + JodaClock clock, + FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, + RpcQosOptions rpcQosOptions, + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } @Override @@ -266,7 +288,7 @@ static final class ListDocumentsFn JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, null); + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, null, null, null); } ListDocumentsFn( @@ -274,7 +296,18 @@ static final class ListDocumentsFn FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, null, null); + } + + ListDocumentsFn( + JodaClock clock, + FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, + RpcQosOptions rpcQosOptions, + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } @Override @@ -320,7 +353,7 @@ static final class ListCollectionIdsFn JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, null); + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, null, null, null); } ListCollectionIdsFn( @@ -328,7 +361,18 @@ static final class ListCollectionIdsFn FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, null, null); + } + + ListCollectionIdsFn( + JodaClock clock, + FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, + RpcQosOptions rpcQosOptions, + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } @Override @@ -383,6 +427,17 @@ static final class BatchGetDocumentsFn super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); } + BatchGetDocumentsFn( + JodaClock clock, + FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, + RpcQosOptions rpcQosOptions, + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); + } + @Override public Context getRpcAttemptContext() { return FirestoreV1RpcAttemptContexts.V1FnRpcAttemptContext.BatchGetDocuments; @@ -458,7 +513,7 @@ protected StreamingFirestoreV1ReadFn( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, null); + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, null, null, null); } protected StreamingFirestoreV1ReadFn( @@ -466,7 +521,18 @@ protected StreamingFirestoreV1ReadFn( FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, null, null); + } + + protected StreamingFirestoreV1ReadFn( + JodaClock clock, + FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, + RpcQosOptions rpcQosOptions, + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } protected abstract ServerStreamingCallable getCallable(FirestoreStub firestoreStub); @@ -539,8 +605,11 @@ protected PaginatedFirestoreV1ReadFn( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { - super(clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime); + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { + super( + clock, firestoreStatefulComponentFactory, rpcQosOptions, readTime, projectId, databaseId); } protected abstract UnaryCallable getCallable( @@ -610,6 +679,7 @@ abstract static class BaseFirestoreV1ReadFn protected transient FirestoreStub firestoreStub; protected transient RpcQos rpcQos; protected transient String projectId; + protected transient @Nullable String databaseId; @SuppressWarnings( "initialization.fields.uninitialized") // allow transient fields to be managed by component @@ -618,12 +688,18 @@ protected BaseFirestoreV1ReadFn( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - @Nullable Instant readTime) { + @Nullable Instant readTime, + @Nullable String projectId, + @Nullable String databaseId) { this.clock = requireNonNull(clock, "clock must be non null"); this.firestoreStatefulComponentFactory = requireNonNull(firestoreStatefulComponentFactory, "firestoreFactory must be non null"); this.rpcQosOptions = requireNonNull(rpcQosOptions, "rpcQosOptions must be non null"); this.readTime = readTime; + if (projectId != null) { + this.projectId = projectId; + } + this.databaseId = databaseId; } /** {@inheritDoc} */ @@ -635,7 +711,10 @@ public void setup() { /** {@inheritDoc} */ @Override public final void startBundle(StartBundleContext c) { - String project = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject(); + String project = + this.projectId != null + ? this.projectId + : c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject(); if (project == null) { project = c.getPipelineOptions().as(GcpOptions.class).getProject(); } @@ -643,7 +722,15 @@ public final void startBundle(StartBundleContext c) { requireNonNull( project, "project must be defined on FirestoreOptions or GcpOptions of PipelineOptions"); - firestoreStub = firestoreStatefulComponentFactory.getFirestoreStub(c.getPipelineOptions()); + databaseId = + this.databaseId != null + ? this.databaseId + : c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb(); + requireNonNull( + databaseId, "firestoreDb must be defined on FirestoreOptions of PipelineOptions"); + firestoreStub = + firestoreStatefulComponentFactory.getFirestoreStub( + c.getPipelineOptions(), projectId, databaseId); } /** {@inheritDoc} */ @@ -651,6 +738,7 @@ public final void startBundle(StartBundleContext c) { @Override public void finishBundle() throws Exception { projectId = null; + databaseId = null; firestoreStub.close(); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1ReadFnTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1ReadFnTest.java index 0aab59d3aacd..5c28d3fc99ea 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1ReadFnTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1ReadFnTest.java @@ -45,7 +45,7 @@ abstract class BaseFirestoreV1ReadFnTest public final void attemptsExhaustedForRetryableError() throws Exception { BaseFirestoreV1ReadFn fn = getFn(clock, ff, rpcQosOptions); V1RpcFnTestCtx ctx = newCtx(); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); when(ff.getRpcQos(any())).thenReturn(rpcQos); when(rpcQos.newReadAttempt(fn.getRpcAttemptContext())).thenReturn(attempt); ctx.mockRpcToCallable(stub); @@ -79,7 +79,7 @@ public final void attemptsExhaustedForRetryableError() throws Exception { public final void noRequestIsSentIfNotSafeToProceed() throws Exception { BaseFirestoreV1ReadFn fn = getFn(clock, ff, rpcQosOptions); V1RpcFnTestCtx ctx = newCtx(); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); when(ff.getRpcQos(any())).thenReturn(rpcQos); when(rpcQos.newReadAttempt(fn.getRpcAttemptContext())).thenReturn(attempt); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchGetDocumentsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchGetDocumentsTest.java index b9c950e92fd5..1dec02ad40a4 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchGetDocumentsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchGetDocumentsTest.java @@ -54,7 +54,12 @@ @RunWith(Parameterized.class) public final class FirestoreV1FnBatchGetDocumentsTest extends BaseFirestoreV1ReadFnTest { - @Parameterized.Parameter public Instant readTime; + + @Parameterized.Parameter(0) + public Instant readTime; + + @Parameterized.Parameter(1) + public boolean setDatabaseOnFn; @Rule public MockitoRule rule = MockitoJUnit.rule(); @@ -65,9 +70,12 @@ public final class FirestoreV1FnBatchGetDocumentsTest @Mock private ServerStream responseStream2; @Mock private ServerStream responseStream3; - @Parameterized.Parameters(name = "readTime = {0}") - public static Collection data() { - return Arrays.asList(null, Instant.now()); + @Parameterized.Parameters(name = "readTime = {0}, setDatabaseOnFn = {1}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {null, false}, {null, true}, {Instant.now(), false}, {Instant.now(), true} + }); } private BatchGetDocumentsRequest withReadTime( @@ -98,7 +106,7 @@ public void endToEnd() throws Exception { when(stub.batchGetDocumentsCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); when(ff.getRpcQos(any())) .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(rpcQosOptions)); @@ -108,7 +116,7 @@ public void endToEnd() throws Exception { when(processContext.element()).thenReturn(request); - runFunction(new BatchGetDocumentsFn(clock, ff, rpcQosOptions, readTime)); + runFunction(getFnWithParameters()); List allValues = responsesCaptor.getAllValues(); assertEquals(responses, allValues); @@ -184,7 +192,7 @@ protected BatchGetDocumentsResponse computeNext() { when(stub.batchGetDocumentsCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); when(ff.getRpcQos(any())).thenReturn(rpcQos); when(rpcQos.newReadAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); @@ -196,7 +204,7 @@ protected BatchGetDocumentsResponse computeNext() { when(processContext.element()).thenReturn(request1); - BatchGetDocumentsFn fn = new BatchGetDocumentsFn(clock, ff, rpcQosOptions, readTime); + BatchGetDocumentsFn fn = getFnWithParameters(); runFunction(fn); @@ -246,6 +254,14 @@ protected BatchGetDocumentsFn getFn( return new BatchGetDocumentsFn(clock, firestoreStatefulComponentFactory, rpcQosOptions); } + private BatchGetDocumentsFn getFnWithParameters() { + if (setDatabaseOnFn) { + return new BatchGetDocumentsFn(clock, ff, rpcQosOptions, readTime, projectId, "(default)"); + } else { + return new BatchGetDocumentsFn(clock, ff, rpcQosOptions, readTime); + } + } + private static BatchGetDocumentsResponse newFound(int docNumber) { String docName = docName(docNumber); return BatchGetDocumentsResponse.newBuilder() diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnListCollectionIdsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnListCollectionIdsTest.java index eb3cd2692c8e..e99d42427316 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnListCollectionIdsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnListCollectionIdsTest.java @@ -57,6 +57,9 @@ public final class FirestoreV1FnListCollectionIdsTest @Parameter public Instant readTime; + @Parameter(1) + public boolean setDatabaseOnFn; + @Rule public MockitoRule rule = MockitoJUnit.rule(); @Mock private UnaryCallable callable; @@ -65,9 +68,12 @@ public final class FirestoreV1FnListCollectionIdsTest @Mock private ListCollectionIdsPagedResponse pagedResponse2; @Mock private ListCollectionIdsPage page2; - @Parameters(name = "readTime = {0}") - public static Collection data() { - return Arrays.asList(null, Instant.now()); + @Parameters(name = "readTime = {0}, setDatabaseOnFn = {1}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {null, false}, {null, true}, {Instant.now(), false}, {Instant.now(), true} + }); } private ListCollectionIdsRequest withReadTime(ListCollectionIdsRequest input, Instant readTime) { @@ -104,7 +110,7 @@ public void endToEnd() throws Exception { when(stub.listCollectionIdsPagedCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); RpcQosOptions options = RpcQosOptions.defaultOptions(); when(ff.getRpcQos(any())) .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options)); @@ -116,7 +122,7 @@ public void endToEnd() throws Exception { when(processContext.element()).thenReturn(request1); - ListCollectionIdsFn fn = new ListCollectionIdsFn(clock, ff, options, readTime); + ListCollectionIdsFn fn = getFnWithParameters(); runFunction(fn); @@ -127,7 +133,7 @@ public void endToEnd() throws Exception { @Override public void resumeFromLastReadValue() throws Exception { - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); when(ff.getRpcQos(any())).thenReturn(rpcQos); when(rpcQos.newReadAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); @@ -186,7 +192,7 @@ protected ListCollectionIdsPage computeNext() { when(stub.listCollectionIdsPagedCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); ArgumentCaptor responses = ArgumentCaptor.forClass(ListCollectionIdsResponse.class); @@ -195,7 +201,7 @@ protected ListCollectionIdsPage computeNext() { when(processContext.element()).thenReturn(request1); - ListCollectionIdsFn fn = new ListCollectionIdsFn(clock, ff, rpcQosOptions, readTime); + ListCollectionIdsFn fn = getFnWithParameters(); runFunction(fn); @@ -238,4 +244,12 @@ protected ListCollectionIdsFn getFn( RpcQosOptions rpcQosOptions) { return new ListCollectionIdsFn(clock, firestoreStatefulComponentFactory, rpcQosOptions); } + + private ListCollectionIdsFn getFnWithParameters() { + if (setDatabaseOnFn) { + return new ListCollectionIdsFn(clock, ff, rpcQosOptions, readTime, projectId, "(default)"); + } else { + return new ListCollectionIdsFn(clock, ff, rpcQosOptions, readTime); + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnListDocumentsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnListDocumentsTest.java index 2faa7c3e2f1b..54827f6d6017 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnListDocumentsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnListDocumentsTest.java @@ -47,6 +47,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; @@ -58,6 +60,9 @@ public final class FirestoreV1FnListDocumentsTest @Parameterized.Parameter public Instant readTime; + @Parameter(1) + public boolean setDatabaseOnFn; + @Rule public MockitoRule rule = MockitoJUnit.rule(); @Mock private UnaryCallable callable; @@ -66,9 +71,12 @@ public final class FirestoreV1FnListDocumentsTest @Mock private ListDocumentsPagedResponse pagedResponse2; @Mock private ListDocumentsPage page2; - @Parameterized.Parameters(name = "readTime = {0}") - public static Collection data() { - return Arrays.asList(null, Instant.now()); + @Parameters(name = "readTime = {0}, setDatabaseOnFn = {1}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {null, false}, {null, true}, {Instant.now(), false}, {Instant.now(), true} + }); } private ListDocumentsRequest withReadTime(ListDocumentsRequest request, Instant readTime) { @@ -127,7 +135,7 @@ public void endToEnd() throws Exception { when(stub.listDocumentsPagedCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); RpcQosOptions options = RpcQosOptions.defaultOptions(); when(ff.getRpcQos(any())) .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options)); @@ -139,7 +147,7 @@ public void endToEnd() throws Exception { when(processContext.element()).thenReturn(request1); - ListDocumentsFn fn = new ListDocumentsFn(clock, ff, options, readTime); + ListDocumentsFn fn = getFnWithParameters(); runFunction(fn); @@ -150,7 +158,7 @@ public void endToEnd() throws Exception { @Override public void resumeFromLastReadValue() throws Exception { - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); when(ff.getRpcQos(any())).thenReturn(rpcQos); when(rpcQos.newReadAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); @@ -231,7 +239,7 @@ protected ListDocumentsPage computeNext() { when(stub.listDocumentsPagedCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); ArgumentCaptor responses = ArgumentCaptor.forClass(ListDocumentsResponse.class); @@ -283,4 +291,12 @@ protected ListDocumentsFn getFn( RpcQosOptions rpcQosOptions) { return new ListDocumentsFn(clock, firestoreStatefulComponentFactory, rpcQosOptions); } + + private ListDocumentsFn getFnWithParameters() { + if (setDatabaseOnFn) { + return new ListDocumentsFn(clock, ff, rpcQosOptions, readTime, projectId, "(default)"); + } else { + return new ListDocumentsFn(clock, ff, rpcQosOptions, readTime); + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java index d6c69fbd96b2..20f728bab73a 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java @@ -47,6 +47,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; @@ -58,6 +60,9 @@ public final class FirestoreV1FnPartitionQueryTest @Parameterized.Parameter public Instant readTime; + @Parameter(1) + public boolean setDatabaseOnFn; + @Rule public MockitoRule rule = MockitoJUnit.rule(); @Mock private UnaryCallable callable; @@ -66,9 +71,12 @@ public final class FirestoreV1FnPartitionQueryTest @Mock private PartitionQueryPagedResponse pagedResponse2; @Mock private PartitionQueryPage page2; - @Parameterized.Parameters(name = "readTime = {0}") - public static Collection data() { - return Arrays.asList(null, Instant.now()); + @Parameters(name = "readTime = {0}, setDatabaseOnFn = {1}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {null, false}, {null, true}, {Instant.now(), false}, {Instant.now(), true} + }); } private PartitionQueryRequest withReadTime(PartitionQueryRequest request, Instant readTime) { @@ -101,7 +109,7 @@ public void endToEnd() throws Exception { when(stub.partitionQueryPagedCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); RpcQosOptions options = RpcQosOptions.defaultOptions(); when(ff.getRpcQos(any())) .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options)); @@ -113,7 +121,7 @@ public void endToEnd() throws Exception { when(processContext.element()).thenReturn(request1); - PartitionQueryFn fn = new PartitionQueryFn(clock, ff, options, readTime); + PartitionQueryFn fn = getFnWithParameters(); runFunction(fn); @@ -136,7 +144,7 @@ public void endToEnd_emptyCursors() throws Exception { when(stub.partitionQueryPagedCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); RpcQosOptions options = RpcQosOptions.defaultOptions(); when(ff.getRpcQos(any())) .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options)); @@ -148,7 +156,7 @@ public void endToEnd_emptyCursors() throws Exception { when(processContext.element()).thenReturn(request1); - PartitionQueryFn fn = new PartitionQueryFn(clock, ff, options, readTime); + PartitionQueryFn fn = getFnWithParameters(); runFunction(fn); @@ -159,7 +167,7 @@ public void endToEnd_emptyCursors() throws Exception { @Override public void resumeFromLastReadValue() throws Exception { - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); when(ff.getRpcQos(any())).thenReturn(rpcQos); when(rpcQos.newReadAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); @@ -230,7 +238,7 @@ protected PartitionQueryPage computeNext() { when(stub.partitionQueryPagedCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); ArgumentCaptor responses = ArgumentCaptor.forClass(PartitionQueryPair.class); @@ -283,4 +291,12 @@ protected PartitionQueryFn getFn( RpcQosOptions rpcQosOptions) { return new PartitionQueryFn(clock, firestoreStatefulComponentFactory, rpcQosOptions); } + + private PartitionQueryFn getFnWithParameters() { + if (setDatabaseOnFn) { + return new PartitionQueryFn(clock, ff, rpcQosOptions, readTime, projectId, "(default)"); + } else { + return new PartitionQueryFn(clock, ff, rpcQosOptions, readTime); + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnRunQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnRunQueryTest.java index 02e5f9743eaa..78dad6faeaea 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnRunQueryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnRunQueryTest.java @@ -59,6 +59,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; @@ -70,15 +72,21 @@ public final class FirestoreV1FnRunQueryTest @Parameterized.Parameter public Instant readTime; + @Parameter(1) + public boolean setDatabaseOnFn; + @Rule public MockitoRule rule = MockitoJUnit.rule(); @Mock private ServerStreamingCallable callable; @Mock private ServerStream responseStream; @Mock private ServerStream retryResponseStream; - @Parameterized.Parameters(name = "readTime = {0}") - public static Collection data() { - return Arrays.asList(null, Instant.now()); + @Parameters(name = "readTime = {0}, setDatabaseOnFn = {1}") + public static Collection data() { + return Arrays.asList( + new Object[][] { + {null, false}, {null, true}, {Instant.now(), false}, {Instant.now(), true} + }); } private RunQueryRequest withReadTime(RunQueryRequest request, Instant readTime) { @@ -100,7 +108,7 @@ public void endToEnd() throws Exception { when(stub.runQueryCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); RpcQosOptions options = RpcQosOptions.defaultOptions(); when(ff.getRpcQos(any())) .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options)); @@ -112,7 +120,7 @@ public void endToEnd() throws Exception { when(processContext.element()).thenReturn(testData.request); - RunQueryFn fn = new RunQueryFn(clock, ff, options, readTime); + RunQueryFn fn = getFnWithParameters(); runFunction(fn); @@ -242,7 +250,7 @@ protected RunQueryResponse computeNext() { when(stub.runQueryCallable()).thenReturn(callable); - when(ff.getFirestoreStub(any())).thenReturn(stub); + when(ff.getFirestoreStub(any(), any(), any())).thenReturn(stub); when(ff.getRpcQos(any())).thenReturn(rpcQos); when(rpcQos.newReadAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); @@ -302,6 +310,14 @@ protected RunQueryFn getFn( return new RunQueryFn(clock, firestoreStatefulComponentFactory, rpcQosOptions); } + private RunQueryFn getFnWithParameters() { + if (setDatabaseOnFn) { + return new RunQueryFn(clock, ff, rpcQosOptions, readTime, projectId, "(default)"); + } else { + return new RunQueryFn(clock, ff, rpcQosOptions, readTime); + } + } + private static final class TestData { static final FieldReference FILTER_FIELD_PATH = diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java index 14344b105b35..e0776927db0f 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java @@ -163,6 +163,8 @@ public final void listCollections() throws Exception { FirestoreIO.v1() .read() .listCollectionIds() + .withProjectId(project) + .withDatabaseId(databaseId) .withRpcQosOptions(RPC_QOS_OPTIONS) .build()); @@ -178,6 +180,8 @@ public final void listCollections() throws Exception { FirestoreIO.v1() .read() .listCollectionIds() + .withProjectId(project) + .withDatabaseId(databaseId) .withReadTime(readTime) .withRpcQosOptions(RPC_QOS_OPTIONS) .build()); @@ -209,7 +213,13 @@ public final void listDocuments() throws Exception { .apply(Create.of("a")) .apply(getListDocumentsPTransform(testName.getMethodName())) .apply( - FirestoreIO.v1().read().listDocuments().withRpcQosOptions(RPC_QOS_OPTIONS).build()) + FirestoreIO.v1() + .read() + .listDocuments() + .withProjectId(project) + .withDatabaseId(databaseId) + .withRpcQosOptions(RPC_QOS_OPTIONS) + .build()) .apply(ParDo.of(new DocumentToName())); PAssert.that(listDocumentPaths).containsInAnyOrder(allDocumentPaths); @@ -224,6 +234,8 @@ public final void listDocuments() throws Exception { FirestoreIO.v1() .read() .listDocuments() + .withProjectId(project) + .withDatabaseId(databaseId) .withReadTime(readTime) .withRpcQosOptions(RPC_QOS_OPTIONS) .build()) @@ -260,7 +272,14 @@ public final void runQuery() throws Exception { testPipeline .apply(Create.of(collectionId)) .apply(getRunQueryPTransform(testName.getMethodName())) - .apply(FirestoreIO.v1().read().runQuery().withRpcQosOptions(RPC_QOS_OPTIONS).build()) + .apply( + FirestoreIO.v1() + .read() + .runQuery() + .withProjectId(project) + .withDatabaseId(databaseId) + .withRpcQosOptions(RPC_QOS_OPTIONS) + .build()) .apply(ParDo.of(new RunQueryResponseToDocument())) .apply(ParDo.of(new DocumentToName())); @@ -276,6 +295,8 @@ public final void runQuery() throws Exception { FirestoreIO.v1() .read() .runQuery() + .withProjectId(project) + .withDatabaseId(databaseId) .withReadTime(readTime) .withRpcQosOptions(RPC_QOS_OPTIONS) .build()) @@ -318,8 +339,21 @@ public final void partitionQuery() throws Exception { testPipeline .apply(Create.of(collectionGroupId)) .apply(getPartitionQueryPTransform(testName.getMethodName(), partitionCount)) - .apply(FirestoreIO.v1().read().partitionQuery().withNameOnlyQuery().build()) - .apply(FirestoreIO.v1().read().runQuery().build()) + .apply( + FirestoreIO.v1() + .read() + .partitionQuery() + .withProjectId(project) + .withDatabaseId(databaseId) + .withNameOnlyQuery() + .build()) + .apply( + FirestoreIO.v1() + .read() + .runQuery() + .withProjectId(project) + .withDatabaseId(databaseId) + .build()) .apply(ParDo.of(new RunQueryResponseToDocument())) .apply(ParDo.of(new DocumentToName())); @@ -335,10 +369,19 @@ public final void partitionQuery() throws Exception { FirestoreIO.v1() .read() .partitionQuery() + .withProjectId(project) + .withDatabaseId(databaseId) .withReadTime(readTime) .withNameOnlyQuery() .build()) - .apply(FirestoreIO.v1().read().runQuery().withReadTime(readTime).build()) + .apply( + FirestoreIO.v1() + .read() + .runQuery() + .withProjectId(project) + .withDatabaseId(databaseId) + .withReadTime(readTime) + .build()) .apply(ParDo.of(new RunQueryResponseToDocument())) .apply(ParDo.of(new DocumentToName())); @@ -381,6 +424,8 @@ public final void batchGet() throws Exception { FirestoreIO.v1() .read() .batchGetDocuments() + .withProjectId(project) + .withDatabaseId(databaseId) .withRpcQosOptions(RPC_QOS_OPTIONS) .build()) .apply(Filter.by(BatchGetDocumentsResponse::hasFound)) @@ -399,6 +444,8 @@ public final void batchGet() throws Exception { FirestoreIO.v1() .read() .batchGetDocuments() + .withProjectId(project) + .withDatabaseId(databaseId) .withReadTime(readTime) .withRpcQosOptions(RPC_QOS_OPTIONS) .build())