Skip to content

Commit

Permalink
[FLINK-36989][runtime] Fix scheduler benchmark regression caused by C…
Browse files Browse the repository at this point in the history
…onsumedSubpartitionContext
  • Loading branch information
noorall authored and zhuzhurk committed Jan 6, 2025
1 parent 3084561 commit 81f882b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@

import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IndexRangeUtil;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
Expand Down Expand Up @@ -113,34 +112,41 @@ public IndexRange getConsumedSubpartitionRange(int shuffleDescriptorIndex) {
*
* @param consumedSubpartitionGroups a mapping of consumed partition index ranges to
* subpartition ranges.
* @param consumedResultPartitions an iterator of {@link IntermediateResultPartitionID} for the
* consumed result partitions.
* @param partitions all partition ids of consumed {@link IntermediateResult}.
* @param consumedPartitionGroup partition group consumed by the task.
* @param partitionIdRetriever a function that retrieves the {@link
* IntermediateResultPartitionID} for a given index.
* @return a {@link ConsumedSubpartitionContext} instance constructed from the input parameters.
*/
public static ConsumedSubpartitionContext buildConsumedSubpartitionContext(
Map<IndexRange, IndexRange> consumedSubpartitionGroups,
Iterator<IntermediateResultPartitionID> consumedResultPartitions,
IntermediateResultPartitionID[] partitions) {
Map<IntermediateResultPartitionID, Integer> partitionIdToShuffleDescriptorIndexMap =
new HashMap<>();
while (consumedResultPartitions.hasNext()) {
IntermediateResultPartitionID partitionId = consumedResultPartitions.next();
partitionIdToShuffleDescriptorIndexMap.put(
partitionId, partitionIdToShuffleDescriptorIndexMap.size());
ConsumedPartitionGroup consumedPartitionGroup,
Function<Integer, IntermediateResultPartitionID> partitionIdRetriever) {
Map<IntermediateResultPartitionID, Integer> resultPartitionsInOrder =
consumedPartitionGroup.getResultPartitionsInOrder();
// If only one range is included and the index range size is the same as the number of
// shuffle descriptors, it means that the task will subscribe to all partitions, i.e., the
// partition range is one-to-one corresponding to the shuffle descriptors. Therefore, we can
// directly construct the ConsumedSubpartitionContext using the subpartition range.
if (consumedSubpartitionGroups.size() == 1
&& consumedSubpartitionGroups.keySet().iterator().next().size()
== resultPartitionsInOrder.size()) {
return buildConsumedSubpartitionContext(
resultPartitionsInOrder.size(),
consumedSubpartitionGroups.values().iterator().next());
}

Map<IndexRange, IndexRange> consumedShuffleDescriptorToSubpartitionRangeMap =
new LinkedHashMap<>();
for (Map.Entry<IndexRange, IndexRange> entry : consumedSubpartitionGroups.entrySet()) {
IndexRange partitionRange = entry.getKey();
IndexRange subpartitionRange = entry.getValue();
// The shuffle descriptor index is consistent with the index in resultPartitionsInOrder.
IndexRange shuffleDescriptorRange =
new IndexRange(
partitionIdToShuffleDescriptorIndexMap.get(
partitions[partitionRange.getStartIndex()]),
partitionIdToShuffleDescriptorIndexMap.get(
partitions[partitionRange.getEndIndex()]));
resultPartitionsInOrder.get(
partitionIdRetriever.apply(partitionRange.getStartIndex())),
resultPartitionsInOrder.get(
partitionIdRetriever.apply(partitionRange.getEndIndex())));
checkState(
partitionRange.size() == shuffleDescriptorRange.size()
&& !consumedShuffleDescriptorToSubpartitionRangeMap.containsKey(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -151,6 +150,7 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors

IntermediateDataSetID resultId = consumedIntermediateResult.getId();
ResultPartitionType partitionType = consumedIntermediateResult.getResultType();
IntermediateResultPartition[] partitions = consumedIntermediateResult.getPartitions();

inputGates.add(
new InputGateDeploymentDescriptor(
Expand All @@ -160,10 +160,8 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors
executionVertex
.getExecutionVertexInputInfo(resultId)
.getConsumedSubpartitionGroups(),
consumedPartitionGroup.iterator(),
Arrays.stream(consumedIntermediateResult.getPartitions())
.map(IntermediateResultPartition::getPartitionId)
.toArray(IntermediateResultPartitionID[]::new)),
consumedPartitionGroup,
index -> partitions[index].getPartitionId()),
consumedPartitionGroup.size(),
getConsumedPartitionShuffleDescriptors(
consumedIntermediateResult,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@

import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import static org.apache.flink.util.Preconditions.checkArgument;
Expand All @@ -40,7 +42,10 @@
*/
public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartitionID> {

private final List<IntermediateResultPartitionID> resultPartitions;
// The key is the result partition ID, the value is the index of the result partition in the
// original construction list.
private final Map<IntermediateResultPartitionID, Integer> resultPartitionsInOrder =
new LinkedHashMap<>();

private final AtomicInteger unfinishedPartitions;

Expand All @@ -64,13 +69,15 @@ private ConsumedPartitionGroup(
this.intermediateDataSetID = resultPartitions.get(0).getIntermediateDataSetID();
this.resultPartitionType = Preconditions.checkNotNull(resultPartitionType);

// Sanity check: all the partitions in one ConsumedPartitionGroup should have the same
// IntermediateDataSetID
for (IntermediateResultPartitionID resultPartition : resultPartitions) {
for (int i = 0; i < resultPartitions.size(); i++) {
// Sanity check: all the partitions in one ConsumedPartitionGroup should have the same
// IntermediateDataSetID
IntermediateResultPartitionID resultPartition = resultPartitions.get(i);
checkArgument(
resultPartition.getIntermediateDataSetID().equals(this.intermediateDataSetID));

resultPartitionsInOrder.put(resultPartition, i);
}
this.resultPartitions = resultPartitions;

this.unfinishedPartitions = new AtomicInteger(resultPartitions.size());
}
Expand All @@ -92,15 +99,19 @@ public static ConsumedPartitionGroup fromSinglePartition(

@Override
public Iterator<IntermediateResultPartitionID> iterator() {
return resultPartitions.iterator();
return resultPartitionsInOrder.keySet().iterator();
}

public Map<IntermediateResultPartitionID, Integer> getResultPartitionsInOrder() {
return Collections.unmodifiableMap(resultPartitionsInOrder);
}

public int size() {
return resultPartitions.size();
return resultPartitionsInOrder.size();
}

public boolean isEmpty() {
return resultPartitions.isEmpty();
return resultPartitionsInOrder.isEmpty();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;

import org.junit.jupiter.api.Test;

Expand All @@ -29,6 +30,7 @@
import java.util.List;
import java.util.Map;

import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
import static org.assertj.core.api.Assertions.assertThat;

/** Tests for {@link ConsumedSubpartitionContext}. */
Expand All @@ -40,17 +42,13 @@ void testBuildConsumedSubpartitionContextWithGroups() {
new IndexRange(0, 1), new IndexRange(0, 2),
new IndexRange(2, 3), new IndexRange(3, 5));

List<IntermediateResultPartitionID> consumedPartitionIds = new ArrayList<>();

IntermediateResultPartitionID[] partitions = new IntermediateResultPartitionID[4];
for (int i = 0; i < partitions.length; i++) {
partitions[i] = new IntermediateResultPartitionID(new IntermediateDataSetID(), i);
consumedPartitionIds.add(partitions[i]);
}
List<IntermediateResultPartitionID> partitions = createPartitions();
ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromMultiplePartitions(4, partitions, BLOCKING);

ConsumedSubpartitionContext context =
ConsumedSubpartitionContext.buildConsumedSubpartitionContext(
consumedSubpartitionGroups, consumedPartitionIds.iterator(), partitions);
consumedSubpartitionGroups, consumedPartitionGroup, partitions::get);

assertThat(context.getNumConsumedShuffleDescriptors()).isEqualTo(4);

Expand All @@ -71,17 +69,13 @@ void testBuildConsumedSubpartitionContextWithUnorderedGroups() {
new IndexRange(3, 3), new IndexRange(1, 1),
new IndexRange(0, 0), new IndexRange(0, 1));

List<IntermediateResultPartitionID> consumedPartitionIds = new ArrayList<>();

IntermediateResultPartitionID[] partitions = new IntermediateResultPartitionID[4];
for (int i = 0; i < partitions.length; i++) {
partitions[i] = new IntermediateResultPartitionID(new IntermediateDataSetID(), i);
consumedPartitionIds.add(partitions[i]);
}
List<IntermediateResultPartitionID> partitions = createPartitions();
ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromMultiplePartitions(4, partitions, BLOCKING);

ConsumedSubpartitionContext context =
ConsumedSubpartitionContext.buildConsumedSubpartitionContext(
consumedSubpartitionGroups, consumedPartitionIds.iterator(), partitions);
consumedSubpartitionGroups, consumedPartitionGroup, partitions::get);

assertThat(context.getNumConsumedShuffleDescriptors()).isEqualTo(2);

Expand All @@ -100,17 +94,13 @@ void testBuildConsumedSubpartitionContextWithOverlapGroups() {
new IndexRange(0, 3), new IndexRange(1, 1),
new IndexRange(0, 1), new IndexRange(2, 2));

List<IntermediateResultPartitionID> consumedPartitionIds = new ArrayList<>();

IntermediateResultPartitionID[] partitions = new IntermediateResultPartitionID[4];
for (int i = 0; i < partitions.length; i++) {
partitions[i] = new IntermediateResultPartitionID(new IntermediateDataSetID(), i);
consumedPartitionIds.add(partitions[i]);
}
List<IntermediateResultPartitionID> partitions = createPartitions();
ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromMultiplePartitions(4, partitions, BLOCKING);

ConsumedSubpartitionContext context =
ConsumedSubpartitionContext.buildConsumedSubpartitionContext(
consumedSubpartitionGroups, consumedPartitionIds.iterator(), partitions);
consumedSubpartitionGroups, consumedPartitionGroup, partitions::get);

assertThat(context.getNumConsumedShuffleDescriptors()).isEqualTo(4);

Expand Down Expand Up @@ -144,4 +134,13 @@ void testBuildConsumedSubpartitionContextWithRange() {
IndexRange subpartitionRange = context.getConsumedSubpartitionRange(2);
assertThat(subpartitionRange).isEqualTo(consumedSubpartitionRange);
}

private static List<IntermediateResultPartitionID> createPartitions() {
List<IntermediateResultPartitionID> partitions = new ArrayList<>();
IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
for (int i = 0; i < 4; i++) {
partitions.add(new IntermediateResultPartitionID(intermediateDataSetID, i));
}
return partitions;
}
}

0 comments on commit 81f882b

Please sign in to comment.