Skip to content

Commit ead8d6f

Browse files
authored
GH-44363: [C#] Handle Flight data with zero batches (#45315)
### Rationale for this change See #44363. This improves compatibility with other Flight implementations and means user code works with empty data without needing to treat it as a special case to work around this limitation. ### What changes are included in this PR? * Adds new async overloads of `FlightClient.StartPut` that immediately send the schema, before any data batches are sent. * Updates the test server to send the schema on `DoGet` even when there are no data batches. * Enables the `primitive_no_batches` test case for C# Flight. ### Are these changes tested? Yes, using a new unit test and with the integration tests. ### Are there any user-facing changes? Yes. New overloads of the `FlightClient.StartPut` method have been added that are async and accept a `Schema` parameter, and ensure the schema is sent when no data batches are sent. * GitHub Issue: #44363 Authored-by: Adam Reeve <[email protected]> Signed-off-by: Curt Hagenlocher <[email protected]>
1 parent ba79a48 commit ead8d6f

File tree

7 files changed

+109
-20
lines changed

7 files changed

+109
-20
lines changed

csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs

+55
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,39 @@ public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Met
9898
flightInfoResult.Dispose);
9999
}
100100

101+
/// <summary>
102+
/// Start a Flight Put request.
103+
/// </summary>
104+
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
105+
/// <param name="headers">gRPC headers to send with the request</param>
106+
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
101107
public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null)
102108
{
103109
return StartPut(flightDescriptor, headers, null, CancellationToken.None);
104110
}
105111

112+
/// <summary>
113+
/// Start a Flight Put request.
114+
/// </summary>
115+
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
116+
/// <param name="schema">The schema of the data</param>
117+
/// <param name="headers">gRPC headers to send with the request</param>
118+
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
119+
/// <remarks>Using this method rather than a StartPut overload that doesn't accept a schema
120+
/// means that the schema is sent even if no data batches are sent</remarks>
121+
public Task<FlightRecordBatchDuplexStreamingCall> StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers = null)
122+
{
123+
return StartPut(flightDescriptor, schema, headers, null, CancellationToken.None);
124+
}
125+
126+
/// <summary>
127+
/// Start a Flight Put request.
128+
/// </summary>
129+
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
130+
/// <param name="headers">gRPC headers to send with the request</param>
131+
/// <param name="deadline">Optional deadline. The request will be cancelled if this deadline is reached.</param>
132+
/// <param name="cancellationToken">Optional token for cancelling the request</param>
133+
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
106134
public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
107135
{
108136
var channels = _client.DoPut(headers, deadline, cancellationToken);
@@ -117,6 +145,33 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc
117145
channels.Dispose);
118146
}
119147

