From 74846a0a36695ea2cd834878cc0225a284f91810 Mon Sep 17 00:00:00 2001 From: Y Ethan Guo Date: Thu, 27 Feb 2025 16:15:56 -0800 Subject: [PATCH] [WIP] Support data skipping based on column stats in Hudi connector --- .../trino/plugin/hudi/HudiSplitManager.java | 60 +++- .../hudi/query/HudiFileSkippingManager.java | 319 ++++++++++++++++++ 2 files changed, 378 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/query/HudiFileSkippingManager.java diff --git a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiSplitManager.java b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiSplitManager.java index 3ef66c66f5cd..8dadd85b0db5 100644 --- a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiSplitManager.java +++ b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiSplitManager.java @@ -13,14 +13,19 @@ */ package io.trino.plugin.hudi; +import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import io.trino.filesystem.TrinoFileSystemFactory; import io.trino.metastore.HiveMetastore; +import io.trino.metastore.HivePartition; import io.trino.metastore.Table; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorSplitSource; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HiveTransactionHandle; +import io.trino.plugin.hudi.query.HudiFileSkippingManager; +import io.trino.plugin.hudi.storage.TrinoStorageConfiguration; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.ConnectorSplitSource; @@ -28,12 +33,17 @@ import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.FixedSplitSource; import io.trino.spi.connector.TableNotFoundException; import io.trino.spi.security.ConnectorIdentity; import io.trino.spi.type.TypeManager; +import org.apache.hudi.common.engine.HoodieLocalEngineContext; +import org.apache.hudi.common.model.HoodieTableQueryType; +import org.apache.hudi.common.table.HoodieTableMetaClient; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.function.BiFunction; @@ -44,6 +54,7 @@ import static io.trino.plugin.hive.util.HiveUtil.getPartitionKeyColumnHandles; import static io.trino.plugin.hudi.HudiSessionProperties.getMaxOutstandingSplits; import static io.trino.plugin.hudi.HudiSessionProperties.getMaxSplitsPerSecond; +import static io.trino.plugin.hudi.HudiSessionProperties.isHudiMetadataTableEnabled; import static io.trino.plugin.hudi.partition.HiveHudiPartitionInfo.NON_PARTITION; import static io.trino.spi.connector.SchemaTableName.schemaTableName; import static java.util.Objects.requireNonNull; @@ -85,11 +96,40 @@ public ConnectorSplitSource getSplits( HiveMetastore metastore = metastoreProvider.apply(session.getIdentity(), (HiveTransactionHandle) transaction); Table table = metastore.getTable(hudiTableHandle.getSchemaName(), hudiTableHandle.getTableName()) .orElseThrow(() -> new TableNotFoundException(schemaTableName(hudiTableHandle.getSchemaName(), hudiTableHandle.getTableName()))); - List partitionColumns = getPartitionKeyColumnHandles(table, typeManager); Map partitionColumnHandles = partitionColumns.stream() .collect(toImmutableMap(HiveColumnHandle::getName, identity())); List partitions = getPartitions(metastore, hudiTableHandle, partitionColumns); + boolean enableMetadataTable = isHudiMetadataTableEnabled(session); + + if (enableMetadataTable) { + Optional hiveTableOpt = metastore.getTable(table.getDatabaseName(), table.getTableName()); + Verify.verify(hiveTableOpt.isPresent()); + HoodieTableMetaClient metaClient = HoodieTableMetaClient + .builder() + .setBasePath(((HudiTableHandle) tableHandle).getBasePath()) + .build(); + HoodieLocalEngineContext engineContext = new HoodieLocalEngineContext(new TrinoStorageConfiguration()); + HudiFileSkippingManager hudiFileSkippingManager = new HudiFileSkippingManager( + partitions, + // TODO(yihua): make this configurable + "/tmp", + engineContext, + metaClient, + HoodieTableQueryType.SNAPSHOT, + Optional.empty()); + ImmutableList.Builder splitsBuilder = ImmutableList.builder(); + Map hudiPartitionMap = getHudiPartitions(hiveTableOpt.get(), hudiTableHandle, partitions); + hudiFileSkippingManager.listQueryFiles(hudiTableHandle.getTupleDomain()) + .entrySet() + .stream() + .flatMap(entry -> entry.getValue().stream().map(fileSlice -> createHudiSplit(table, fileSlice, timestamp, hudiPartitionMap.get(entry.getKey()), splitWeightProvider))) + .filter(Optional::isPresent) + .map(Optional::get) + .forEach(splitsBuilder::add); + List splitsList = splitsBuilder.build(); + return splitsList.isEmpty() ? new FixedSplitSource(ImmutableList.of()) : new FixedSplitSource(splitsList); + } HudiSplitSource splitSource = new HudiSplitSource( session, @@ -106,6 +146,24 @@ public ConnectorSplitSource getSplits( return new ClassLoaderSafeConnectorSplitSource(splitSource, HudiSplitManager.class.getClassLoader()); } + private Map getHudiPartitions(Table table, HudiTableHandle tableHandle, List partitions) + { + List partitionColumnNames = table.getPartitionColumns().stream().map(f -> f.getName()).collect(Collectors.toList()); + + Map> partitionMap = HudiPartitionManager + .getPartitions(partitionColumnNames, partitions); + if (partitions.size() == 1 && partitions.get(0).isEmpty()) { + // non-partitioned + return ImmutableMap.of(partitions.get(0), new HivePartition(partitions.get(0), ImmutableList.of(), tableHandle.getConstraintColumns())); + } + ImmutableMap.Builder builder = ImmutableMap.builder(); + partitionMap.entrySet().stream().map(entry -> { + List partitionValues = HudiPartitionManager.extractPartitionValues(entry.getKey(), Optional.of(partitionColumnNames)); + return new HivePartition(entry.getKey(), partitionValues, entry.getValue(), table.getStorage(), fromDataColumns(table.getDataColumns())); + }).forEach(p -> builder.put(p.getName(), p)); + return builder.build(); + } + private static List getPartitions(HiveMetastore metastore, HudiTableHandle table, List partitionColumns) { if (partitionColumns.isEmpty()) { diff --git a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/query/HudiFileSkippingManager.java b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/query/HudiFileSkippingManager.java new file mode 100644 index 000000000000..44df7c1fc677 --- /dev/null +++ b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/query/HudiFileSkippingManager.java @@ -0,0 +1,319 @@ +package io.trino.plugin.hudi.query; + +import com.google.common.collect.ImmutableList; +import io.airlift.log.Logger; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.parquet.predicate.TupleDomainParquetPredicate; +import io.trino.plugin.hive.HiveColumnHandle; +import io.trino.plugin.hudi.HudiPredicates; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; +import org.apache.avro.generic.GenericRecord; +import org.apache.hudi.avro.model.HoodieMetadataColumnStats; +import org.apache.hudi.common.config.HoodieCommonConfig; +import org.apache.hudi.common.config.HoodieMetadataConfig; +import org.apache.hudi.common.engine.HoodieEngineContext; +import org.apache.hudi.common.model.BaseFile; +import org.apache.hudi.common.model.FileSlice; +import org.apache.hudi.common.model.HoodieTableQueryType; +import org.apache.hudi.common.table.HoodieTableMetaClient; +import org.apache.hudi.common.table.timeline.HoodieInstant; +import org.apache.hudi.common.table.timeline.HoodieTimeline; +import org.apache.hudi.common.table.view.FileSystemViewManager; +import org.apache.hudi.common.table.view.FileSystemViewStorageConfig; +import org.apache.hudi.common.table.view.SyncableFileSystemView; +import org.apache.hudi.common.util.collection.Pair; +import org.apache.hudi.common.util.hash.ColumnIndexID; +import org.apache.hudi.metadata.HoodieTableMetadata; +import org.apache.hudi.metadata.HoodieTableMetadataUtil; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static io.trino.parquet.predicate.PredicateUtils.isStatisticsOverflow; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.lang.Float.floatToRawIntBits; +import static java.util.Objects.requireNonNull; + +public class HudiFileSkippingManager +{ + private static final Logger log = Logger.get(HudiFileSkippingManager.class); + + private final HoodieTableQueryType queryType; + private final Optional specifiedQueryInstant; + private final HoodieTableMetaClient metaClient; + private final HoodieTableMetadata metadataTable; + + private final Map> allInputFileSlices; + + public HudiFileSkippingManager( + List partitions, + String spillableDir, + HoodieEngineContext engineContext, + HoodieTableMetaClient metaClient, + HoodieTableQueryType queryType, + Optional specifiedQueryInstant) + { + requireNonNull(partitions, "partitions is null"); + requireNonNull(spillableDir, "spillableDir is null"); + requireNonNull(engineContext, "engineContext is null"); + this.queryType = requireNonNull(queryType, "queryType is null"); + this.specifiedQueryInstant = requireNonNull(specifiedQueryInstant, "specifiedQueryInstant is null"); + this.metaClient = requireNonNull(metaClient, "metaClient is null"); + + HoodieMetadataConfig metadataConfig = HoodieMetadataConfig.newBuilder().enable(true).build(); + this.metadataTable = HoodieTableMetadata.create( + engineContext, metaClient.getStorage(), metadataConfig, metaClient.getBasePathV2().toString(), true); + this.allInputFileSlices = prepareAllInputFileSlices(partitions, engineContext, metadataConfig, spillableDir); + } + + private Map> prepareAllInputFileSlices( + List partitions, + HoodieEngineContext engineContext, + HoodieMetadataConfig metadataConfig, + String spillableDir) + { + long startTime = System.currentTimeMillis(); + HoodieTimeline activeTimeline = metaClient.reloadActiveTimeline(); + Optional latestInstant = activeTimeline.lastInstant().toJavaOptional(); + // build system view. + SyncableFileSystemView fileSystemView = FileSystemViewManager + .createViewManager(engineContext, + FileSystemViewStorageConfig.newBuilder().withBaseStoreDir(spillableDir).build(), + HoodieCommonConfig.newBuilder().build(), + e -> metadataTable) + .getFileSystemView(metaClient); + Optional queryInstant = specifiedQueryInstant.isPresent() ? + specifiedQueryInstant : latestInstant.map(HoodieInstant::getTimestamp); + + Map> allInputFileSlices = engineContext + .mapToPair( + partitions, + partitionPath -> Pair.of( + partitionPath, + getLatestFileSlices(partitionPath, fileSystemView, queryInstant)), + partitions.size()); + + long duration = System.currentTimeMillis() - startTime; + log.debug("prepare query files for table %s, spent: %d ms", metaClient.getTableConfig().getTableName(), duration); + return allInputFileSlices; + } + + private List getLatestFileSlices( + String partitionPath, + SyncableFileSystemView fileSystemView, + Optional queryInstant) + { + return queryInstant + .map(instant -> + fileSystemView.getLatestMergedFileSlicesBeforeOrOn(partitionPath, queryInstant.get())) + .orElse(fileSystemView.getLatestFileSlices(partitionPath)) + .collect(Collectors.toList()); + } + + public Map> listQueryFiles(TupleDomain tupleDomain) + { + // do file skipping by MetadataTable + Map> candidateFileSlices = allInputFileSlices; + try { + if (!tupleDomain.isAll()) { + candidateFileSlices = lookupCandidateFilesInMetadataTable(candidateFileSlices, tupleDomain); + } + } + catch (Exception e) { + // Should not throw exception, just log this Exception. + log.warn(e, "failed to do data skipping for table: %s, fallback to all files scan", metaClient.getBasePathV2()); + candidateFileSlices = allInputFileSlices; + } + if (log.isDebugEnabled()) { + int candidateFileSize = candidateFileSlices.values().stream().mapToInt(List::size).sum(); + int totalFiles = allInputFileSlices.values().stream().mapToInt(List::size).sum(); + double skippingPercent = totalFiles == 0 ? 0.0d : (totalFiles - candidateFileSize) / (totalFiles + 0.0d); + log.debug("Total files: %s; candidate files after data skipping: %s; skipping percent %s", + totalFiles, + candidateFileSize, + skippingPercent); + } + return candidateFileSlices; + } + + private Map> lookupCandidateFilesInMetadataTable( + Map> inputFileSlices, + TupleDomain tupleDomain) + { + // split regular column predicates + TupleDomain regularTupleDomain = HudiPredicates.from(tupleDomain).getRegularColumnPredicates(); + TupleDomain regularColumnPredicates = regularTupleDomain.transformKeys(HiveColumnHandle::getName); + if (regularColumnPredicates.isAll() || !regularColumnPredicates.getDomains().isPresent()) { + return inputFileSlices; + } + List regularColumns = regularColumnPredicates + .getDomains().get().entrySet().stream().map(Map.Entry::getKey).collect(Collectors.toList()); + // get filter columns + List encodedTargetColumnNames = regularColumns + .stream() + .map(col -> new ColumnIndexID(col).asBase64EncodedString()).collect(Collectors.toList()); + Map> statsByFileName = metadataTable.getRecordsByKeyPrefixes( + encodedTargetColumnNames, + HoodieTableMetadataUtil.PARTITION_NAME_COLUMN_STATS, true) + .collectAsList() + .stream() + .filter(f -> f.getData().getColumnStatMetadata().isPresent()) + .map(f -> f.getData().getColumnStatMetadata().get()) + .collect(Collectors.groupingBy(HoodieMetadataColumnStats::getFileName)); + + // prune files. + return inputFileSlices + .entrySet() + .stream() + .collect(Collectors + .toMap(entry -> entry.getKey(), entry -> entry + .getValue() + .stream() + .filter(fileSlice -> pruneFiles(fileSlice, statsByFileName, regularColumnPredicates, regularColumns)) + .collect(Collectors.toList()))); + } + + private boolean pruneFiles( + FileSlice fileSlice, + Map> statsByFileName, + TupleDomain regularColumnPredicates, + List regularColumns) + { + String fileSliceName = fileSlice.getBaseFile().map(BaseFile::getFileName).orElse(""); + // no stats found + if (!statsByFileName.containsKey(fileSliceName)) { + return true; + } + List stats = statsByFileName.get(fileSliceName); + return evaluateStatisticPredicate(regularColumnPredicates, stats, regularColumns); + } + + private boolean evaluateStatisticPredicate( + TupleDomain regularColumnPredicates, + List stats, + List regularColumns) + { + if (regularColumnPredicates.isNone() || !regularColumnPredicates.getDomains().isPresent()) { + return true; + } + for (String regularColumn : regularColumns) { + Domain columnPredicate = regularColumnPredicates.getDomains().get().get(regularColumn); + Optional currentColumnStats = stats + .stream().filter(s -> s.getColumnName().equals(regularColumn)).findFirst(); + if (!currentColumnStats.isPresent()) { + // no stats for column + } + else { + Domain domain = getDomain(regularColumn, columnPredicate.getType(), currentColumnStats.get()); + if (columnPredicate.intersect(domain).isNone()) { + return false; + } + } + } + return true; + } + + private static Domain getDomain(String colName, Type type, HoodieMetadataColumnStats statistics) + { + if (statistics == null) { + return Domain.all(type); + } + boolean hasNullValue = statistics.getNullCount() != 0L; + boolean hasNonNullValue = statistics.getValueCount() - statistics.getNullCount() > 0; + if (!hasNonNullValue || statistics.getMaxValue() == null || statistics.getMinValue() == null) { + return Domain.create(ValueSet.all(type), hasNullValue); + } + if (!(statistics.getMinValue() instanceof GenericRecord) || + !(statistics.getMaxValue() instanceof GenericRecord)) { + return Domain.all(type); + } + return getDomain(colName, type, ((GenericRecord) statistics.getMinValue()).get(0), + ((GenericRecord) statistics.getMaxValue()).get(0), hasNullValue); + } + + /** + * Get a domain for the ranges defined by each pair of elements from {@code minimums} and {@code maximums}. + * Both arrays must have the same length. + */ + private static Domain getDomain(String colName, Type type, Object minimum, Object maximum, boolean hasNullValue) + { + try { + if (type.equals(BOOLEAN)) { + boolean hasTrueValue = (boolean) minimum || (boolean) maximum; + boolean hasFalseValue = !(boolean) minimum || !(boolean) maximum; + if (hasTrueValue && hasFalseValue) { + return Domain.all(type); + } + if (hasTrueValue) { + return Domain.create(ValueSet.of(type, true), hasNullValue); + } + if (hasFalseValue) { + return Domain.create(ValueSet.of(type, false), hasNullValue); + } + // No other case, since all null case is handled earlier. + } + + if ((type.equals(BIGINT) || type.equals(TINYINT) || type.equals(SMALLINT) + || type.equals(INTEGER) || type.equals(DATE))) { + long minValue = TupleDomainParquetPredicate.asLong(minimum); + long maxValue = TupleDomainParquetPredicate.asLong(maximum); + if (isStatisticsOverflow(type, minValue, maxValue)) { + return Domain.create(ValueSet.all(type), hasNullValue); + } + return ofMinMax(type, minValue, maxValue, hasNullValue); + } + + if (type.equals(REAL)) { + Float minValue = (Float) minimum; + Float maxValue = (Float) maximum; + if (minValue.isNaN() || maxValue.isNaN()) { + return Domain.create(ValueSet.all(type), hasNullValue); + } + return ofMinMax(type, (long) floatToRawIntBits(minValue), (long) floatToRawIntBits(maxValue), hasNullValue); + } + + if (type.equals(DOUBLE)) { + Double minValue = (Double) minimum; + Double maxValue = (Double) maximum; + if (minValue.isNaN() || maxValue.isNaN()) { + return Domain.create(ValueSet.all(type), hasNullValue); + } + return ofMinMax(type, minValue, maxValue, hasNullValue); + } + + if (type.equals(VarcharType.VARCHAR)) { + Slice min = Slices.utf8Slice((String) minimum); + Slice max = Slices.utf8Slice((String) maximum); + return ofMinMax(type, min, max, hasNullValue); + } + return Domain.create(ValueSet.all(type), hasNullValue); + } + catch (Exception e) { + log.warn("failed to create Domain for column: %s which type is: %s", colName, type.toString()); + return Domain.create(ValueSet.all(type), hasNullValue); + } + } + + private static Domain ofMinMax(Type type, Object min, Object max, boolean hasNullValue) + { + Range range = Range.range(type, min, true, max, true); + ValueSet vs = ValueSet.ofRanges(ImmutableList.of(range)); + return Domain.create(vs, hasNullValue); + } +}