Skip to content

Commit

Permalink
API, AWS: Add RetryableInputStream and use that in S3InputStream
Browse files Browse the repository at this point in the history
Co-authored-by: Jack Ye <[email protected]>
Co-authored-by: Xiaoxuan Li <[email protected]>
  • Loading branch information
3 people committed Aug 22, 2024
1 parent 257b1d7 commit 28513fc
Show file tree
Hide file tree
Showing 6 changed files with 422 additions and 21 deletions.
34 changes: 22 additions & 12 deletions aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.iceberg.io.FileIOMetricsContext;
import org.apache.iceberg.io.IOUtil;
import org.apache.iceberg.io.RangeReadable;
import org.apache.iceberg.io.RetryableInputStream;
import org.apache.iceberg.io.SeekableInputStream;
import org.apache.iceberg.metrics.Counter;
import org.apache.iceberg.metrics.MetricsContext;
Expand Down Expand Up @@ -92,13 +93,13 @@ public void seek(long newPos) {
public int read() throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();

int bytesRead = stream.read();
pos += 1;
next += 1;
readBytes.increment();
readOperations.increment();

return stream.read();
return bytesRead;
}

@Override
Expand Down Expand Up @@ -139,7 +140,11 @@ private InputStream readRange(String range) {

S3RequestUtil.configureEncryption(s3FileIOProperties, requestBuilder);

return s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream());
stream =
RetryableInputStream.builderFor(
() -> s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream()))
.build();
return stream;
}

@Override
Expand Down Expand Up @@ -178,18 +183,23 @@ private void positionStream() throws IOException {
}

private void openStream() throws IOException {
GetObjectRequest.Builder requestBuilder =
GetObjectRequest.builder()
.bucket(location.bucket())
.key(location.key())
.range(String.format("bytes=%s-", pos));

S3RequestUtil.configureEncryption(s3FileIOProperties, requestBuilder);

closeStream();

try {
stream = s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream());
stream =
RetryableInputStream.builderFor(
rangeStart -> {
GetObjectRequest.Builder requestBuilder =
GetObjectRequest.builder()
.bucket(location.bucket())
.key(location.key())
.range(String.format("bytes=%s-", rangeStart));
S3RequestUtil.configureEncryption(s3FileIOProperties, requestBuilder);
return s3.getObject(
requestBuilder.build(), ResponseTransformer.toInputStream());
},
() -> pos)
.build();
} catch (NoSuchKeyException e) {
throw new NotFoundException(e, "Location does not exist: %s", location);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.aws.s3;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;

import java.io.IOException;
import java.io.InputStream;
import java.net.SocketTimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import javax.net.ssl.SSLException;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.CreateBucketResponse;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;

public class TestFuzzyS3InputStream extends TestS3InputStream {

private static final int DATA_SIZE = 100;
private static final int SEEK_SIZE = 4;
private static final int SEEK_NEW_POSITION = 25;

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testReadWithFuzzyStreamRetrySucceed(IOException exception) throws Exception {
testRead(fuzzyStreamClient(new AtomicInteger(3), exception), DATA_SIZE);
}

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testReadWithFuzzyStreamExhaustedRetries(IOException exception) {
assertThatThrownBy(
() -> testRead(fuzzyStreamClient(new AtomicInteger(5), exception), DATA_SIZE))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
}

@ParameterizedTest
@MethodSource("nonRetryableExceptions")
public void testReadWithFuzzyStreamNonRetryableException(IOException exception) {
assertThatThrownBy(
() -> testRead(fuzzyStreamClient(new AtomicInteger(3), exception), DATA_SIZE))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
}

@Override
protected void testRead(S3Client s3, int dataSize) throws Exception {
testRead(s3, DATA_SIZE, 4, SEEK_SIZE, SEEK_NEW_POSITION);
}

private static Stream<Arguments> retryableExceptions() {
return Stream.of(
Arguments.of(
new SocketTimeoutException("socket timeout exception"),
new SSLException("some ssl exception")));
}

private static Stream<Arguments> nonRetryableExceptions() {
return Stream.of(Arguments.of(new IOException("some generic non-retryable IO exception")));
}

private S3ClientWrapper fuzzyStreamClient(AtomicInteger counter, IOException failure) {
S3ClientWrapper fuzzyClient = spy(new S3ClientWrapper(s3Client()));
doAnswer(
invocation ->
new FuzzyResponseInputStream(invocation.callRealMethod(), counter, failure))
.when(fuzzyClient)
.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class));
return fuzzyClient;
}

/** Wrapper for S3 client, used to mock the final class DefaultS3Client */
public static class S3ClientWrapper implements S3Client {

private final S3Client delegate;

public S3ClientWrapper(S3Client delegate) {
this.delegate = delegate;
}

@Override
public String serviceName() {
return delegate.serviceName();
}

@Override
public void close() {
delegate.close();
}

@Override
public <ReturnT> ReturnT getObject(
GetObjectRequest getObjectRequest,
ResponseTransformer<GetObjectResponse, ReturnT> responseTransformer)
throws AwsServiceException, SdkClientException {
return delegate.getObject(getObjectRequest, responseTransformer);
}

@Override
public HeadObjectResponse headObject(HeadObjectRequest headObjectRequest)
throws AwsServiceException, SdkClientException {
return delegate.headObject(headObjectRequest);
}

@Override
public PutObjectResponse putObject(PutObjectRequest putObjectRequest, RequestBody requestBody)
throws AwsServiceException, SdkClientException {
return delegate.putObject(putObjectRequest, requestBody);
}

@Override
public CreateBucketResponse createBucket(CreateBucketRequest createBucketRequest)
throws AwsServiceException, SdkClientException {
return delegate.createBucket(createBucketRequest);
}
}