148+
/// <summary>
149+
/// Start a Flight Put request.
150+
/// </summary>
151+
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
152+
/// <param name="schema">The schema of the data</param>
153+
/// <param name="headers">gRPC headers to send with the request</param>
154+
/// <param name="deadline">Optional deadline. The request will be cancelled if this deadline is reached.</param>
155+
/// <param name="cancellationToken">Optional token for cancelling the request</param>
156+
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
157+
/// <remarks>Using this method rather than a StartPut overload that doesn't accept a schema
158+
/// means that the schema is sent even if no data batches are sent</remarks>
159+
public async Task<FlightRecordBatchDuplexStreamingCall> StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
160+
{
161+
var channels = _client.DoPut(headers, deadline, cancellationToken);
162+
var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
163+
var readStream = new StreamReader<Protocol.PutResult, FlightPutResult>(channels.ResponseStream, putResult => new FlightPutResult(putResult));
164+
var streamingCall = new FlightRecordBatchDuplexStreamingCall(
165+
requestStream,
166+
readStream,
167+
channels.ResponseHeadersAsync,
168+
channels.GetStatus,
169+
channels.GetTrailers,
170+
channels.Dispose);
171+
await streamingCall.RequestStream.SetupStream(schema).ConfigureAwait(false);
172+
return streamingCall;
173+
}
174+
120175
public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers = null)
121176
{
122177
return Handshake(headers, null, CancellationToken.None);

csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs

+17-4
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,22 @@ private protected FlightRecordBatchStreamWriter(IAsyncStreamWriter<Protocol.Flig
3838
_flightDescriptor = flightDescriptor;
3939
}
4040

41-
private void SetupStream(Schema schema)
41+
/// <summary>
42+
/// Configure the data stream to write to.
43+
/// </summary>
44+
/// <remarks>
45+
/// The stream will be set up automatically when writing a RecordBatch if required,
46+
/// but calling this method before writing any data allows handling empty streams.
47+
/// </remarks>
48+
/// <param name="schema">The schema of data to be written to this stream</param>
49+
public async Task SetupStream(Schema schema)
4250
{
51+
if (_flightDataStream != null)
52+
{
53+
throw new InvalidOperationException("Flight data stream is already set");
54+
}
4355
_flightDataStream = new FlightDataStream(_clientStreamWriter, _flightDescriptor, schema);
56+
await _flightDataStream.SendSchema().ConfigureAwait(false);
4457
}
4558

4659
public WriteOptions WriteOptions { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
@@ -50,14 +63,14 @@ public Task WriteAsync(RecordBatch message)
5063
return WriteAsync(message, default);
5164
}
5265

53-
public Task WriteAsync(RecordBatch message, ByteString applicationMetadata)
66+
public async Task WriteAsync(RecordBatch message, ByteString applicationMetadata)
5467
{
5568
if (_flightDataStream == null)
5669
{
57-
SetupStream(message.Schema);
70+
await SetupStream(message.Schema).ConfigureAwait(false);
5871
}
5972

60-
return _flightDataStream.Write(message, applicationMetadata);
73+
await _flightDataStream.Write(message, applicationMetadata);
6174
}
6275

6376
protected virtual void Dispose(bool disposing)

csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public FlightDataStream(IAsyncStreamWriter<Protocol.FlightData> clientStreamWrit
4444
_flightDescriptor = flightDescriptor;
4545
}
4646

47-
private async Task SendSchema()
47+
public async Task SendSchema()
4848
{
4949
_currentFlightData = new Protocol.FlightData();
5050

csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public async Task RunClient(int serverPort)
7676
var batches = jsonFile.Batches.Select(batch => batch.ToArrow(schema, dictionaries)).ToArray();
7777

7878
// 1. Put the data to the server.
79-
await UploadBatches(client, descriptor, batches).ConfigureAwait(false);
79+
await UploadBatches(client, descriptor, schema, batches).ConfigureAwait(false);
8080

8181
// 2. Get the ticket for the data.
8282
var info = await client.GetInfo(descriptor).ConfigureAwait(false);
@@ -112,9 +112,10 @@ public async Task RunClient(int serverPort)
112112
}
113113
}
114114

