20
20
import org .apache .flink .runtime .OperatorIDPair ;
21
21
import org .apache .flink .runtime .checkpoint .InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor ;
22
22
import org .apache .flink .runtime .checkpoint .InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor .MappingType ;
23
+ import org .apache .flink .runtime .checkpoint .channel .InputChannelInfo ;
24
+ import org .apache .flink .runtime .checkpoint .channel .ResultSubpartitionInfo ;
23
25
import org .apache .flink .runtime .executiongraph .ExecutionJobVertex ;
24
26
import org .apache .flink .runtime .executiongraph .IntermediateResult ;
25
27
import org .apache .flink .runtime .io .network .api .writer .SubtaskStateMapper ;
28
30
import org .apache .flink .runtime .jobgraph .OperatorInstanceID ;
29
31
import org .apache .flink .runtime .state .InputChannelStateHandle ;
30
32
import org .apache .flink .runtime .state .KeyedStateHandle ;
33
+ import org .apache .flink .runtime .state .MergedInputChannelStateHandle ;
34
+ import org .apache .flink .runtime .state .MergedResultSubpartitionStateHandle ;
31
35
import org .apache .flink .runtime .state .OperatorStateHandle ;
32
36
import org .apache .flink .runtime .state .ResultSubpartitionStateHandle ;
33
37
import org .apache .flink .runtime .state .StateObject ;
40
44
import javax .annotation .Nullable ;
41
45
42
46
import java .util .Arrays ;
47
+ import java .util .Collection ;
43
48
import java .util .HashMap ;
44
49
import java .util .List ;
45
50
import java .util .Map ;
48
53
import java .util .Set ;
49
54
import java .util .function .BiFunction ;
50
55
import java .util .function .Function ;
56
+ import java .util .stream .Collectors ;
51
57
import java .util .stream .IntStream ;
52
58
53
59
import static java .util .Collections .emptySet ;
@@ -67,12 +73,16 @@ class TaskStateAssignment {
67
73
final Map <OperatorID , OperatorState > oldState ;
68
74
final boolean hasNonFinishedState ;
69
75
final boolean isFullyFinished ;
70
- final boolean hasInputState ;
71
- final boolean hasOutputState ;
72
76
final int newParallelism ;
73
77
final OperatorID inputOperatorID ;
74
78
final OperatorID outputOperatorID ;
75
79
80
+ /** The InputGate set that containing input buffer state. */
81
+ private final Set <Integer > inputStateGates ;
82
+
83
+ /** The ResultPartition set that containing input buffer state. */
84
+ private final Set <Integer > outputStatePartitions ;
85
+
76
86
final Map <OperatorInstanceID , List <OperatorStateHandle >> subManagedOperatorState ;
77
87
final Map <OperatorInstanceID , List <OperatorStateHandle >> subRawOperatorState ;
78
88
final Map <OperatorInstanceID , List <KeyedStateHandle >> subManagedKeyedState ;
@@ -131,12 +141,63 @@ public TaskStateAssignment(
131
141
outputOperatorID = operatorIDs .get (0 ).getGeneratedOperatorID ();
132
142
inputOperatorID = operatorIDs .get (operatorIDs .size () - 1 ).getGeneratedOperatorID ();
133
143
134
- hasInputState =
135
- oldState .get (inputOperatorID ).getStates ().stream ()
136
- .anyMatch (subState -> !subState .getInputChannelState ().isEmpty ());
137
- hasOutputState =
138
- oldState .get (outputOperatorID ).getStates ().stream ()
139
- .anyMatch (subState -> !subState .getResultSubpartitionState ().isEmpty ());
144
+ inputStateGates = extractInputStateGates (oldState .get (inputOperatorID ));
145
+ outputStatePartitions = extractOutputStatePartitions (oldState .get (outputOperatorID ));
146
+ }
147
+
148
+ private static Set <Integer > extractInputStateGates (OperatorState operatorState ) {
149
+ return operatorState .getStates ().stream ()
150
+ .map (OperatorSubtaskState ::getInputChannelState )
151
+ .flatMap (Collection ::stream )
152
+ .flatMapToInt (
153
+ handle -> {
154
+ if (handle instanceof InputChannelStateHandle ) {
155
+ return IntStream .of (
156
+ ((InputChannelStateHandle ) handle ).getInfo ().getGateIdx ());
157
+ } else if (handle instanceof MergedInputChannelStateHandle ) {
158
+ return ((MergedInputChannelStateHandle ) handle )
159
+ .getInfos ().stream ().mapToInt (InputChannelInfo ::getGateIdx );
160
+ } else {
161
+ throw new IllegalStateException (
162
+ "Invalid input channel state : " + handle .getClass ());
163
+ }
164
+ })
165
+ .distinct ()
166
+ .boxed ()
167
+ .collect (Collectors .toSet ());
168
+ }
169
+
170
+ private static Set <Integer > extractOutputStatePartitions (OperatorState operatorState ) {
171
+ return operatorState .getStates ().stream ()
172
+ .map (OperatorSubtaskState ::getResultSubpartitionState )
173
+ .flatMap (Collection ::stream )
174
+ .flatMapToInt (
175
+ handle -> {
176
+ if (handle instanceof ResultSubpartitionStateHandle ) {
177
+ return IntStream .of (
178
+ ((ResultSubpartitionStateHandle ) handle )
179
+ .getInfo ()
180
+ .getPartitionIdx ());
181
+ } else if (handle instanceof MergedResultSubpartitionStateHandle ) {
182
+ return ((MergedResultSubpartitionStateHandle ) handle )
183
+ .getInfos ().stream ()
184
+ .mapToInt (ResultSubpartitionInfo ::getPartitionIdx );
185
+ } else {
186
+ throw new IllegalStateException (
187
+ "Invalid output channel state : " + handle .getClass ());
188
+ }
189
+ })
190
+ .distinct ()
191
+ .boxed ()
192
+ .collect (Collectors .toSet ());
193
+ }
194
+
195
+ public boolean hasInputState () {
196
+ return !inputStateGates .isEmpty ();
197
+ }
198
+
199
+ public boolean hasOutputState () {
200
+ return !outputStatePartitions .isEmpty ();
140
201
}
141
202
142
203
public TaskStateAssignment [] getDownstreamAssignments () {
@@ -212,7 +273,7 @@ public boolean hasUpstreamOutputStates() {
212
273
if (hasUpstreamOutputStates == null ) {
213
274
hasUpstreamOutputStates =
214
275
Arrays .stream (getUpstreamAssignments ())
215
- .anyMatch (assignment -> assignment . hasOutputState );
276
+ .anyMatch (TaskStateAssignment :: hasOutputState );
216
277
}
217
278
return hasUpstreamOutputStates ;
218
279
}
@@ -221,7 +282,7 @@ public boolean hasDownstreamInputStates() {
221
282
if (hasDownstreamInputStates == null ) {
222
283
hasDownstreamInputStates =
223
284
Arrays .stream (getDownstreamAssignments ())
224
- .anyMatch (assignment -> assignment . hasInputState );
285
+ .anyMatch (TaskStateAssignment :: hasInputState );
225
286
}
226
287
return hasDownstreamInputStates ;
227
288
}
0 commit comments