Skip to content

Add validations for upload in s3 mulitpart client #6282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AmazonS3-6522f77.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "Amazon S3",
"contributor": "",
"description": "Add additional validations for multipart upload operations in the Java multipart S3 client."
}
6 changes: 6 additions & 0 deletions bom-internal/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@
<version>${rxjava.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.reactivex.rxjava3</groupId>
<artifactId>rxjava</artifactId>
<version>${rxjava3.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<artifactId>commons-lang3</artifactId>
<groupId>org.apache.commons</groupId>
Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
<org.eclipse.jdt.version>3.10.0</org.eclipse.jdt.version>
<org.eclipse.text.version>3.5.101</org.eclipse.text.version>
<rxjava.version>2.2.21</rxjava.version>
<rxjava3.version>3.1.5</rxjava3.version>
<commons-codec.verion>1.17.1</commons-codec.verion>
<jmh.version>1.37</jmh.version>
<awscrt.version>0.38.1</awscrt.version>
Expand Down
5 changes: 5 additions & 0 deletions services/s3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,10 @@
<artifactId>jimfs</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.reactivex.rxjava3</groupId>
<artifactId>rxjava</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@

package software.amazon.awssdk.services.s3.internal.multipart;

import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMismatchForPart;
import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.partNumMismatch;
import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.JAVA_PROGRESS_LISTENER;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -32,6 +35,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.listener.PublisherListener;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
Expand All @@ -54,10 +58,10 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber<
private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false);
private final AtomicInteger partNumber = new AtomicInteger(1);
private final MultipartUploadHelper multipartUploadHelper;
private final long contentLength;
private final long totalSize;
private final long partSize;
private final int partCount;
private final int numExistingParts;
private final int expectedNumParts;
private final int existingNumParts;
private final String uploadId;
private final Collection<CompletableFuture<CompletedPart>> futures = new ConcurrentLinkedQueue<>();
private final PutObjectRequest putObjectRequest;
Expand All @@ -77,25 +81,21 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber<
KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext,
CompletableFuture<PutObjectResponse> returnFuture,
MultipartUploadHelper multipartUploadHelper) {
this.contentLength = mpuRequestContext.contentLength();
this.totalSize = mpuRequestContext.contentLength();
this.partSize = mpuRequestContext.partSize();
this.partCount = determinePartCount(contentLength, partSize);
this.expectedNumParts = mpuRequestContext.expectedNumParts();
this.putObjectRequest = mpuRequestContext.request().left();
this.returnFuture = returnFuture;
this.uploadId = mpuRequestContext.uploadId();
this.existingParts = mpuRequestContext.existingParts() == null ? new HashMap<>() : mpuRequestContext.existingParts();
this.numExistingParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted());
this.completedParts = new AtomicReferenceArray<>(partCount);
this.existingNumParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted());
this.completedParts = new AtomicReferenceArray<>(expectedNumParts);
this.multipartUploadHelper = multipartUploadHelper;
this.progressListener = putObjectRequest.overrideConfiguration().map(c -> c.executionAttributes()
.getAttribute(JAVA_PROGRESS_LISTENER))
.orElseGet(PublisherListener::noOp);
}

private int determinePartCount(long contentLength, long partSize) {
return (int) Math.ceil(contentLength / (double) partSize);
}

