Skip to content

Commit 4f49de1

Browse files
committed
[FLINK-38267][checkpoint] Only call channel state rescaling logic for exchange with channel state to avoid UnsupportedOperationException
1 parent c33e974 commit 4f49de1

File tree

8 files changed

+910
-10
lines changed

8 files changed

+910
-10
lines changed

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ public InflightDataRescalingDescriptor(
4545
}
4646

4747
public int[] getOldSubtaskIndexes(int gateOrPartitionIndex) {
48-
return gateOrPartitionDescriptors[gateOrPartitionIndex].oldSubtaskIndexes;
48+
return gateOrPartitionDescriptors[gateOrPartitionIndex].getOldSubtaskInstances();
4949
}
5050

5151
public RescaleMappings getChannelMapping(int gateOrPartitionIndex) {
52-
return gateOrPartitionDescriptors[gateOrPartitionIndex].rescaledChannelsMappings;
52+
return gateOrPartitionDescriptors[gateOrPartitionIndex].getRescaleMappings();
5353
}
5454

5555
public boolean isAmbiguous(int gateOrPartitionIndex, int oldSubtaskIndex) {
@@ -112,6 +112,28 @@ public String toString() {
112112
*/
113113
public static class InflightDataGateOrPartitionRescalingDescriptor implements Serializable {
114114

115+
public static final InflightDataGateOrPartitionRescalingDescriptor NO_STATE =
116+
new InflightDataGateOrPartitionRescalingDescriptor(
117+
new int[0],
118+
RescaleMappings.identity(0, 0),
119+
java.util.Collections.emptySet(),
120+
MappingType.IDENTITY) {
121+
122+
private static final long serialVersionUID = 1L;
123+
124+
@Override
125+
public int[] getOldSubtaskInstances() {
126+
throw new UnsupportedOperationException(
127+
"Cannot get old subtasks from a descriptor that represents no state.");
128+
}
129+
130+
@Override
131+
public RescaleMappings getRescaleMappings() {
132+
throw new UnsupportedOperationException(
133+
"Cannot get rescale mappings from a descriptor that represents no state.");
134+
}
135+
};
136+
115137
private static final long serialVersionUID = 1L;
116138

117139
/** Set when several operator instances are merged into one. */
@@ -145,6 +167,14 @@ public InflightDataGateOrPartitionRescalingDescriptor(
145167
this.mappingType = mappingType;
146168
}
147169

170+
public int[] getOldSubtaskInstances() {
171+
return oldSubtaskIndexes;
172+
}
173+
174+
public RescaleMappings getRescaleMappings() {
175+
return rescaledChannelsMappings;
176+
}
177+
148178
public boolean isIdentity() {
149179
return mappingType == MappingType.IDENTITY;
150180
}

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,9 @@ public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment)
386386
// Parallelism of this vertex changed, distribute ResultSubpartitionStateHandle
387387
// according to output mapping.
388388
for (int partitionIndex = 0; partitionIndex < outputs.size(); partitionIndex++) {
389+
if (!assignment.hasInFlightDataForResultPartition(partitionIndex)) {
390+
continue;
391+
}
389392
final List<List<ResultSubpartitionStateHandle>> partitionState =
390393
outputs.size() == 1
391394
? outputOperatorState
@@ -466,6 +469,9 @@ public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment)
466469
// subtask 0 recovers data from old subtask 0 + 1 and subtask 1 recovers data from old
467470
// subtask 1 + 2
468471
for (int gateIndex = 0; gateIndex < inputs.size(); gateIndex++) {
472+
if (!stateAssignment.hasInFlightDataForInputGate(gateIndex)) {
473+
continue;
474+
}
469475
final RescaleMappings mapping =
470476
stateAssignment.getInputMapping(gateIndex).getRescaleMappings();
471477

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
252252
return assignment.getOutputMapping(assignmentIndex, recompute);
253253
},
254254
inputSubtaskMappings,
255-
this::getInputMapping))
255+
this::getInputMapping,
256+
true))
256257
.setOutputRescalingDescriptor(
257258
createRescalingDescriptor(
258259
instanceID,
@@ -265,7 +266,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
265266
return assignment.getInputMapping(assignmentIndex, recompute);
266267
},
267268
outputSubtaskMappings,
268-
this::getOutputMapping))
269+
this::getOutputMapping,
270+
false))
269271
.build();
270272
}
271273

