Skip to content

Commit c33e974

Browse files
committed
[FLINK-38267][checkpoint] Refactor hasInputState and hasOutputState related logic in TaskStateAssignment
1 parent 250ab88 commit c33e974

File tree

2 files changed

+73
-12
lines changed

2 files changed

+73
-12
lines changed

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ public static <T extends StateObject> void reDistributePartitionableStates(
361361
public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) {
362362
// FLINK-31963: We can skip this phase if there is no output state AND downstream has no
363363
// input states
364-
if (!assignment.hasOutputState && !assignment.hasDownstreamInputStates()) {
364+
if (!assignment.hasOutputState() && !assignment.hasDownstreamInputStates()) {
365365
return;
366366
}
367367

@@ -410,7 +410,7 @@ public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment)
410410
public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment) {
411411
// FLINK-31963: We can skip this phase only if there is no input state AND upstream has no
412412
// output states
413-
if (!stateAssignment.hasInputState && !stateAssignment.hasUpstreamOutputStates()) {
413+
if (!stateAssignment.hasInputState() && !stateAssignment.hasUpstreamOutputStates()) {
414414
return;
415415
}
416416

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import org.apache.flink.runtime.OperatorIDPair;
2121
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
2222
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
23+
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
24+
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
2325
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
2426
import org.apache.flink.runtime.executiongraph.IntermediateResult;
2527
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
@@ -28,6 +30,8 @@
2830
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
2931
import org.apache.flink.runtime.state.InputChannelStateHandle;
3032
import org.apache.flink.runtime.state.KeyedStateHandle;
33+
import org.apache.flink.runtime.state.MergedInputChannelStateHandle;
34+
import org.apache.flink.runtime.state.MergedResultSubpartitionStateHandle;
3135
import org.apache.flink.runtime.state.OperatorStateHandle;
3236
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
3337
import org.apache.flink.runtime.state.StateObject;
@@ -40,6 +44,7 @@
4044
import javax.annotation.Nullable;
4145

4246
import java.util.Arrays;
47+
import java.util.Collection;
4348
import java.util.HashMap;
4449
import java.util.List;
4550
import java.util.Map;
@@ -48,6 +53,7 @@
4853
import java.util.Set;
4954
import java.util.function.BiFunction;
5055
import java.util.function.Function;
56+
import java.util.stream.Collectors;
5157
import java.util.stream.IntStream;
5258

5359
import static java.util.Collections.emptySet;
@@ -67,12 +73,16 @@ class TaskStateAssignment {
6773
final Map<OperatorID, OperatorState> oldState;
6874
final boolean hasNonFinishedState;
6975
final boolean isFullyFinished;
70-
final boolean hasInputState;
71-
final boolean hasOutputState;
7276
final int newParallelism;
7377
final OperatorID inputOperatorID;
7478
final OperatorID outputOperatorID;
7579

80+
/** The InputGate set that containing input buffer state. */
81+
private final Set<Integer> inputStateGates;
82+
83+
/** The ResultPartition set that containing input buffer state. */
84+
private final Set<Integer> outputStatePartitions;
85+
7686
final Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState;
7787
final Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState;
7888
final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState;
@@ -131,12 +141,63 @@ public TaskStateAssignment(
131141
outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID();
132142
inputOperatorID = operatorIDs.get(operatorIDs.size() - 1).getGeneratedOperatorID();
133143

134-
hasInputState =
135-
oldState.get(inputOperatorID).getStates().stream()
136-
.anyMatch(subState -> !subState.getInputChannelState().isEmpty());
137-
hasOutputState =
138-
oldState.get(outputOperatorID).getStates().stream()
139-
.anyMatch(subState -> !subState.getResultSubpartitionState().isEmpty());
144+
inputStateGates = extractInputStateGates(oldState.get(inputOperatorID));
145+
outputStatePartitions = extractOutputStatePartitions(oldState.get(outputOperatorID));
146+
}
147+
148+
private static Set<Integer> extractInputStateGates(OperatorState operatorState) {
149+
return operatorState.getStates().stream()
150+
.map(OperatorSubtaskState::getInputChannelState)
151+
.flatMap(Collection::stream)
152+
.flatMapToInt(
153+
handle -> {
154+
if (handle instanceof InputChannelStateHandle) {
155+
return IntStream.of(
156+
((InputChannelStateHandle) handle).getInfo().getGateIdx());
157+
} else if (handle instanceof MergedInputChannelStateHandle) {
158+
return ((MergedInputChannelStateHandle) handle)
159+
.getInfos().stream().mapToInt(InputChannelInfo::getGateIdx);
160+
} else {
161+
throw new IllegalStateException(
162+
"Invalid input channel state : " + handle.getClass());
163+
}
164+
})
165+
.distinct()
166+
.boxed()
167+
.collect(Collectors.toSet());
168+
}
169+
170+
private static Set<Integer> extractOutputStatePartitions(OperatorState operatorState) {
171+
return operatorState.getStates().stream()
172+
.map(OperatorSubtaskState::getResultSubpartitionState)
173+
.flatMap(Collection::stream)
174+
.flatMapToInt(
175+
handle -> {
176+
if (handle instanceof ResultSubpartitionStateHandle) {
177+
return IntStream.of(
178+
((ResultSubpartitionStateHandle) handle)
179+
.getInfo()
180+
.getPartitionIdx());
181+
} else if (handle instanceof MergedResultSubpartitionStateHandle) {
182+
return ((MergedResultSubpartitionStateHandle) handle)
183+
.getInfos().stream()
184+
.mapToInt(ResultSubpartitionInfo::getPartitionIdx);
185+
} else {
186+
throw new IllegalStateException(
187+
"Invalid output channel state : " + handle.getClass());
188+
}
189+
})
190+
.distinct()
191+
.boxed()
192+
.collect(Collectors.toSet());
193+
}
194+
195+
public boolean hasInputState() {
196+
return !inputStateGates.isEmpty();
197+
}
198+
199+
public boolean hasOutputState() {
200+
return !outputStatePartitions.isEmpty();
140201
}
141202

142203
public TaskStateAssignment[] getDownstreamAssignments() {
@@ -212,7 +273,7 @@ public boolean hasUpstreamOutputStates() {
212273
if (hasUpstreamOutputStates == null) {
213274
hasUpstreamOutputStates =
214275
Arrays.stream(getUpstreamAssignments())
215-
.anyMatch(assignment -> assignment.hasOutputState);
276+
.anyMatch(TaskStateAssignment::hasOutputState);
216277
}
217278
return hasUpstreamOutputStates;
218279
}
@@ -221,7 +282,7 @@ public boolean hasDownstreamInputStates() {
221282
if (hasDownstreamInputStates == null) {
222283
hasDownstreamInputStates =
223284
Arrays.stream(getDownstreamAssignments())
224-
.anyMatch(assignment -> assignment.hasInputState);
285+
.anyMatch(TaskStateAssignment::hasInputState);
225286
}
226287
return hasDownstreamInputStates;
227288
}

0 commit comments

Comments
 (0)