115-
private static async Task UploadBatches(FlightClient client, FlightDescriptor descriptor, RecordBatch[] batches)
115+
private static async Task UploadBatches(
116+
FlightClient client, FlightDescriptor descriptor, Schema schema, RecordBatch[] batches)
116117
{
117-
using var putCall = client.StartPut(descriptor);
118+
using var putCall = await client.StartPut(descriptor, schema);
118119
using var writer = putCall.RequestStream;
119120

120121
try

csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStr
5151

5252
if(_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
5353
{
54+
await responseStream.SetupStream(flightHolder.GetFlightInfo().Schema);
55+
5456
var batches = flightHolder.GetRecordBatches();
5557

56-
5758
foreach(var batch in batches)
5859
{
5960
await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata);

csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs

+29-7
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ private RecordBatch CreateTestBatch(int startValue, int length)
5757
return batchBuilder.Build();
5858
}
5959

60+
private Schema GetStoreSchema(FlightDescriptor flightDescriptor)
61+
{
62+
Assert.Contains(flightDescriptor, (IReadOnlyDictionary<FlightDescriptor, FlightHolder>)_flightStore.Flights);
63+
64+
var flightHolder = _flightStore.Flights[flightDescriptor];
65+
return flightHolder.GetFlightInfo().Schema;
66+
}
6067

6168
private IEnumerable<RecordBatchWithMetadata> GetStoreBatch(FlightDescriptor flightDescriptor)
6269
{
@@ -88,7 +95,7 @@ public async Task TestPutSingleRecordBatch()
8895
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
8996
var expectedBatch = CreateTestBatch(0, 100);
9097

91-
var putStream = _flightClient.StartPut(flightDescriptor);
98+
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch.Schema);
9299
await putStream.RequestStream.WriteAsync(expectedBatch);
93100
await putStream.RequestStream.CompleteAsync();
94101
var putResults = await putStream.ResponseStream.ToListAsync();
@@ -108,7 +115,7 @@ public async Task TestPutTwoRecordBatches()
108115
var expectedBatch1 = CreateTestBatch(0, 100);
109116
var expectedBatch2 = CreateTestBatch(0, 100);
110117

111-
var putStream = _flightClient.StartPut(flightDescriptor);
118+
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch1.Schema);
112119
await putStream.RequestStream.WriteAsync(expectedBatch1);
113120
await putStream.RequestStream.WriteAsync(expectedBatch2);
114121
await putStream.RequestStream.CompleteAsync();
@@ -123,6 +130,23 @@ public async Task TestPutTwoRecordBatches()
123130
ArrowReaderVerifier.CompareBatches(expectedBatch2, actualBatches[1].RecordBatch);
124131
}
125132

133+
[Fact]
134+
public async Task TestPutZeroRecordBatches()
135+
{
136+
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
137+
var schema = CreateTestBatch(0, 1).Schema;
138+
139+
var putStream = await _flightClient.StartPut(flightDescriptor, schema);
140+
await putStream.RequestStream.CompleteAsync();
141+
var putResults = await putStream.ResponseStream.ToListAsync();
142+
143+
Assert.Empty(putResults);
144+
145+
var actualSchema = GetStoreSchema(flightDescriptor);
146+
147+
SchemaComparer.Compare(schema, actualSchema);
148+
}
149+
126150
[Fact]
127151
public async Task TestGetRecordBatchWithDelayedSchema()
128152
{
@@ -230,7 +254,7 @@ public async Task TestPutWithMetadata()
230254
var expectedBatch = CreateTestBatch(0, 100);
231255
var expectedMetadata = ByteString.CopyFromUtf8("test metadata");
232256

233-
var putStream = _flightClient.StartPut(flightDescriptor);
257+
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch.Schema);
234258
await putStream.RequestStream.WriteAsync(expectedBatch, expectedMetadata);
235259
await putStream.RequestStream.CompleteAsync();
236260
var putResults = await putStream.ResponseStream.ToListAsync();
@@ -471,8 +495,7 @@ public async Task EnsureCallRaisesDeadlineExceeded()
471495
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
472496
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
473497

474-
var putStream = _flightClient.StartPut(flightDescriptor, null, deadline);
475-
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
498+
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.StartPut(flightDescriptor, batch.Schema, null, deadline));
476499
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
477500

478501
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, deadline));
@@ -514,8 +537,7 @@ public async Task EnsureCallRaisesRequestCancelled()
514537
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
515538
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
516539

517-
var putStream = _flightClient.StartPut(flightDescriptor, null, null, cts.Token);
518-
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
540+
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.StartPut(flightDescriptor, batch.Schema, null, null, cts.Token));
519541
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
520542

521543
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token));

dev/archery/archery/integration/datagen.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1890,10 +1890,7 @@ def _temp_path():
18901890
return
18911891

18921892
file_objs = [
1893-
generate_primitive_case([], name='primitive_no_batches')
1894-
# TODO(https://github.com/apache/arrow/issues/44363)
1895-
.skip_format(SKIP_FLIGHT, 'C#'),
1896-
1893+
generate_primitive_case([], name='primitive_no_batches'),
18971894
generate_primitive_case([17, 20], name='primitive'),
18981895
generate_primitive_case([0, 0, 0], name='primitive_zerolength'),
18991896

0 commit comments

Comments
 (0)