Skip to content

Commit

Permalink
[FLINK-34502][autoscaler] Support calculating network memory for forw…
Browse files Browse the repository at this point in the history
…ard and rescale edge
  • Loading branch information
1996fanrui committed Feb 23, 2024
1 parent 1f3425d commit 2436c0c
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ private void computeTargetDataRate(
}
out.put(CATCH_UP_DATA_RATE, EvaluatedScalingMetric.of(catchUpInputRate));
} else {
var inputs = topology.get(vertex).getInputs();
var inputs = topology.get(vertex).getInputs().keySet();
double sumAvgTargetRate = 0;
double sumCatchUpDataRate = 0;
for (var inputVertex : inputs) {
Expand Down Expand Up @@ -531,7 +531,7 @@ protected static double computeEdgeDataRate(
JobVertexID from,
JobVertexID to) {

var toVertexInputs = topology.get(to).getInputs();
var toVertexInputs = topology.get(to).getInputs().keySet();
// Case 1: Downstream vertex has single input (from) so we can use the most reliable num
// records in
if (toVertexInputs.size() == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public JobTopology(VertexInfo... vertexInfo) {

public JobTopology(Set<VertexInfo> vertexInfo) {

Map<JobVertexID, Set<JobVertexID>> vertexOutputs = new HashMap<>();
Map<JobVertexID, Map<JobVertexID, String>> vertexOutputs = new HashMap<>();
vertexInfos =
ImmutableMap.copyOf(
vertexInfo.stream().collect(Collectors.toMap(VertexInfo::getId, v -> v)));
Expand All @@ -72,13 +72,13 @@ public JobTopology(Set<VertexInfo> vertexInfo) {
info -> {
var vertexId = info.getId();

vertexOutputs.computeIfAbsent(vertexId, id -> new HashSet<>());
vertexOutputs.computeIfAbsent(vertexId, id -> new HashMap<>());
info.getInputs()
.forEach(
inputId ->
(inputId, shipStrategy) ->
vertexOutputs
.computeIfAbsent(inputId, id -> new HashSet<>())
.add(vertexId));
.computeIfAbsent(inputId, id -> new HashMap<>())
.put(vertexId, shipStrategy));
if (info.isFinished()) {
finishedVertices.add(vertexId);
}
Expand All @@ -105,7 +105,8 @@ private List<JobVertexID> returnVerticesInTopologicalOrder() {
List<JobVertexID> sorted = new ArrayList<>(vertexInfos.size());

Map<JobVertexID, List<JobVertexID>> remainingInputs = new HashMap<>(vertexInfos.size());
vertexInfos.forEach((id, v) -> remainingInputs.put(id, new ArrayList<>(v.getInputs())));
vertexInfos.forEach(
(id, v) -> remainingInputs.put(id, new ArrayList<>(v.getInputs().keySet())));

while (!remainingInputs.isEmpty()) {
List<JobVertexID> verticesWithZeroIndegree = new ArrayList<>();
Expand All @@ -122,6 +123,7 @@ private List<JobVertexID> returnVerticesInTopologicalOrder() {
vertexInfos
.get(v)
.getOutputs()
.keySet()
.forEach(o -> remainingInputs.get(o).remove(v));
});

Expand All @@ -143,20 +145,22 @@ public static JobTopology fromJsonPlan(

for (JsonNode node : nodes) {
var vertexId = JobVertexID.fromHexString(node.get("id").asText());
var inputList = new HashSet<JobVertexID>();
var inputs = new HashMap<JobVertexID, String>();
var ioMetrics = metrics.get(vertexId);
var finished = finishedVertices.contains(vertexId);
vertexInfo.add(
new VertexInfo(
vertexId,
inputList,
inputs,
node.get("parallelism").asInt(),
maxParallelismMap.get(vertexId),
finished,
finished ? IOMetrics.FINISHED_METRICS : ioMetrics));
if (node.has("inputs")) {
for (JsonNode input : node.get("inputs")) {
inputList.add(JobVertexID.fromHexString(input.get("id").asText()));
inputs.put(
JobVertexID.fromHexString(input.get("id").asText()),
input.get("ship_strategy").asText());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@

import lombok.Data;

import java.util.Set;
import java.util.Map;

/** Job vertex information. */
@Data
public class VertexInfo {

private final JobVertexID id;

private final Set<JobVertexID> inputs;
// All input vertices and the ship_strategy
private final Map<JobVertexID, String> inputs;

private Set<JobVertexID> outputs;
// All output vertices and the ship_strategy
private Map<JobVertexID, String> outputs;

private final int parallelism;

Expand All @@ -46,7 +48,7 @@ public class VertexInfo {

public VertexInfo(
JobVertexID id,
Set<JobVertexID> inputs,
Map<JobVertexID, String> inputs,
int parallelism,
int maxParallelism,
boolean finished,
Expand All @@ -63,7 +65,7 @@ public VertexInfo(
@VisibleForTesting
public VertexInfo(
JobVertexID id,
Set<JobVertexID> inputs,
Map<JobVertexID, String> inputs,
int parallelism,
int maxParallelism,
IOMetrics ioMetrics) {
Expand All @@ -72,7 +74,7 @@ public VertexInfo(

@VisibleForTesting
public VertexInfo(
JobVertexID id, Set<JobVertexID> inputs, int parallelism, int maxParallelism) {
JobVertexID id, Map<JobVertexID, String> inputs, int parallelism, int maxParallelism) {
this(id, inputs, parallelism, maxParallelism, null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.flink.autoscaler.tuning;

import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.autoscaler.JobAutoScalerContext;
import org.apache.flink.autoscaler.ScalingSummary;
import org.apache.flink.autoscaler.config.AutoScalerOptions;
Expand Down Expand Up @@ -243,28 +244,36 @@ private static MemorySize adjustNetworkMemory(
Configuration config,
MemoryBudget memBudget) {

final long buffersPerChannel =
final int buffersPerChannel =
config.get(NettyShuffleEnvironmentOptions.NETWORK_BUFFERS_PER_CHANNEL);
final long floatingBuffers =
final int floatingBuffers =
config.get(NettyShuffleEnvironmentOptions.NETWORK_EXTRA_BUFFERS_PER_GATE);
final long memorySegmentBytes =
config.get(TaskManagerOptions.MEMORY_SEGMENT_SIZE).getBytes();

long maxNetworkMemory = 0;
for (VertexInfo vertexInfo : jobTopology.getVertexInfos().values()) {
// Add max amount of memory for each input gate
for (JobVertexID input : vertexInfo.getInputs()) {
int inputParallelism = updatedParallelisms.get(input);
for (Map.Entry<JobVertexID, String> inputEntry : vertexInfo.getInputs().entrySet()) {
maxNetworkMemory +=
(inputParallelism * buffersPerChannel + floatingBuffers)
calculateNetworkSegmentNumber(
updatedParallelisms.get(vertexInfo.getId()),
updatedParallelisms.get(inputEntry.getKey()),
inputEntry.getValue(),
buffersPerChannel,
floatingBuffers)
* memorySegmentBytes;
}
// Add max amount of memory for each output gate
// Usually, there is just one output per task
for (JobVertexID output : vertexInfo.getOutputs()) {
int downstreamParallelism = updatedParallelisms.get(output);
for (Map.Entry<JobVertexID, String> outputEntry : vertexInfo.getOutputs().entrySet()) {
maxNetworkMemory +=
(downstreamParallelism * buffersPerChannel + floatingBuffers)
calculateNetworkSegmentNumber(
updatedParallelisms.get(vertexInfo.getId()),
updatedParallelisms.get(outputEntry.getKey()),
outputEntry.getValue(),
buffersPerChannel,
floatingBuffers)
* memorySegmentBytes;
}
}
Expand All @@ -277,6 +286,30 @@ private static MemorySize adjustNetworkMemory(
return new MemorySize(memBudget.budget(maxNetworkMemory));
}

/**
* Calculate how many network segment current vertex needs.
*
* @param currentVertexParallelism The parallelism of current vertex.
* @param otherVertexParallelism The parallelism of other vertex.
*/
@VisibleForTesting
static int calculateNetworkSegmentNumber(
int currentVertexParallelism,
int otherVertexParallelism,
String shipStrategy,
int buffersPerChannel,
int floatingBuffers) {
if (currentVertexParallelism == otherVertexParallelism && "FORWARD".equals(shipStrategy)) {
return buffersPerChannel + floatingBuffers;
} else if ("FORWARD".equals(shipStrategy) || "RESCALE".equals(shipStrategy)) {
final int channelCount =
(int) Math.ceil(1.0d * otherVertexParallelism / currentVertexParallelism);
return channelCount * buffersPerChannel + floatingBuffers;
} else {
return otherVertexParallelism * buffersPerChannel + floatingBuffers;
}
}

private static MemorySize getUsage(
ScalingMetric scalingMetric, Map<ScalingMetric, EvaluatedScalingMetric> globalMetrics) {
MemorySize heapUsed = new MemorySize((long) globalMetrics.get(scalingMetric).getAverage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.Set;
import java.util.Map;
import java.util.SortedMap;

import static org.apache.flink.autoscaler.JobAutoScalerImpl.AUTOSCALER_ERROR;
Expand Down Expand Up @@ -78,9 +78,13 @@ public void setup() {
metricsCollector =
new TestingMetricsCollector<>(
new JobTopology(
new VertexInfo(source1, Set.of(), 1, 720, new IOMetrics(0, 0, 0)),
new VertexInfo(source1, Map.of(), 1, 720, new IOMetrics(0, 0, 0)),
new VertexInfo(
sink, Set.of(source1), 1, 720, new IOMetrics(0, 0, 0))));
sink,
Map.of(source1, "REBALANCE"),
1,
720,
new IOMetrics(0, 0, 0))));

var defaultConf = context.getConfiguration();
defaultConf.set(AutoScalerOptions.AUTOSCALER_ENABLED, true);
Expand Down Expand Up @@ -152,8 +156,8 @@ public void test() throws Exception {
// Update topology to reflect updated parallelisms
metricsCollector.setJobTopology(
new JobTopology(
new VertexInfo(source1, Set.of(), 4, 24),
new VertexInfo(sink, Set.of(source1), 4, 720)));
new VertexInfo(source1, Map.of(), 4, 24),
new VertexInfo(sink, Map.of(source1, "REBALANCE"), 4, 720)));

metricsCollector.updateMetrics(
source1,
Expand Down Expand Up @@ -234,8 +238,8 @@ public void test() throws Exception {
metricsCollector.setJobUpdateTs(now);
metricsCollector.setJobTopology(
new JobTopology(
new VertexInfo(source1, Set.of(), 2, 24),
new VertexInfo(sink, Set.of(source1), 2, 720)));
new VertexInfo(source1, Map.of(), 2, 24),
new VertexInfo(sink, Map.of(source1, "REBALANCE"), 2, 720)));

/* Test stability while processing backlog. */

Expand Down Expand Up @@ -356,8 +360,8 @@ public void shouldTrackRestartDurationCorrectly() throws Exception {
context = context.toBuilder().jobStatus(JobStatus.RUNNING).build();
metricsCollector.setJobTopology(
new JobTopology(
new VertexInfo(source1, Set.of(), 4, 720),
new VertexInfo(sink, Set.of(source1), 4, 720)));
new VertexInfo(source1, Map.of(), 4, 720),
new VertexInfo(sink, Map.of(source1, "REBALANCE"), 4, 720)));

var expectedEndTime = Instant.ofEpochMilli(10);
metricsCollector.setJobUpdateTs(expectedEndTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import static java.util.Map.entry;
Expand Down Expand Up @@ -87,7 +86,7 @@ public void setup() {
@Test
void testMetricReporting() throws Exception {
JobVertexID jobVertexID = new JobVertexID();
JobTopology jobTopology = new JobTopology(new VertexInfo(jobVertexID, Set.of(), 1, 10));
JobTopology jobTopology = new JobTopology(new VertexInfo(jobVertexID, Map.of(), 1, 10));

var metricsCollector =
new TestingMetricsCollector<JobID, JobAutoScalerContext<JobID>>(jobTopology);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

import static org.apache.flink.autoscaler.TestingAutoscalerUtils.createDefaultJobAutoScalerContext;
Expand Down Expand Up @@ -89,11 +88,16 @@ public void setup() {

topology =
new JobTopology(
new VertexInfo(source1, Set.of(), 2, 720, new IOMetrics(0, 0, 0)),
new VertexInfo(source2, Set.of(), 2, 720, new IOMetrics(0, 0, 0)),
new VertexInfo(source1, Map.of(), 2, 720, new IOMetrics(0, 0, 0)),
new VertexInfo(source2, Map.of(), 2, 720, new IOMetrics(0, 0, 0)),
new VertexInfo(
map, Set.of(source1, source2), 12, 720, new IOMetrics(0, 0, 0)),
new VertexInfo(sink, Set.of(map), 8, 24, new IOMetrics(0, 0, 0)));
map,
Map.of(source1, "REBALANCE", source2, "REBALANCE"),
12,
720,
new IOMetrics(0, 0, 0)),
new VertexInfo(
sink, Map.of(map, "REBALANCE"), 8, 24, new IOMetrics(0, 0, 0)));

metricsCollector = new TestingMetricsCollector<>(topology);

Expand Down Expand Up @@ -335,7 +339,7 @@ public void testClearHistoryOnTopoChange() throws Exception {

@Test
public void testTolerateAbsenceOfPendingRecordsMetric() throws Exception {
var topology = new JobTopology(new VertexInfo(source1, Set.of(), 5, 720));
var topology = new JobTopology(new VertexInfo(source1, Map.of(), 5, 720));

metricsCollector = new TestingMetricsCollector(topology);
metricsCollector.setJobUpdateTs(startTime);
Expand Down Expand Up @@ -408,9 +412,9 @@ public void testFinishedVertexMetricsCollection() throws Exception {
var finished = new JobVertexID();
var topology =
new JobTopology(
new VertexInfo(s1, Set.of(), 10, 720),
new VertexInfo(s1, Map.of(), 10, 720),
new VertexInfo(
finished, Set.of(), 10, 720, true, IOMetrics.FINISHED_METRICS));
finished, Map.of(), 10, 720, true, IOMetrics.FINISHED_METRICS));

metricsCollector = new TestingMetricsCollector(topology);
metricsCollector.setJobUpdateTs(startTime);
Expand Down Expand Up @@ -445,7 +449,7 @@ public void testFinishedVertexMetricsCollection() throws Exception {
@Test
public void testObservedTprCollection() throws Exception {
var source = new JobVertexID();
var topology = new JobTopology(new VertexInfo(source, Set.of(), 10, 720));
var topology = new JobTopology(new VertexInfo(source, Map.of(), 10, 720));

metricsCollector = new TestingMetricsCollector(topology);
metricsCollector.setJobUpdateTs(startTime);
Expand Down Expand Up @@ -521,7 +525,7 @@ public void testObservedTprCollection() throws Exception {
@Test
public void testMetricCollectionDuringStabilization() throws Exception {
var source = new JobVertexID();
var topology = new JobTopology(new VertexInfo(source, Set.of(), 10, 720));
var topology = new JobTopology(new VertexInfo(source, Map.of(), 10, 720));

metricsCollector = new TestingMetricsCollector(topology);
metricsCollector.setJobUpdateTs(startTime);
Expand Down Expand Up @@ -596,7 +600,7 @@ protected Map<JobVertexID, Map<String, FlinkMetric>> queryFilteredMetricNames(

@Test
public void testScaleDownWithZeroProcessingRate() throws Exception {
var topology = new JobTopology(new VertexInfo(source1, Set.of(), 2, 720));
var topology = new JobTopology(new VertexInfo(source1, Map.of(), 2, 720));

metricsCollector = new TestingMetricsCollector<>(topology);
metricsCollector.setJobUpdateTs(startTime);
Expand Down
Loading

0 comments on commit 2436c0c

Please sign in to comment.