Skip to content

Commit

Permalink
Add support for WITH SESSION clause
Browse files Browse the repository at this point in the history
Co-Authored-By: Mateusz "Serafin" Gajewski <[email protected]>
  • Loading branch information
ebyhr and wendigo committed Sep 18, 2024
1 parent a429c24 commit 308303f
Show file tree
Hide file tree
Showing 25 changed files with 408 additions and 8 deletions.
11 changes: 10 additions & 1 deletion core/trino-main/src/main/java/io/trino/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static io.trino.SystemSessionProperties.TIME_ZONE_ID;
import static io.trino.client.ProtocolHeaders.TRINO_HEADERS;
import static io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND;
import static io.trino.spi.StandardErrorCode.NOT_FOUND;
import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey;
import static io.trino.sql.SqlPath.EMPTY_PATH;
import static io.trino.util.Failures.checkCondition;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -416,6 +418,11 @@ public Session withDefaultProperties(Map<String, String> systemPropertyDefaults,
.putAll(catalogEntry.getValue());
}

return withProperties(systemProperties, catalogProperties);
}

public Session withProperties(Map<String, String> systemProperties, Map<String, Map<String, String>> catalogProperties)
{
return new Session(
queryId,
querySpan,
Expand All @@ -428,7 +435,9 @@ public Session withDefaultProperties(Map<String, String> systemPropertyDefaults,
schema,
path,
traceToken,
timeZoneKey,
Optional.ofNullable(systemProperties.get(TIME_ZONE_ID))
.map(TimeZoneKey::getTimeZoneKey)
.orElse(timeZoneKey),
locale,
remoteUserAddress,
userAgent,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import io.trino.server.protocol.Slug;
import io.trino.spi.TrinoException;
import io.trino.spi.resourcegroups.ResourceGroupId;
import io.trino.sql.SessionSpecificationEvaluator;
import io.trino.sql.tree.Statement;
import io.trino.transaction.TransactionId;
import io.trino.transaction.TransactionManager;
Expand All @@ -59,6 +60,7 @@ public class LocalDispatchQueryFactory
private final TransactionManager transactionManager;
private final AccessControl accessControl;
private final Metadata metadata;
private final SessionSpecificationEvaluator sessionSpecificationEvaluator;
private final QueryMonitor queryMonitor;
private final LocationFactory locationFactory;

Expand All @@ -77,6 +79,7 @@ public LocalDispatchQueryFactory(
QueryManager queryManager,
QueryManagerConfig queryManagerConfig,
TransactionManager transactionManager,
SessionSpecificationEvaluator sessionSpecificationEvaluator,
AccessControl accessControl,
Metadata metadata,
QueryMonitor queryMonitor,
Expand All @@ -92,6 +95,7 @@ public LocalDispatchQueryFactory(
this.transactionManager = requireNonNull(transactionManager, "transactionManager is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.sessionSpecificationEvaluator = requireNonNull(sessionSpecificationEvaluator, "sessionSpecificationEvaluator is null");
this.queryMonitor = requireNonNull(queryMonitor, "queryMonitor is null");
this.locationFactory = requireNonNull(locationFactory, "locationFactory is null");
this.executionFactories = requireNonNull(executionFactories, "executionFactories is null");
Expand Down Expand Up @@ -132,6 +136,7 @@ public DispatchQuery createDispatchQuery(
planOptimizersStatsCollector,
getQueryType(preparedQuery.getStatement()),
faultTolerantExecutionExchangeEncryptionEnabled,
Optional.of(sessionSpecificationEvaluator.getSessionSpecificationApplier(preparedQuery)),
version);

// It is important that `queryCreatedEvent` is called here. Moving it past the `executor.submit` below
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import io.trino.spi.resourcegroups.ResourceGroupId;
import io.trino.spi.security.SelectedRole;
import io.trino.spi.type.Type;
import io.trino.sql.SessionSpecificationEvaluator.SessionSpecificationsApplier;
import io.trino.sql.analyzer.Output;
import io.trino.sql.planner.PlanFragment;
import io.trino.tracing.TrinoAttributes;
Expand Down Expand Up @@ -243,6 +244,7 @@ public static QueryStateMachine begin(
PlanOptimizersStatsCollector queryStatsCollector,
Optional<QueryType> queryType,
boolean faultTolerantExecutionExchangeEncryptionEnabled,
Optional<SessionSpecificationsApplier> sessionSpecificationsApplier,
NodeVersion version)
{
return beginWithTicker(
Expand All @@ -262,6 +264,7 @@ public static QueryStateMachine begin(
queryStatsCollector,
queryType,
faultTolerantExecutionExchangeEncryptionEnabled,
sessionSpecificationsApplier,
version);
}

Expand All @@ -282,6 +285,7 @@ static QueryStateMachine beginWithTicker(
PlanOptimizersStatsCollector queryStatsCollector,
Optional<QueryType> queryType,
boolean faultTolerantExecutionExchangeEncryptionEnabled,
Optional<SessionSpecificationsApplier> sessionSpecificationsApplier,
NodeVersion version)
{
// if there is an existing transaction, activate it
Expand All @@ -308,6 +312,11 @@ static QueryStateMachine beginWithTicker(
session = session.withExchangeEncryption(serializeAesEncryptionKey(createRandomAesEncryptionKey()));
}

// Apply WITH SESSION specifications which require transaction to be started to resolve catalog handles
if (sessionSpecificationsApplier.isPresent()) {
session = sessionSpecificationsApplier.orElseThrow().apply(session);
}

Span querySpan = session.getQuerySpan();

querySpan.setAttribute(TrinoAttributes.QUERY_TYPE, queryType.map(Enum::name).orElse("UNKNOWN"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
import io.trino.server.ui.WorkerResource;
import io.trino.spi.VersionEmbedder;
import io.trino.sql.PlannerContext;
import io.trino.sql.SessionSpecificationEvaluator;
import io.trino.sql.analyzer.AnalyzerFactory;
import io.trino.sql.analyzer.QueryExplainerFactory;
import io.trino.sql.planner.OptimizerStatsMBeanExporter;
Expand Down Expand Up @@ -210,6 +211,8 @@ protected void setup(Binder binder)

// dispatcher
binder.bind(DispatchManager.class).in(Scopes.SINGLETON);
// WITH SESSION interpreter
binder.bind(SessionSpecificationEvaluator.class).in(Scopes.SINGLETON);
// export under the old name, for backwards compatibility
newExporter(binder).export(DispatchManager.class).as(generator -> generator.generatedNameOf(QueryManager.class));
binder.bind(FailedDispatchQueryFactory.class).in(Scopes.SINGLETON);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Table;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.execution.QueryPreparer.PreparedQuery;
import io.trino.metadata.SessionPropertyManager;
import io.trino.security.AccessControl;
import io.trino.security.SecurityContext;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.session.PropertyMetadata;
import io.trino.spi.type.Type;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.Query;
import io.trino.sql.tree.SessionSpecification;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

import static com.google.common.base.Preconditions.checkState;
import static io.trino.execution.ParameterExtractor.bindParameters;
import static io.trino.metadata.MetadataUtil.getRequiredCatalogHandle;
import static io.trino.metadata.SessionPropertyManager.evaluatePropertyValue;
import static io.trino.metadata.SessionPropertyManager.serializeSessionProperty;
import static io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND;
import static io.trino.spi.StandardErrorCode.INVALID_SESSION_PROPERTY;
import static io.trino.sql.analyzer.SemanticExceptions.semanticException;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class SessionSpecificationEvaluator
{
private final PlannerContext plannerContext;
private final AccessControl accessControl;
private final SessionPropertyManager sessionPropertyManager;

@Inject
public SessionSpecificationEvaluator(PlannerContext plannerContext, AccessControl accessControl, SessionPropertyManager sessionPropertyManager)
{
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
}

public SessionSpecificationsApplier getSessionSpecificationApplier(PreparedQuery preparedQuery)
{
if (!(preparedQuery.getStatement() instanceof Query queryStatement)) {
return session -> session;
}
return session -> prepareSession(session, queryStatement.getSessionProperties(), bindParameters(preparedQuery.getStatement(), preparedQuery.getParameters()));
}

public Session prepareSession(Session session, List<SessionSpecification> specifications, Map<NodeRef<Parameter>, Expression> parameters)
{
ResolvedSessionSpecifications resolvedSessionSpecifications = resolve(session, parameters, specifications);
return overrideProperties(session, resolvedSessionSpecifications);
}

public ResolvedSessionSpecifications resolve(Session session, Map<NodeRef<Parameter>, Expression> parameters, List<SessionSpecification> specifications)
{
ImmutableMap.Builder<String, String> sessionProperties = ImmutableMap.builder();
Table<String, String, String> catalogProperties = HashBasedTable.create();
Set<QualifiedName> seenPropertyNames = new HashSet<>();

for (SessionSpecification specification : specifications) {
List<String> nameParts = specification.getName().getParts();

if (!seenPropertyNames.add(specification.getName())) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Session property %s already set", specification.getName());
}

if (nameParts.size() == 1) {
Optional<PropertyMetadata<?>> systemSessionPropertyMetadata = sessionPropertyManager.getSystemSessionPropertyMetadata(nameParts.getFirst());
if (systemSessionPropertyMetadata.isEmpty()) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Session property %s does not exist", specification.getName());
}
sessionProperties.put(nameParts.getFirst(), toSessionValue(session, parameters, specification, systemSessionPropertyMetadata.get()));
}

if (nameParts.size() == 2) {
CatalogHandle catalogHandle = getRequiredCatalogHandle(plannerContext.getMetadata(), session, specification, nameParts.getFirst());
Optional<PropertyMetadata<?>> connectorSessionPropertyMetadata = sessionPropertyManager.getConnectorSessionPropertyMetadata(catalogHandle, nameParts.getLast());
if (connectorSessionPropertyMetadata.isEmpty()) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Session property %s does not exist", specification.getName());
}
catalogProperties.put(nameParts.get(0), nameParts.get(1), toSessionValue(session, parameters, specification, connectorSessionPropertyMetadata.get()));
}
}

return new ResolvedSessionSpecifications(sessionProperties.buildOrThrow(), catalogProperties.rowMap());
}

public Session overrideProperties(Session session, ResolvedSessionSpecifications resolvedSessionSpecifications)
{
requireNonNull(resolvedSessionSpecifications, "resolvedSessionSpecifications is null");

validateSystemProperties(session, resolvedSessionSpecifications.systemProperties());

// Catalog session properties were already evaluated so we need to evaluate overrides
if (session.getTransactionId().isPresent()) {
validateCatalogProperties(session, resolvedSessionSpecifications.catalogProperties());
}

// NOTE: properties are validated before calling overrideProperties
Map<String, String> systemProperties = new HashMap<>();
systemProperties.putAll(session.getSystemProperties());
systemProperties.putAll(resolvedSessionSpecifications.systemProperties());

Map<String, Map<String, String>> catalogProperties = new HashMap<>(session.getCatalogProperties());
for (Map.Entry<String, Map<String, String>> catalogEntry : resolvedSessionSpecifications.catalogProperties().entrySet()) {
catalogProperties.computeIfAbsent(catalogEntry.getKey(), id -> new HashMap<>())
.putAll(catalogEntry.getValue());
}

return session.withProperties(systemProperties, catalogProperties);
}

private String toSessionValue(Session session, Map<NodeRef<Parameter>, Expression> parameters, SessionSpecification specification, PropertyMetadata<?> propertyMetadata)
{
Type type = propertyMetadata.getSqlType();
Object objectValue;

try {
objectValue = evaluatePropertyValue(specification.getValue(), type, session, plannerContext, accessControl, parameters);
}
catch (TrinoException e) {
throw new TrinoException(
INVALID_SESSION_PROPERTY,
format("Unable to set session property '%s' to '%s': %s", specification.getName(), specification.getValue(), e.getRawMessage()));
}

String value = serializeSessionProperty(type, objectValue);
// verify the SQL value can be decoded by the property
try {
propertyMetadata.decode(objectValue);
}
catch (RuntimeException e) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "%s", e.getMessage());
}

return value;
}

private void validateSystemProperties(Session session, Map<String, String> systemProperties)
{
for (Map.Entry<String, String> property : systemProperties.entrySet()) {
// verify permissions
accessControl.checkCanSetSystemSessionProperty(session.getIdentity(), session.getQueryId(), property.getKey());
// validate session property value
sessionPropertyManager.validateSystemSessionProperty(property.getKey(), property.getValue());
}
}

private void validateCatalogProperties(Session session, Map<String, Map<String, String>> catalogsProperties)
{
checkState(session.getTransactionId().isPresent(), "Not in transaction");
for (Map.Entry<String, Map<String, String>> catalogProperties : catalogsProperties.entrySet()) {
CatalogHandle catalogHandle = plannerContext.getMetadata().getCatalogHandle(session, catalogProperties.getKey())
.orElseThrow(() -> new TrinoException(CATALOG_NOT_FOUND, "Catalog '%s' not found".formatted(catalogProperties.getKey())));

for (Map.Entry<String, String> catalogProperty : catalogProperties.getValue().entrySet()) {
// verify permissions
accessControl.checkCanSetCatalogSessionProperty(new SecurityContext(session.getRequiredTransactionId(), session.getIdentity(), session.getQueryId(), session.getStart()), catalogProperties.getKey(), catalogProperty.getKey());
// validate catalog session property value
sessionPropertyManager.validateCatalogSessionProperty(catalogProperties.getKey(), catalogHandle, catalogProperty.getKey(), catalogProperty.getValue());
}
}
}

public record ResolvedSessionSpecifications(Map<String, String> systemProperties, Map<String, Map<String, String>> catalogProperties)
{
public ResolvedSessionSpecifications
{
systemProperties = ImmutableMap.copyOf(requireNonNull(systemProperties, "systemProperties is null"));
catalogProperties = ImmutableMap.copyOf(requireNonNull(catalogProperties, "catalogProperties is null"));
}
}

@FunctionalInterface
public interface SessionSpecificationsApplier
extends Function<Session, Session>
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ public void testSubmittedForDispatchedQuery()
createPlanOptimizersStatsCollector(),
Optional.of(QueryType.DATA_DEFINITION),
true,
Optional.empty(),
new NodeVersion("test"));
QueryMonitor queryMonitor = new QueryMonitor(
JsonCodec.jsonCodec(StageInfo.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ private static QueryStateMachine stateMachine(TransactionManager transactionMana
createPlanOptimizersStatsCollector(),
Optional.empty(),
true,
Optional.empty(),
new NodeVersion("test"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ private QueryStateMachine stateMachine(TransactionManager transactionManager, Me
createPlanOptimizersStatsCollector(),
Optional.empty(),
true,
Optional.empty(),
new NodeVersion("test"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ private QueryStateMachine createQueryStateMachine(String query, Session session,
createPlanOptimizersStatsCollector(),
Optional.empty(),
true,
Optional.empty(),
new NodeVersion("test"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public void setUp()
createPlanOptimizersStatsCollector(),
Optional.empty(),
true,
Optional.empty(),
new NodeVersion("test"));

this.queryRunner = queryRunner;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ private Set<String> executeDeallocate(String statementName, String sqlString, Se
createPlanOptimizersStatsCollector(),
Optional.empty(),
true,
Optional.empty(),
new NodeVersion("test"));
Deallocate deallocate = new Deallocate(new NodeLocation(1, 1), new Identifier(statementName));
new DeallocateTask().execute(deallocate, stateMachine, emptyList(), WarningCollector.NOOP);
Expand Down
Loading

0 comments on commit 308303f

Please sign in to comment.