Skip to content

Commit

Permalink
Bump BWC Version to 2.18 and Fix Bugs (#1311) (#1312)
Browse files Browse the repository at this point in the history
This PR includes the following updates and bug fixes:

* Bump BWC Version to 2.18: Updated BWC version to 2.18 since the 2.17 branch has been cut.
* Fix Confidence Value Exceeding 1 in RCF: Addressed a bug in RCF where the confidence value could exceed 1. Implemented a check to cap the confidence value at 1, preventing invalid confidence scores.
* Correct Parameter Assignment in GetAnomalyDetectorTransportAction: Fixed an issue where parameter assignments within a method did not affect external variables due to Java's pass-by-value nature.
* Fixed a bug in ResultProcessor where we were supposed to check whether the number of sent messages equals the number of received messages before starting imputation. However, the sent message count was mistakenly based on the number of pages rather than the actual number of messages.
* Fixed a bug where we mistakenly used the total reserved memory bytes as the memory size per entity in PriorityCache.

Testing done:
* added test cases for the buggy scenarios
* manual e2e testing

Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo authored Sep 14, 2024
1 parent 1c0f51e commit c518a7c
Show file tree
Hide file tree
Showing 18 changed files with 540 additions and 39 deletions.
4 changes: 1 addition & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,8 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.ad.transport.ADHCImputeNodeResponse',
'org.opensearch.ad.transport.GetAnomalyDetectorTransportAction',
'org.opensearch.ad.ml.ADColdStart',
'org.opensearch.ad.transport.ADHCImputeNodesResponse',
'org.opensearch.timeseries.transport.BooleanNodeResponse',
'org.opensearch.timeseries.ml.TimeSeriesSingleStreamCheckpointDao',
'org.opensearch.timeseries.transport.JobRequest',
Expand All @@ -713,7 +712,6 @@ List<String> jacocoExclusions = [
'org.opensearch.timeseries.transport.ResultBulkTransportAction',
'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler',
'org.opensearch.timeseries.transport.handler.ResultIndexingHandler',
'org.opensearch.ad.transport.ADHCImputeNodeResponse',
'org.opensearch.timeseries.ml.Sample',
'org.opensearch.timeseries.ratelimit.FeatureRequest',
'org.opensearch.ad.transport.ADHCImputeNodeRequest',
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/ad/model/AnomalyResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ public static AnomalyResult fromRawTRCFResult(
taskId,
rcfScore,
Math.max(0, grade),
confidence,
Math.min(1, confidence),
featureData,
dataStartTime,
dataEndTime,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ public GetAnomalyDetectorTransportAction(
}

@Override
protected void fillInHistoricalTaskforBwc(Map<String, ADTask> tasks, Optional<ADTask> historicalAdTask) {
protected Optional<ADTask> fillInHistoricalTaskforBwc(Map<String, ADTask> tasks) {
if (tasks.containsKey(ADTaskType.HISTORICAL.name())) {
historicalAdTask = Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name()));
return Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name()));
}
return Optional.empty();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public static List<ForecastResult> fromRawRCFCasterResult(
new ForecastResult(
forecasterId,
taskId,
dataQuality,
Math.min(1, dataQuality),
featureData,
dataStartTime,
dataEndTime,
Expand Down Expand Up @@ -218,7 +218,7 @@ public static List<ForecastResult> fromRawRCFCasterResult(
new ForecastResult(
forecasterId,
taskId,
dataQuality,
Math.min(1, dataQuality),
null,
dataStartTime,
dataEndTime,
Expand Down
17 changes: 4 additions & 13 deletions src/main/java/org/opensearch/timeseries/caching/PriorityCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -463,15 +463,16 @@ public Pair<List<Entity>, List<Entity>> selectUpdateCandidate(Collection<Entity>
return Pair.of(hotEntities, coldEntities);
}

private CacheBufferType computeBufferIfAbsent(Config config, String configId) {
public CacheBufferType computeBufferIfAbsent(Config config, String configId) {
CacheBufferType buffer = activeEnities.get(configId);
if (buffer == null) {
long requiredBytes = getRequiredMemory(config, config.isHighCardinality() ? hcDedicatedCacheSize : 1);
long bytesPerEntityModel = getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees);
long requiredBytes = bytesPerEntityModel * (config.isHighCardinality() ? hcDedicatedCacheSize : 1);
if (memoryTracker.canAllocateReserved(requiredBytes)) {
memoryTracker.consumeMemory(requiredBytes, true, origin);
buffer = createEmptyCacheBuffer(
config,
requiredBytes,
bytesPerEntityModel,
priorityTrackerMap
.getOrDefault(
configId,
Expand All @@ -496,16 +497,6 @@ private CacheBufferType computeBufferIfAbsent(Config config, String configId) {
return buffer;
}

/**
*
* @param config Detector config accessor
* @param numberOfEntity number of entities
* @return Memory in bytes required for hosting numberOfEntity entities
*/
private long getRequiredMemory(Config config, int numberOfEntity) {
return numberOfEntity * getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees);
}

/**
* Whether the candidate entity can replace any entity in the shared cache.
* We can have race conditions when multiple threads try to evaluate this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public void doExecute(Task task, ActionRequest request, ActionListener<GetConfig
}
}

protected void getConfigAndJob(
public void getConfigAndJob(
String configID,
boolean returnJob,
boolean returnTask,
Expand Down Expand Up @@ -251,7 +251,7 @@ public void getExecute(GetConfigRequest request, ActionListener<GetConfigRespons
} else {
// AD needs to provides custom behavior for bwc, while forecasting can inherit
// the empty implementation
fillInHistoricalTaskforBwc(tasks, historicalTask);
historicalTask = fillInHistoricalTaskforBwc(tasks);
}
}
getConfigAndJob(configID, returnJob, returnTask, realtimeTask, historicalTask, listener);
Expand Down Expand Up @@ -357,7 +357,9 @@ public void onFailure(Exception e) {
};
}

protected void fillInHistoricalTaskforBwc(Map<String, TaskClass> tasks, Optional<TaskClass> historicalAdTask) {}
protected Optional<TaskClass> fillInHistoricalTaskforBwc(Map<String, TaskClass> tasks) {
return Optional.empty();
}

protected void getExecuteProfile(
GetConfigRequest request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ public void onResponse(CompositeRetriever.Page entityFeatures) {
pageIterator.next(this);
}
if (entityFeatures != null && false == entityFeatures.isEmpty()) {
sentOutPages.incrementAndGet();

LOG
.info(
"Sending an HC request to process data from timestamp {} to {} for config {}",
Expand Down Expand Up @@ -285,6 +283,7 @@ public void onResponse(CompositeRetriever.Page entityFeatures) {
final AtomicReference<Exception> failure = new AtomicReference<>();

node2Entities.stream().forEach(nodeEntity -> {
sentOutPages.incrementAndGet();
DiscoveryNode node = nodeEntity.getKey();
transportService
.sendRequest(
Expand Down Expand Up @@ -370,7 +369,15 @@ public void run() {
cancellable.get().cancel();
}
} else if (Instant.now().toEpochMilli() >= timeoutMillis) {
LOG.warn("Scheduled impute HC task is cancelled due to timeout");
LOG
.warn(
"Scheduled impute HC task is cancelled due to timeout, current epoch {}, timeout epoch {}, dataEndTime {}, sent out {}, receive {}",
Instant.now().toEpochMilli(),
timeoutMillis,
dataEndTime,
sentOutPages.get(),
receivedPages.get()
);
if (cancellable != null) {
cancellable.get().cancel();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,29 @@ protected List<Object> waitUntilTaskReachState(String detectorId, Set<String> ta
return results;
}

protected List<Object> waitUntilTaskReachNumberOfEntities(String detectorId, int categoricalValuesCount) throws InterruptedException {
List<Object> results = new ArrayList<>();
int i = 0;
ADTaskProfile adTaskProfile = null;
// Increase retryTimes if some task can't reach done state
while ((adTaskProfile == null
|| adTaskProfile.getTotalEntitiesCount() == null
|| adTaskProfile.getTotalEntitiesCount().intValue() != categoricalValuesCount) && i < MAX_RETRY_TIMES) {
try {
adTaskProfile = getADTaskProfile(detectorId);
} catch (Exception e) {
logger.error("failed to get ADTaskProfile", e);
} finally {
Thread.sleep(1000);
}
i++;
}
assertNotNull(adTaskProfile);
results.add(adTaskProfile);
results.add(i);
return results;
}

protected List<Object> waitUntilEntityCountAvailable(String detectorId) throws InterruptedException {
List<Object> results = new ArrayList<>();
int i = 0;
Expand Down
46 changes: 46 additions & 0 deletions src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque;
Expand Down Expand Up @@ -62,6 +63,7 @@
import org.opensearch.threadpool.Scheduler.ScheduledCancellable;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.MemoryTracker;
import org.opensearch.timeseries.TestHelpers;
import org.opensearch.timeseries.breaker.CircuitBreakerService;
import org.opensearch.timeseries.common.exception.LimitExceededException;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
Expand Down Expand Up @@ -788,4 +790,48 @@ public void testGetTotalUpdates_orElseGetBranchWithNullSamples() {
// Assert that the result is 0L
assertEquals(0L, result);
}

public void testAllocation() throws IOException {
JvmService jvmService = mock(JvmService.class);
JvmInfo info = mock(JvmInfo.class);

when(jvmService.info()).thenReturn(info);

Mem mem = mock(Mem.class);
when(mem.getHeapMax()).thenReturn(new ByteSizeValue(800_000_000L));
when(info.getMem()).thenReturn(mem);

CircuitBreakerService circuitBreaker = mock(CircuitBreakerService.class);
when(circuitBreaker.isOpen()).thenReturn(false);
MemoryTracker tracker = new MemoryTracker(jvmService, 0.1, clusterService, circuitBreaker);

dedicatedCacheSize = 10;
ADPriorityCache cache = new ADPriorityCache(
checkpoint,
dedicatedCacheSize,
AnomalyDetectorSettings.AD_CHECKPOINT_TTL,
AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES,
tracker,
TimeSeriesSettings.NUM_TREES,
clock,
clusterService,
TimeSeriesSettings.HOURLY_MAINTENANCE,
threadPool,
TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT,
Settings.EMPTY,
AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ,
checkpointWriteQueue,
checkpointMaintainQueue
);

List<String> categoryFields = Arrays.asList("category_field_1", "category_field_2");
AnomalyDetector anomalyDetector = TestHelpers.AnomalyDetectorBuilder
.newInstance(5)
.setShingleSize(8)
.setCategoryFields(categoryFields)
.build();
ADCacheBuffer buffer = cache.computeBufferIfAbsent(anomalyDetector, anomalyDetector.getId());
assertEquals(698336, buffer.getMemoryConsumptionPerModel());
assertEquals(698336 * dedicatedCacheSize, tracker.getTotalMemoryBytes());
}
}
35 changes: 28 additions & 7 deletions src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected TrainResult ingestTrainDataAndCreateDetector(
int trainTestSplit,
boolean useDateNanos
) throws Exception {
return ingestTrainDataAndCreateDetector(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos, -1);
return ingestTrainDataAndCreateDetector(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos, -1, true);
}

protected TrainResult ingestTrainDataAndCreateDetector(
Expand All @@ -56,7 +56,8 @@ protected TrainResult ingestTrainDataAndCreateDetector(
int numberOfEntities,
int trainTestSplit,
boolean useDateNanos,
int ingestDataSize
int ingestDataSize,
boolean relative
) throws Exception {
TrainResult trainResult = ingestTrainData(
datasetName,
Expand All @@ -67,15 +68,30 @@ protected TrainResult ingestTrainDataAndCreateDetector(
ingestDataSize
);

String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult);
String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult, relative);
String detectorId = createDetector(client(), detector);
LOG.info("Created detector {}", detectorId);
trainResult.detectorId = detectorId;

return trainResult;
}

protected String genDetector(String datasetName, int intervalMinutes, int trainTestSplit, TrainResult trainResult) {
protected String genDetector(String datasetName, int intervalMinutes, int trainTestSplit, TrainResult trainResult, boolean relative) {
// Determine threshold types and values based on the 'relative' parameter
String thresholdType1;
String thresholdType2;
double value;
if (relative) {
thresholdType1 = "actual_over_expected_ratio";
thresholdType2 = "expected_over_actual_ratio";
value = 0.3;
} else {
thresholdType1 = "actual_over_expected_margin";
thresholdType2 = "expected_over_actual_margin";
value = 3000.0;
}

// Generate the detector JSON string with the appropriate threshold types and values
String detector = String
.format(
Locale.ROOT,
Expand All @@ -87,15 +103,20 @@ protected String genDetector(String datasetName, int intervalMinutes, int trainT
+ "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}},"
+ "\"history\": %d,"
+ "\"schema_version\": 0,"
+ "\"rules\": [{\"action\": \"ignore_anomaly\", \"conditions\": [{\"feature_name\": \"feature 1\", \"threshold_type\": \"actual_over_expected_ratio\", \"operator\": \"lte\", \"value\": 0.3}, "
+ "{\"feature_name\": \"feature 1\", \"threshold_type\": \"expected_over_actual_ratio\", \"operator\": \"lte\", \"value\": 0.3}"
+ "\"rules\": [{\"action\": \"ignore_anomaly\", \"conditions\": ["
+ "{ \"feature_name\": \"feature 1\", \"threshold_type\": \"%s\", \"operator\": \"lte\", \"value\": %f }, "
+ "{ \"feature_name\": \"feature 1\", \"threshold_type\": \"%s\", \"operator\": \"lte\", \"value\": %f }"
+ "]}]"
+ "}",
datasetName,
intervalMinutes,
categoricalField,
trainResult.windowDelay.toMinutes(),
trainTestSplit - 1
trainTestSplit - 1,
thresholdType1,
value,
thresholdType2,
value
);
return detector;
}
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public void testRule() throws Exception {
(trainTestSplit + 1) * numberOfEntities
);

String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult);
String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult, true);
Map<String, Object> result = preview(detector, trainResult.firstDataTime, trainResult.finalDataTime, client());
List<Object> results = (List<Object>) XContentMapValues.extractValue(result, "anomaly_result");
assertTrue(results.size() > 100);
Expand Down
13 changes: 11 additions & 2 deletions src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import com.google.gson.JsonObject;

public class RealTimeRuleIT extends AbstractRuleTestCase {
public void testRuleWithDateNanos() throws Exception {
private void template(boolean reltive) throws Exception {
// TODO: this test case will run for a much longer time and timeout with security enabled
if (!isHttps()) {
disableResourceNotFoundFaultTolerence();
Expand All @@ -32,7 +32,8 @@ public void testRuleWithDateNanos() throws Exception {
trainTestSplit,
true,
// ingest just enough for finish the test
(trainTestSplit + 1) * numberOfEntities
(trainTestSplit + 1) * numberOfEntities,
reltive
);

startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, false);
Expand Down Expand Up @@ -90,4 +91,12 @@ public void testRuleWithDateNanos() throws Exception {
}
}
}

public void testRelativeRule() throws Exception {
template(true);
}

public void testAbsoluateRule() throws Exception {
template(false);
}
}
21 changes: 21 additions & 0 deletions src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -1275,4 +1275,25 @@ public void testNotEnoughTrainingData() throws IOException, InterruptedException
checkSemaphoreRelease();
assertTrue(modelState.getModel().isEmpty());
}

public void testTrainModelFromInvalidSamplesNotEnoughSamples() {
Deque<Sample> samples = new ArrayDeque<>();
// we have at least numMinSamples samples before executing the null check of trainModelFromDataSegments
for (int i = 0; i < numMinSamples; i++) {
samples.add(new Sample());
}

modelState = new ModelState<ThresholdedRandomCutForest>(
null,
modelId,
detectorId,
ModelManager.ModelType.TRCF.getName(),
clock,
priority,
Optional.of(entity),
samples
);
entityColdStarter.trainModelFromExistingSamples(modelState, detector, "123");
assertTrue(modelState.getModel().isEmpty());
}
}
Loading

0 comments on commit c518a7c

Please sign in to comment.