public S3ResumeToken pause() {
isPaused = true;

Expand All @@ -119,8 +119,8 @@ public S3ResumeToken pause() {
return S3ResumeToken.builder()
.uploadId(uploadId)
.partSize(partSize)
.totalNumParts((long) partCount)
.numPartsCompleted(numPartsCompleted + numExistingParts)
.totalNumParts((long) expectedNumParts)
.numPartsCompleted(numPartsCompleted + existingNumParts)
.build();
}

Expand All @@ -145,21 +145,32 @@ public void onSubscribe(Subscription s) {

@Override
public void onNext(AsyncRequestBody asyncRequestBody) {
if (isPaused) {
if (isPaused || isDone) {
return;
}

if (existingParts.containsKey(partNumber.get())) {
partNumber.getAndIncrement();
int currentPartNum = partNumber.getAndIncrement();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please help me to understand why earlier we used to do contains on get, now we first increment and then do containsKey check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why we did that earlier, but the reason I changed is to avoid another atomic integer get call (micro perf optimization)

if (existingParts.containsKey(currentPartNum)) {
asyncRequestBody.subscribe(new CancelledSubscriber<>());
subscription.request(1);
asyncRequestBody.contentLength().ifPresent(progressListener::subscriberOnNext);
return;
}

Optional<SdkClientException> sdkClientException = validatePart(asyncRequestBody, currentPartNum);
if (sdkClientException.isPresent()) {
multipartUploadHelper.failRequestsElegantly(futures,
sdkClientException.get(),
uploadId,
returnFuture,
putObjectRequest);
subscription.cancel();
return;
}

asyncRequestBodyInFlight.incrementAndGet();
UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest,
partNumber.getAndIncrement(),
currentPartNum,
uploadId);

Consumer<CompletedPart> completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1,
Expand All @@ -179,6 +190,39 @@ public void onNext(AsyncRequestBody asyncRequestBody) {
subscription.request(1);
}

private Optional<SdkClientException> validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) {
if (!asyncRequestBody.contentLength().isPresent()) {
return Optional.of(MultipartUploadHelper.contentLengthMissingForPart(currentPartNum));
}

Long currentPartSize = asyncRequestBody.contentLength().get();

if (currentPartNum > expectedNumParts) {
return Optional.of(partNumMismatch(expectedNumParts, currentPartNum));
}

if (currentPartNum == expectedNumParts) {
return validateLastPartSize(currentPartSize);
}

if (currentPartSize != partSize) {
return Optional.of(contentLengthMismatchForPart(partSize, currentPartSize));
}
return Optional.empty();
}

private Optional<SdkClientException> validateLastPartSize(Long currentPartSize) {
long remainder = totalSize % partSize;
long expectedLastPartSize = remainder == 0 ? partSize : remainder;
if (currentPartSize != expectedLastPartSize) {
return Optional.of(
SdkClientException.create("Content length of the last part must be equal to the "
+ "expected last part size. Expected: " + expectedLastPartSize
+ ", Actual: " + currentPartSize));
}
return Optional.empty();
}

private boolean shouldFailRequest() {
return failureActionInitiated.compareAndSet(false, true) && !isPaused;
}
Expand All @@ -187,6 +231,7 @@ private boolean shouldFailRequest() {
public void onError(Throwable t) {
log.debug(() -> "Received onError ", t);
if (failureActionInitiated.compareAndSet(false, true)) {
isDone = true;
multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest);
}
}
Expand All @@ -203,6 +248,7 @@ public void onComplete() {
private void completeMultipartUploadIfFinished(int requestsInFlight) {
if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) {
CompletedPart[] parts;

if (existingParts.isEmpty()) {
parts =
IntStream.range(0, completedParts.length())
Expand All @@ -212,15 +258,23 @@ private void completeMultipartUploadIfFinished(int requestsInFlight) {
// List of CompletedParts needs to be in ascending order
parts = mergeCompletedParts();
}

int actualNumParts = partNumber.get() - 1;
if (actualNumParts != expectedNumParts) {
SdkClientException exception = partNumMismatch(expectedNumParts, actualNumParts);
multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest);
return;
}

completeMpuFuture = multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest,
contentLength);
totalSize);
}
}

private CompletedPart[] mergeCompletedParts() {
CompletedPart[] merged = new CompletedPart[partCount];
CompletedPart[] merged = new CompletedPart[expectedNumParts];
int currPart = 1;
while (currPart < partCount + 1) {
while (currPart < expectedNumParts + 1) {
CompletedPart completedPart = existingParts.containsKey(currPart) ? existingParts.get(currPart) :
completedParts.get(currPart - 1);
merged[currPart - 1] = completedPart;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.utils.Pair;
import software.amazon.awssdk.utils.Validate;

@SdkInternalApi
public class MpuRequestContext {
Expand All @@ -32,6 +33,7 @@ public class MpuRequestContext {
private final Long numPartsCompleted;
private final String uploadId;
private final Map<Integer, CompletedPart> existingParts;
private final int expectedNumParts;

protected MpuRequestContext(Builder builder) {
this.request = builder.request;
Expand All @@ -40,6 +42,8 @@ protected MpuRequestContext(Builder builder) {
this.uploadId = builder.uploadId;
this.existingParts = builder.existingParts;
this.numPartsCompleted = builder.numPartsCompleted;
this.expectedNumParts = Validate.paramNotNull(builder.expectedNumParts,
"expectedNumParts");
}

public static Builder builder() {
Expand All @@ -56,9 +60,13 @@ public boolean equals(Object o) {
}
MpuRequestContext that = (MpuRequestContext) o;

return Objects.equals(request, that.request) && Objects.equals(contentLength, that.contentLength)
&& Objects.equals(partSize, that.partSize) && Objects.equals(numPartsCompleted, that.numPartsCompleted)
&& Objects.equals(uploadId, that.uploadId) && Objects.equals(existingParts, that.existingParts);
return expectedNumParts == that.expectedNumParts
&& Objects.equals(request, that.request)
&& Objects.equals(contentLength, that.contentLength)
&& Objects.equals(partSize, that.partSize)
&& Objects.equals(numPartsCompleted, that.numPartsCompleted)
&& Objects.equals(uploadId, that.uploadId)
&& Objects.equals(existingParts, that.existingParts);
}

@Override
Expand All @@ -69,6 +77,7 @@ public int hashCode() {
result = 31 * result + (contentLength != null ? contentLength.hashCode() : 0);
result = 31 * result + (partSize != null ? partSize.hashCode() : 0);
result = 31 * result + (numPartsCompleted != null ? numPartsCompleted.hashCode() : 0);
result = 31 * result + expectedNumParts;
return result;
}

Expand All @@ -92,6 +101,10 @@ public String uploadId() {
return uploadId;
}

public int expectedNumParts() {
return expectedNumParts;
}

public Map<Integer, CompletedPart> existingParts() {
return existingParts;
}
Expand All @@ -103,6 +116,7 @@ public static final class Builder {
private Long numPartsCompleted;
private String uploadId;
private Map<Integer, CompletedPart> existingParts;
private Integer expectedNumParts;

private Builder() {
}
Expand Down Expand Up @@ -137,6 +151,11 @@ public Builder existingParts(Map<Integer, CompletedPart> existingParts) {
return this;
}

public Builder expectedNumParts(Integer expectedNumParts) {
this.expectedNumParts = expectedNumParts;
return this;
}

public MpuRequestContext build() {
return new MpuRequestContext(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.listener.PublisherListener;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedPart;
Expand All @@ -47,18 +48,15 @@ public final class MultipartUploadHelper {
private static final Logger log = Logger.loggerFor(MultipartUploadHelper.class);

private final S3AsyncClient s3AsyncClient;
private final long partSizeInBytes;
private final GenericMultipartHelper<PutObjectRequest, PutObjectResponse> genericMultipartHelper;

private final long maxMemoryUsageInBytes;
private final long multipartUploadThresholdInBytes;

public MultipartUploadHelper(S3AsyncClient s3AsyncClient,
long partSizeInBytes,
long multipartUploadThresholdInBytes,
long maxMemoryUsageInBytes) {
this.s3AsyncClient = s3AsyncClient;
this.partSizeInBytes = partSizeInBytes;
this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient,
SdkPojoConversionUtils::toAbortMultipartUploadRequest,
SdkPojoConversionUtils::toPutObjectResponse);
Expand Down Expand Up @@ -123,11 +121,18 @@ void failRequestsElegantly(Collection<CompletableFuture<CompletedPart>> futures,
String uploadId,
CompletableFuture<PutObjectResponse> returnFuture,
PutObjectRequest putObjectRequest) {
genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t);
if (uploadId != null) {
genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest));

try {
genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t);
if (uploadId != null) {
genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest));
}
cancelingOtherOngoingRequests(futures, t);
} catch (Throwable throwable) {
returnFuture.completeExceptionally(SdkClientException.create("Unexpected error occurred while handling the upstream "
+ "exception.", throwable));
}
cancelingOtherOngoingRequests(futures, t);

}

static void cancelingOtherOngoingRequests(Collection<CompletableFuture<CompletedPart>> futures, Throwable t) {
Expand All @@ -152,4 +157,22 @@ void uploadInOneChunk(PutObjectRequest putObjectRequest,
CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectResponseCompletableFuture);
CompletableFutureUtils.forwardResultTo(putObjectResponseCompletableFuture, returnFuture);
}

static SdkClientException contentLengthMissingForPart(int currentPartNum) {
return SdkClientException.create("Content length is missing on the AsyncRequestBody for part number " + currentPartNum);
}

static SdkClientException contentLengthMismatchForPart(long expected, long actual) {
return SdkClientException.create(String.format("Content length must not be greater than "
+ "part size. Expected: %d, Actual: %d",
expected,
actual));
}

static SdkClientException partNumMismatch(int expectedNumParts, int actualNumParts) {
return SdkClientException.create(String.format("The number of parts divided is "
+ "not equal to the expected number of "
+ "parts. Expected: %d, Actual: %d",
expectedNumParts, actualNumParts));
}
}
Loading
Loading