static class FuzzyResponseInputStream extends InputStream {

private final ResponseInputStream<GetObjectResponse> delegate;
private final AtomicInteger counter;
private final int round;
private final IOException exception;

FuzzyResponseInputStream(
Object invocationResponse, AtomicInteger counter, IOException exception) {
this.delegate = (ResponseInputStream<GetObjectResponse>) invocationResponse;
this.counter = counter;
this.round = counter.get();
this.exception = exception;
}

private void checkCounter() throws IOException {
// for every round of n invocations, only the last call succeeds
if (counter.decrementAndGet() == 0) {
counter.set(round);
} else {
throw exception;
}
}

@Override
public int read() throws IOException {
checkCounter();
return delegate.read();
}

@Override
public int read(byte[] b) throws IOException {
checkCounter();
return delegate.read(b);
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
checkCounter();
return delegate.read(b, off, len);
}

@Override
public long skip(long n) throws IOException {
return delegate.skip(n);
}

@Override
public int available() throws IOException {
return delegate.available();
}

@Override
public void close() throws IOException {
delegate.close();
}

@Override
public synchronized void mark(int readlimit) {
delegate.mark(readlimit);
}

@Override
public synchronized void reset() throws IOException {
delegate.reset();
}

@Override
public boolean markSupported() {
return delegate.markSupported();
}
}
}
38 changes: 29 additions & 9 deletions aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.iceberg.io.IOUtil;
import org.apache.iceberg.io.RangeReadable;
import org.apache.iceberg.io.SeekableInputStream;
import org.apache.iceberg.metrics.MetricsContext;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -54,21 +55,26 @@ public void before() {

@Test
public void testRead() throws Exception {
testRead(s3, 10 * 1024 * 1024);
}

protected void testRead(S3Client s3Client, int dataSize) throws Exception {
testRead(s3Client, 10 * 1024 * 1024, 1024, 1024, 2 * 1024 * 1024);
}

protected void testRead(
S3Client s3Client, int dataSize, int readSize, int seekSize, int seekNewStreamPosition)
throws Exception {
S3URI uri = new S3URI("s3://bucket/path/to/read.dat");
int dataSize = 1024 * 1024 * 10;
byte[] data = randomData(dataSize);

writeS3Data(uri, data);

try (SeekableInputStream in = new S3InputStream(s3, uri)) {
int readSize = 1024;
byte[] actual = new byte[readSize];

try (SeekableInputStream in = new S3InputStream(s3Client, uri)) {
readAndCheck(in, in.getPos(), readSize, data, false);
readAndCheck(in, in.getPos(), readSize, data, true);

// Seek forward in current stream
int seekSize = 1024;
readAndCheck(in, in.getPos() + seekSize, readSize, data, false);
readAndCheck(in, in.getPos() + seekSize, readSize, data, true);

Expand All @@ -77,7 +83,6 @@ public void testRead() throws Exception {
readAndCheck(in, in.getPos(), readSize, data, false);

// Seek with new stream
long seekNewStreamPosition = 2 * 1024 * 1024;
readAndCheck(in, in.getPos() + seekNewStreamPosition, readSize, data, true);
readAndCheck(in, in.getPos() + seekNewStreamPosition, readSize, data, false);

Expand Down Expand Up @@ -111,6 +116,11 @@ private void readAndCheck(

@Test
public void testRangeRead() throws Exception {
testRangeRead(s3, new S3FileIOProperties());
}

protected void testRangeRead(S3Client s3Client, S3FileIOProperties awsProperties)
throws Exception {
S3URI uri = new S3URI("s3://bucket/path/to/range-read.dat");
int dataSize = 1024 * 1024 * 10;
byte[] expected = randomData(dataSize);
Expand All @@ -122,7 +132,8 @@ public void testRangeRead() throws Exception {

writeS3Data(uri, expected);

try (RangeReadable in = new S3InputStream(s3, uri)) {
try (RangeReadable in =
new S3InputStream(s3Client, uri, awsProperties, MetricsContext.nullMetrics())) {
// first 1k
position = 0;
offset = 0;
Expand Down Expand Up @@ -163,12 +174,17 @@ public void testClose() throws Exception {

@Test
public void testSeek() throws Exception {
testSeek(s3, new S3FileIOProperties());
}

protected void testSeek(S3Client s3Client, S3FileIOProperties awsProperties) throws Exception {
S3URI uri = new S3URI("s3://bucket/path/to/seek.dat");
byte[] expected = randomData(1024 * 1024);

writeS3Data(uri, expected);

try (SeekableInputStream in = new S3InputStream(s3, uri)) {
try (SeekableInputStream in =
new S3InputStream(s3Client, uri, awsProperties, MetricsContext.nullMetrics())) {
in.seek(expected.length / 2);
byte[] actual = new byte[expected.length / 2];
IOUtil.readFully(in, actual, 0, expected.length / 2);
Expand Down Expand Up @@ -200,4 +216,8 @@ private void createBucket(String bucketName) {
// don't do anything
}
}

protected S3Client s3Client() {
return s3;
}
}
Loading

0 comments on commit 28513fc

Please sign in to comment.