@@ -314,7 +316,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
314316
TaskStateAssignment[] connectedAssignments,
315317
BiFunction<TaskStateAssignment, Boolean, SubtasksRescaleMapping> mappingRetriever,
316318
Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings,
317-
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator) {
319+
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator,
320+
boolean isInput) {
318321
if (!expectedOperatorID.equals(instanceID.getOperatorId())) {
319322
return InflightDataRescalingDescriptor.NO_RESCALE;
320323
}
@@ -337,7 +340,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
337340
assignment -> mappingRetriever.apply(assignment, true),
338341
subtaskGateOrPartitionMappings,
339342
subtaskMappingCalculator,
340-
rescaledChannelsMappings);
343+
rescaledChannelsMappings,
344+
isInput);
341345

342346
if (Arrays.stream(gateOrPartitionDescriptors)
343347
.allMatch(InflightDataGateOrPartitionRescalingDescriptor::isIdentity)) {
@@ -356,10 +360,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
356360
Function<TaskStateAssignment, SubtasksRescaleMapping> mappingCalculator,
357361
Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings,
358362
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator,
359-
SubtasksRescaleMapping[] rescaledChannelsMappings) {
363+
SubtasksRescaleMapping[] rescaledChannelsMappings,
364+
boolean isInput) {
360365
return IntStream.range(0, rescaledChannelsMappings.length)
361366
.mapToObj(
362367
partition -> {
368+
if (!hasInFlightData(isInput, partition)) {
369+
return InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
370+
}
363371
TaskStateAssignment connectedAssignment =
364372
connectedAssignments[partition];
365373
SubtasksRescaleMapping rescaleMapping =
@@ -381,6 +389,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
381389
.toArray(InflightDataGateOrPartitionRescalingDescriptor[]::new);
382390
}
383391

392+
private boolean hasInFlightData(boolean isInput, int gateOrPartitionIndex) {
393+
if (isInput) {
394+
return hasInFlightDataForInputGate(gateOrPartitionIndex);
395+
} else {
396+
return hasInFlightDataForResultPartition(gateOrPartitionIndex);
397+
}
398+
}
399+
384400
private InflightDataGateOrPartitionRescalingDescriptor
385401
getInflightDataGateOrPartitionRescalingDescriptor(
386402
OperatorInstanceID instanceID,
@@ -479,6 +495,50 @@ public SubtasksRescaleMapping getInputMapping(int gateIndex) {
479495
checkSubtaskMapping(oldMapping, mapping, mapper.isAmbiguous()));
480496
}
481497

498+
public boolean hasInFlightDataForInputGate(int gateIndex) {
499+
// Check own input state for this gate
500+
if (inputStateGates.contains(gateIndex)) {
501+
return true;
502+
}
503+
504+
// Check upstream output state for this gate
505+
TaskStateAssignment upstreamAssignment = getUpstreamAssignments()[gateIndex];
506+
if (upstreamAssignment != null && upstreamAssignment.hasOutputState()) {
507+
IntermediateResult inputResult = executionJobVertex.getInputs().get(gateIndex);
508+
int partitionIndex =
509+
Arrays.asList(inputResult.getProducer().getProducedDataSets())
510+
.indexOf(inputResult);
511+
512+
if (partitionIndex != -1) {
513+
return upstreamAssignment.outputStatePartitions.contains(partitionIndex);
514+
}
515+
}
516+
517+
return false;
518+
}
519+
520+
public boolean hasInFlightDataForResultPartition(int partitionIndex) {
521+
// Check own output state for this partition
522+
if (outputStatePartitions.contains(partitionIndex)) {
523+
return true;
524+
}
525+
526+
// Check downstream input state for this partition
527+
TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIndex];
528+
529+
if (downstreamAssignment != null && downstreamAssignment.hasInputState()) {
530+
IntermediateResult producedResult =
531+
executionJobVertex.getProducedDataSets()[partitionIndex];
532+
int gateIndex =
533+
downstreamAssignment.executionJobVertex.getInputs().indexOf(producedResult);
534+
535+
if (gateIndex != -1) {
536+
return downstreamAssignment.inputStateGates.contains(gateIndex);
537+
}
538+
}
539+
return false;
540+
}
541+
482542
@Override
483543
public String toString() {
484544
return "TaskStateAssignment for " + executionJobVertex.getName();
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.runtime.checkpoint;
20+
21+
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
22+
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
23+
24+
import org.junit.jupiter.api.Test;
25+
26+
import java.util.Arrays;
27+
import java.util.Collections;
28+
29+
import static org.assertj.core.api.Assertions.assertThat;
30+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
31+
32+
/** Tests for {@link InflightDataRescalingDescriptor}. */
33+
class InflightDataRescalingDescriptorTest {
34+
35+
@Test
36+
void testNoStateDescriptorThrowsOnGetOldSubtaskInstances() {
37+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
38+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
39+
40+
assertThatThrownBy(noStateDescriptor::getOldSubtaskInstances)
41+
.isInstanceOf(UnsupportedOperationException.class)
42+
.hasMessageContaining(
43+
"Cannot get old subtasks from a descriptor that represents no state");
44+
}
45+
46+
@Test
47+
void testNoStateDescriptorThrowsOnGetRescaleMappings() {
48+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
49+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
50+
51+
assertThatThrownBy(noStateDescriptor::getRescaleMappings)
52+
.isInstanceOf(UnsupportedOperationException.class)
53+
.hasMessageContaining(
54+
"Cannot get rescale mappings from a descriptor that represents no state");
55+
}
56+
57+
@Test
58+
void testNoStateDescriptorIsIdentity() {
59+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
60+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
61+
62+
assertThat(noStateDescriptor.isIdentity()).isTrue();
63+
}
64+
65+
@Test
66+
void testRegularDescriptorDoesNotThrow() {
67+
int[] oldSubtasks = new int[] {0, 1, 2};
68+
RescaleMappings mappings =
69+
RescaleMappings.of(Arrays.stream(new int[][] {{0}, {1}, {2}}), 3);
70+
71+
InflightDataGateOrPartitionRescalingDescriptor descriptor =
72+
new InflightDataGateOrPartitionRescalingDescriptor(
73+
oldSubtasks, mappings, Collections.emptySet(), MappingType.RESCALING);
74+
75+
// Should not throw
76+
assertThat(descriptor.getOldSubtaskInstances()).isEqualTo(oldSubtasks);
77+
assertThat(descriptor.getRescaleMappings()).isEqualTo(mappings);
78+
assertThat(descriptor.isIdentity()).isFalse();
79+
}
80+
81+
@Test
82+
void testIdentityDescriptor() {
83+
int[] oldSubtasks = new int[] {0};
84+
RescaleMappings mappings = RescaleMappings.identity(1, 1);
85+
86+
InflightDataGateOrPartitionRescalingDescriptor descriptor =
87+
new InflightDataGateOrPartitionRescalingDescriptor(
88+
oldSubtasks, mappings, Collections.emptySet(), MappingType.IDENTITY);
89+
90+
assertThat(descriptor.isIdentity()).isTrue();
91+
assertThat(descriptor.getOldSubtaskInstances()).isEqualTo(oldSubtasks);
92+
assertThat(descriptor.getRescaleMappings()).isEqualTo(mappings);
93+
}
94+
95+
@Test
96+
void testInflightDataRescalingDescriptorWithNoStateDescriptor() {
97+
// Create a descriptor array with NO_STATE descriptor
98+
InflightDataGateOrPartitionRescalingDescriptor[] descriptors =
99+
new InflightDataGateOrPartitionRescalingDescriptor[] {
100+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE,
101+
new InflightDataGateOrPartitionRescalingDescriptor(
102+
new int[] {0, 1},
103+
RescaleMappings.of(Arrays.stream(new int[][] {{0}, {1}}), 2),
104+
Collections.emptySet(),
105+
MappingType.RESCALING)
106+
};
107+
108+
InflightDataRescalingDescriptor rescalingDescriptor =
109+
new InflightDataRescalingDescriptor(descriptors);
110+
111+
// First gate/partition has NO_STATE
112+
assertThatThrownBy(() -> rescalingDescriptor.getOldSubtaskIndexes(0))
113+
.isInstanceOf(UnsupportedOperationException.class);
114+
assertThatThrownBy(() -> rescalingDescriptor.getChannelMapping(0))
115+
.isInstanceOf(UnsupportedOperationException.class);
116+
117+
// Second gate/partition has normal state
118+
assertThat(rescalingDescriptor.getOldSubtaskIndexes(1)).isEqualTo(new int[] {0, 1});
119+
assertThat(rescalingDescriptor.getChannelMapping(1)).isNotNull();
120+
}
121+
}

0 commit comments

Comments
 (0)