Skip to content

Commit

Permalink
fixup! Add support for WITH SESSION clause
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Sep 18, 2024
1 parent 1fa8c2f commit 644c66b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ public SessionSpecificationsApplier getSessionSpecificationApplier(PreparedQuery
return session -> prepareSession(session, queryStatement.getSessionProperties(), bindParameters(preparedQuery.getStatement(), preparedQuery.getParameters()));
}

public Session prepareSession(Session session, List<SessionSpecification> specifications, Map<NodeRef<Parameter>, Expression> parameters)
private Session prepareSession(Session session, List<SessionSpecification> specifications, Map<NodeRef<Parameter>, Expression> parameters)
{
ResolvedSessionSpecifications resolvedSessionSpecifications = resolve(session, parameters, specifications);
return overrideProperties(session, resolvedSessionSpecifications);
}

public ResolvedSessionSpecifications resolve(Session session, Map<NodeRef<Parameter>, Expression> parameters, List<SessionSpecification> specifications)
private ResolvedSessionSpecifications resolve(Session session, Map<NodeRef<Parameter>, Expression> parameters, List<SessionSpecification> specifications)
{
ImmutableMap.Builder<String, String> sessionProperties = ImmutableMap.builder();
Table<String, String, String> catalogProperties = HashBasedTable.create();
Expand All @@ -100,21 +100,26 @@ public ResolvedSessionSpecifications resolve(Session session, Map<NodeRef<Parame
}
sessionProperties.put(nameParts.getFirst(), toSessionValue(session, parameters, specification, systemSessionPropertyMetadata.get()));
}
else if (nameParts.size() == 2) {
String catalogName = nameParts.getFirst();
String propertyName = nameParts.getLast();

if (nameParts.size() == 2) {
CatalogHandle catalogHandle = getRequiredCatalogHandle(plannerContext.getMetadata(), session, specification, nameParts.getFirst());
Optional<PropertyMetadata<?>> connectorSessionPropertyMetadata = sessionPropertyManager.getConnectorSessionPropertyMetadata(catalogHandle, nameParts.getLast());
CatalogHandle catalogHandle = getRequiredCatalogHandle(plannerContext.getMetadata(), session, specification, catalogName);
Optional<PropertyMetadata<?>> connectorSessionPropertyMetadata = sessionPropertyManager.getConnectorSessionPropertyMetadata(catalogHandle, propertyName);
if (connectorSessionPropertyMetadata.isEmpty()) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Session property %s does not exist", specification.getName());
}
catalogProperties.put(nameParts.get(0), nameParts.get(1), toSessionValue(session, parameters, specification, connectorSessionPropertyMetadata.get()));
catalogProperties.put(catalogName, propertyName, toSessionValue(session, parameters, specification, connectorSessionPropertyMetadata.get()));
}
else {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Invalid session property '%s'", specification.getName());
}
}

return new ResolvedSessionSpecifications(sessionProperties.buildOrThrow(), catalogProperties.rowMap());
}

