Skip to content

Commit

Permalink
Implement listener-based memory tracking
Browse files Browse the repository at this point in the history
Co-authored-by: Ioannis Panagiotas <[email protected]>
  • Loading branch information
vnickolov and IoannisPanagiotas committed Nov 18, 2024
1 parent ffe58b9 commit cec8a0b
Show file tree
Hide file tree
Showing 33 changed files with 454 additions and 284 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ public synchronized <CONFIGURATION extends AlgoBaseConfig> void assertAlgorithmC
) throws IllegalStateException {

try {

var memoryRequirement = MemoryRequirement.create(
estimationFactory,
graphStore,
Expand All @@ -79,31 +78,26 @@ public synchronized <CONFIGURATION extends AlgoBaseConfig> void assertAlgorithmC
useMaxMemoryEstimation
);

// TODO: I don't really like this, think how we can refactor it better.
try {
var bytesToReserve = memoryRequirement.requiredMemory();
if (configuration.sudo()) {
memoryTracker.track(configuration.jobId(), bytesToReserve);
return;
}
var availableMemory = memoryTracker.availableMemory();
if (bytesToReserve > availableMemory) {
throw new MemoryReservationExceededException(bytesToReserve, availableMemory);
}
var bytesToReserve = memoryRequirement.requiredMemory();
if (configuration.sudo()) {
memoryTracker.track(configuration.jobId(), bytesToReserve);
} catch (MemoryReservationExceededException e) {
var message = StringFormatting.formatWithLocale(
"Memory required to run %s (%db) exceeds available memory (%db)",
label,
e.bytesRequired(),
e.bytesAvailable()
);
return;
}

throw new IllegalStateException(message);
memoryTracker.tryToTrack(configuration.jobId(), bytesToReserve);

}
} catch (MemoryEstimationNotImplementedException e) {
log.info("Memory usage estimate not available for " + label + ", skipping guard");
} catch (MemoryReservationExceededException e) {
var message = StringFormatting.formatWithLocale(
"Memory required to run %s (%db) exceeds available memory (%db)",
label,
e.bytesRequired(),
e.bytesAvailable()
);

throw new IllegalStateException(message);

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.neo4j.gds.applications.services.GraphDimensionFactory;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryTracker;
Expand Down Expand Up @@ -52,7 +53,7 @@ void shouldAllowExecution() {
null,
graphDimensionFactory,
false,
new MemoryTracker(42)
new MemoryTracker(42, Log.noOpLog())
);

var graphStore = mock(GraphStore.class);
Expand Down Expand Up @@ -81,7 +82,7 @@ void shouldGuardExecutionUsingMinimumEstimate() {
null,
graphDimensionFactory,
false,
new MemoryTracker(42)
new MemoryTracker(42, Log.noOpLog())
);

var graphStore = mock(GraphStore.class);
Expand Down Expand Up @@ -116,7 +117,7 @@ void shouldGuardExecutionUsingMaximumEstimate() {
null,
graphDimensionFactory,
true,
new MemoryTracker(42)
new MemoryTracker(42, Log.noOpLog())
);

var graphStore = mock(GraphStore.class);
Expand Down Expand Up @@ -151,7 +152,7 @@ void shouldRespectSudoFlag() {
null,
graphDimensionFactory,
false,
new MemoryTracker(42)
new MemoryTracker(42, Log.noOpLog())
);

var graphStore = mock(GraphStore.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,31 @@
package org.neo4j.gds.applications.graphstorecatalog;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.neo4j.gds.BaseTest;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.ImmutableGraphDimensions;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.utils.progress.JobId;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryTracker;
import org.neo4j.gds.mem.MemoryTree;
import org.neo4j.gds.mem.MemoryTreeWithDimensions;

import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

Expand All @@ -43,11 +55,13 @@ void shouldPassOnSufficientMemory() {
var dimensions = GraphDimensions.builder().nodeCount(1000).build();
var memoryTree = MemoryTree.empty();

assertThatNoException().isThrownBy(() -> new MemoryUsageValidator(Log.noOpLog(), GraphDatabaseApiProxy.dependencyResolver(db))
assertThatNoException().isThrownBy(() -> new MemoryUsageValidator(new MemoryTracker(10000000, Log.noOpLog()),
false,
Log.noOpLog()
)
.tryValidateMemoryUsage(
TestConfig.empty(),
(config) -> new MemoryTreeWithDimensions(memoryTree, dimensions),
() -> 10000000
(config) -> new MemoryTreeWithDimensions(memoryTree, dimensions)
));
}

Expand All @@ -56,11 +70,13 @@ void shouldFailOnInsufficientMemory() {
var dimensions = GraphDimensions.builder().nodeCount(1000).build();
var memoryTree = new TestTree("test", MemoryRange.of(42));

assertThatThrownBy(() -> new MemoryUsageValidator(Log.noOpLog(), GraphDatabaseApiProxy.dependencyResolver(db))
assertThatThrownBy(() -> new MemoryUsageValidator(new MemoryTracker(21, Log.noOpLog()),
false,
Log.noOpLog()
)
.tryValidateMemoryUsage(
TestConfig.empty(),
(config) -> new MemoryTreeWithDimensions(memoryTree, dimensions),
() -> 21
(config) -> new MemoryTreeWithDimensions(memoryTree, dimensions)
))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Procedure was blocked since minimum estimated memory (42 Bytes) exceeds current free memory (21 Bytes).");
Expand All @@ -71,11 +87,13 @@ void shouldNotFailOnInsufficientMemoryIfInSudoMode() {
var dimensions = GraphDimensions.builder().nodeCount(1000).build();
var memoryTree = new TestTree("test", MemoryRange.of(42));

assertThatNoException().isThrownBy(() -> new MemoryUsageValidator(Log.noOpLog(), GraphDatabaseApiProxy.dependencyResolver(db))
assertThatNoException().isThrownBy(() -> new MemoryUsageValidator(new MemoryTracker(21, Log.noOpLog()),
false,
Log.noOpLog()
)
.tryValidateMemoryUsage(
TestConfig.of(CypherMapWrapper.empty().withBoolean("sudo", true)),
(config) -> new MemoryTreeWithDimensions(memoryTree, dimensions),
() -> 21
(config) -> new MemoryTreeWithDimensions(memoryTree, dimensions)
));
}

Expand All @@ -84,20 +102,106 @@ void shouldLogWhenFailing() {
var log = mock(Log.class);
var dimensions = GraphDimensions.builder().nodeCount(1000).build();
var memoryTree = new TestTree("test", MemoryRange.of(42));
var memoryUsageValidator = new MemoryUsageValidator(log, GraphDatabaseApiProxy.dependencyResolver(db));
try {
memoryUsageValidator.tryValidateMemoryUsage(
var memoryUsageValidator = new MemoryUsageValidator(
new MemoryTracker(21, log), false, log
);

assertThatIllegalStateException().isThrownBy(
() -> memoryUsageValidator.tryValidateMemoryUsage(
TestConfig.of(CypherMapWrapper.empty()),
(config -> new MemoryTreeWithDimensions(memoryTree, dimensions)),
() -> 21
);
} catch (IllegalStateException ex) {
// do nothing
}
(config -> new MemoryTreeWithDimensions(memoryTree, dimensions))
)
);

verify(log).info("Procedure was blocked since minimum estimated memory (42 Bytes) exceeds current free memory (21 Bytes).");
}

private static final GraphDimensions TEST_DIMENSIONS = ImmutableGraphDimensions
.builder()
.nodeCount(100)
.relCountUpperBound(1000)
.build();

static Stream<Arguments> input() {
var fixedMemory = MemoryEstimations.builder().fixed("foobar", 1337);
var memoryRange = MemoryEstimations
.builder()
.rangePerGraphDimension("foobar", (dimensions, concurrency) -> MemoryRange.of(42, 1337));
return Stream.of(
Arguments.of(fixedMemory.build(), false),
Arguments.of(fixedMemory.build(), true),
Arguments.of(memoryRange.build(), false),
Arguments.of(memoryRange.build(), true)
);
}

@ParameterizedTest
@MethodSource("input")
void doesNotThrow(MemoryEstimation estimation, boolean useMaxMemoryUsage) {
var memoryTrackerMock = mock(MemoryTracker.class);
var memoryUsageValidator = new MemoryUsageValidator(
memoryTrackerMock,
false,
Log.noOpLog()
);
var memoryTree = estimation.estimate(TEST_DIMENSIONS, new Concurrency(1));
var memoryTreeWithDimensions = new MemoryTreeWithDimensions(memoryTree, TEST_DIMENSIONS);

assertDoesNotThrow(() -> memoryUsageValidator.validateMemoryUsage(
memoryTreeWithDimensions.memoryTree.memoryUsage(), 10_000,
useMaxMemoryUsage,
new JobId("foo"), Log.noOpLog()
));
}

@ParameterizedTest
@MethodSource("input")
void throwsOnMinUsageExceeded(MemoryEstimation estimation, boolean ignored) {
var memoryUsageValidator = new MemoryUsageValidator(
null,
false,
Log.noOpLog()
);

var memoryTree = estimation.estimate(TEST_DIMENSIONS, new Concurrency(1));
var memoryTreeWithDimensions = new MemoryTreeWithDimensions(memoryTree, TEST_DIMENSIONS);

assertThatThrownBy(() -> memoryUsageValidator.validateMemoryUsage(
memoryTreeWithDimensions.memoryTree.memoryUsage(), 1,
false,
new JobId("foo"), Log.noOpLog()
))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Procedure was blocked since minimum estimated memory");
}

@ParameterizedTest
@MethodSource("input")
void throwsOnMaxUsageExceeded(MemoryEstimation estimation, boolean ignored) {
var memoryUsageValidator = new MemoryUsageValidator(
null,
false,
Log.noOpLog()
);

var memoryTree = estimation.estimate(TEST_DIMENSIONS, new Concurrency(1));
var memoryTreeWithDimensions = new MemoryTreeWithDimensions(memoryTree, TEST_DIMENSIONS);

assertThatThrownBy(() -> memoryUsageValidator.validateMemoryUsage(
memoryTreeWithDimensions.memoryTree.memoryUsage(), 1,
true,
new JobId("foo"), Log.noOpLog()
))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Procedure was blocked since maximum estimated memory")
.hasMessageContaining(
"Consider resizing your Aura instance via console.neo4j.io. " +
"Alternatively, use 'sudo: true' to override the memory validation. " +
"Overriding the validation is at your own risk. " +
"The database can run out of memory and data can be lost."
);
}

@Configuration
interface TestConfig extends AlgoBaseConfig {
static TestConfig empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,33 @@
*/
package org.neo4j.gds.applications.graphstorecatalog;

import org.neo4j.configuration.Config;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.GraphLoaderContext;
import org.neo4j.gds.api.ImmutableGraphLoaderContext;
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
import org.neo4j.gds.config.GraphProjectConfig;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.mem.MemoryTreeWithDimensions;
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.mem.MemoryTracker;
import org.neo4j.gds.mem.MemoryTreeWithDimensions;
import org.neo4j.gds.settings.GdsSettings;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.transaction.TransactionContext;
import org.neo4j.graphdb.GraphDatabaseService;

public class GraphProjectMemoryUsageService {
private final Log log;
private final GraphDatabaseService graphDatabaseService;
private final MemoryTracker memoryTracker;

public GraphProjectMemoryUsageService(Log log, GraphDatabaseService graphDatabaseService) {
public GraphProjectMemoryUsageService(Log log, GraphDatabaseService graphDatabaseService,
MemoryTracker memoryTracker
) {
this.log = log;
this.graphDatabaseService = graphDatabaseService;
this.memoryTracker = memoryTracker;
}

public void validateMemoryUsage(
Expand Down Expand Up @@ -94,10 +101,11 @@ public MemoryTreeWithDimensions getFictitiousEstimate(GraphProjectConfig configu
}

private MemoryUsageValidator memoryUsageValidator() {
return new MemoryUsageValidator(
log,
GraphDatabaseApiProxy.dependencyResolver(graphDatabaseService)
);
var neo4jConfig = GraphDatabaseApiProxy.dependencyResolver(graphDatabaseService)
.resolveDependency(Config.class);
var useMaxMemoryEstimation = neo4jConfig.get(GdsSettings.validateUsingMaxMemoryEstimation());

return new MemoryUsageValidator(memoryTracker, useMaxMemoryEstimation, log);
}

private GraphLoaderContext graphLoaderContext(
Expand Down
Loading

0 comments on commit cec8a0b

Please sign in to comment.