public Session overrideProperties(Session session, ResolvedSessionSpecifications resolvedSessionSpecifications)
private Session overrideProperties(Session session, ResolvedSessionSpecifications resolvedSessionSpecifications)
{
requireNonNull(resolvedSessionSpecifications, "resolvedSessionSpecifications is null");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.FeaturesConfig;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.execution.QueryPreparer;
import io.trino.metadata.AbstractMockMetadata;
import io.trino.metadata.MetadataManager;
import io.trino.metadata.ResolvedFunction;
Expand All @@ -33,7 +34,6 @@
import io.trino.sql.PlannerContext;
import io.trino.sql.SessionSpecificationEvaluator;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.tree.SessionSpecification;
import io.trino.transaction.TestingTransactionManager;
import io.trino.transaction.TransactionManager;
import io.trino.type.InternalTypeManager;
Expand All @@ -58,54 +58,57 @@
final class TestSessionSpecifications
{
private static final SqlParser SQL_PARSER = new SqlParser();
private static final SessionPropertyManager SESSION_PROPERTY_MANAGER = new SessionPropertyManager(ImmutableSet.of(new SystemSessionProperties()), catalogHandle -> Map.of(
"catalog_property", PropertyMetadata.stringProperty("catalog_property", "Test catalog property", "", false)));
private static final SessionPropertyManager SESSION_PROPERTY_MANAGER = new SessionPropertyManager(
ImmutableSet.of(new SystemSessionProperties()),
_ -> Map.of("catalog_property", PropertyMetadata.stringProperty("catalog_property", "Test catalog property", "", false)));

@Test
void testParseSystemSessionProperty()
{
assertThatThrownBy(() -> analyze("invalid_key = 'invalid_value'"))
.hasMessageContaining("line 1:1: Session property invalid_key does not exist");
assertThatThrownBy(() -> analyze("WITH SESSION invalid_key = 'invalid_value' SELECT 1"))
.hasMessageContaining("line 1:14: Session property invalid_key does not exist");

assertThatThrownBy(() -> analyze("optimize_hash_generation = 'invalid_value'"))
assertThatThrownBy(() -> analyze("WITH SESSION optimize_hash_generation = 'invalid_value' SELECT 1"))
.hasMessageContaining("Unable to set session property 'optimize_hash_generation' to ''invalid_value'': Cannot cast type varchar(13) to boolean");

assertThat(analyze("optimize_hash_generation = true").getSystemProperties())
assertThat(analyze("WITH SESSION optimize_hash_generation = true SELECT 1").getSystemProperties())
.isEqualTo(Map.of("optimize_hash_generation", "true"));

assertThat(analyze("optimize_hash_generation = CAST('true' AS boolean)").getSystemProperties())
assertThat(analyze("WITH SESSION optimize_hash_generation = CAST('true' AS boolean) SELECT 1").getSystemProperties())
.isEqualTo(Map.of("optimize_hash_generation", "true"));

assertThatThrownBy(() -> analyze("optimize_hash_generation = true", "optimize_hash_generation = false"))
.hasMessageContaining("line 1:1: Session property optimize_hash_generation already set");
assertThatThrownBy(() -> analyze("WITH SESSION optimize_hash_generation = true, optimize_hash_generation = false SELECT 1"))
.hasMessageContaining("line 1:47: Session property optimize_hash_generation already set");
}

@Test
void testCatalogSessionProperty()
{
assertThatThrownBy(() -> analyze("test.invalid_key = 'invalid_value'"))
.hasMessageContaining("line 1:1: Session property test.invalid_key does not exist");
assertThatThrownBy(() -> analyze("WITH SESSION test.invalid_key = 'invalid_value' SELECT 1"))
.hasMessageContaining("line 1:14: Session property test.invalid_key does not exist");

assertThatThrownBy(() -> analyze("test.catalog_property = true"))
assertThatThrownBy(() -> analyze("WITH SESSION test.catalog_property = true SELECT 1"))
.hasMessageContaining("Unable to set session property 'test.catalog_property' to 'true': Cannot cast type boolean to varchar");

assertThat(analyze("test.catalog_property = 'true'").getCatalogProperties("test"))
assertThat(analyze("WITH SESSION test.catalog_property = 'true' SELECT 1").getCatalogProperties("test"))
.isEqualTo(Map.of("catalog_property", "true"));

assertThat(analyze("test.catalog_property = CAST(true AS varchar)").getCatalogProperties("test"))
assertThat(analyze("WITH SESSION test.catalog_property = CAST(true AS varchar) SELECT 1").getCatalogProperties("test"))
.isEqualTo(Map.of("catalog_property", "true"));

assertThatThrownBy(() -> analyze("test.catalog_property = 'true'", "test.catalog_property = 'false'").getCatalogProperties("test"))
.hasMessageContaining("line 1:1: Session property test.catalog_property already set");
assertThatThrownBy(() -> analyze("WITH SESSION test.catalog_property = 'true', test.catalog_property = 'false' SELECT 1").getCatalogProperties("test"))
.hasMessageContaining("line 1:46: Session property test.catalog_property already set");
}

private static Session analyze(@Language("SQL") String... statements)
@Test
void testInvalidSessionProperty()
{
ImmutableList.Builder<SessionSpecification> sessionSpecifications = ImmutableList.builder();
for (String statement : statements) {
sessionSpecifications.add(SQL_PARSER.createSessionSpecification(statement));
}
assertThatThrownBy(() -> analyze("WITH SESSION test.schema.invalid_key = 'invalid_value' SELECT 1"))
.hasMessageContaining("line 1:14: Invalid session property 'test.schema.invalid_key'");
}

private static Session analyze(@Language("SQL") String statement)
{
TransactionManager transactionManager = new TestingTransactionManager();
PlannerContext plannerContext = plannerContextBuilder()
.withMetadata(new MockMetadata())
Expand All @@ -115,7 +118,8 @@ private static Session analyze(@Language("SQL") String... statements)
return transaction(transactionManager, plannerContext.getMetadata(), new AllowAllAccessControl())
.execute(testSession(), transactionSession -> {
SessionSpecificationEvaluator evaluator = new SessionSpecificationEvaluator(plannerContext, new AllowAllAccessControl(), SESSION_PROPERTY_MANAGER);
return evaluator.prepareSession(transactionSession, sessionSpecifications.build(), Map.of());
QueryPreparer.PreparedQuery preparedQuery = new QueryPreparer.PreparedQuery(SQL_PARSER.createStatement(statement), ImmutableList.of(), Optional.empty());
return evaluator.getSessionSpecificationApplier(preparedQuery).apply(transactionSession);
});
}

Expand Down

0 comments on commit 644c66b

Please sign in to comment.