Skip to content

Commit

Permalink
Add WITH SESSION clause grammar
Browse files Browse the repository at this point in the history
Co-Authored-By: Mateusz "Serafin" Gajewski <[email protected]>
  • Loading branch information
ebyhr and wendigo committed Sep 18, 2024
1 parent ae167ba commit a429c24
Show file tree
Hide file tree
Showing 14 changed files with 287 additions and 11 deletions.
20 changes: 18 additions & 2 deletions core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ standaloneFunctionSpecification
: functionSpecification EOF
;

standaloneSessionSpecification
: sessionSpecification EOF
;

statement
: rootQuery #statementDefault
| USE schema=identifier #use
Expand Down Expand Up @@ -199,11 +203,23 @@ statement
;

rootQuery
: withFunction? query
: queryScoped? query
;

queryScoped
: WITH withFunction? (SESSION withSession)?
;

withFunction
: WITH functionSpecification (',' functionSpecification)*
: functionSpecification (',' functionSpecification)*
;

withSession
: sessionSpecification (',' sessionSpecification)*
;

sessionSpecification
: qualifiedName EQ expression
;

query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ public static Query singleValueQuery(String columnName, boolean value)
public static Query query(QueryBody body)
{
return new Query(
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
body,
Expand Down
29 changes: 27 additions & 2 deletions core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
import io.trino.sql.tree.SecurityCharacteristic;
import io.trino.sql.tree.Select;
import io.trino.sql.tree.SelectItem;
import io.trino.sql.tree.SessionSpecification;
import io.trino.sql.tree.SetColumnType;
import io.trino.sql.tree.SetPath;
import io.trino.sql.tree.SetProperties;
Expand Down Expand Up @@ -640,8 +641,13 @@ protected Void visitDescribeInput(DescribeInput node, Integer indent)
@Override
protected Void visitQuery(Query node, Integer indent)
{
if (!node.getFunctions().isEmpty()) {
builder.append("WITH\n");
if (!node.getFunctions().isEmpty() || !node.getSessionProperties().isEmpty()) {
builder.append("WITH");
if (!node.getSessionProperties().isEmpty()) {
builder.append(" SESSION");
}
builder.append("\n");

Iterator<FunctionSpecification> functions = node.getFunctions().iterator();
while (functions.hasNext()) {
process(functions.next(), indent + 1);
Expand All @@ -650,6 +656,15 @@ protected Void visitQuery(Query node, Integer indent)
}
builder.append('\n');
}

Iterator<SessionSpecification> sessionProperties = node.getSessionProperties().iterator();
while (sessionProperties.hasNext()) {
process(sessionProperties.next(), indent + 1);
if (sessionProperties.hasNext()) {
builder.append(',');
}
builder.append('\n');
}
}

node.getWith().ifPresent(with -> {
Expand Down Expand Up @@ -2300,6 +2315,16 @@ protected Void visitFunctionSpecification(FunctionSpecification node, Integer in
return null;
}

@Override
protected Void visitSessionSpecification(SessionSpecification node, Integer indent)
{
append(indent, "")
.append(formatName(node.getName()))
.append(" = ")
.append(formatExpression(node.getValue()));
return null;
}

@Override
protected Void visitParameterDeclaration(ParameterDeclaration node, Integer indent)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@
import io.trino.sql.tree.SecurityCharacteristic;
import io.trino.sql.tree.Select;
import io.trino.sql.tree.SelectItem;
import io.trino.sql.tree.SessionSpecification;
import io.trino.sql.tree.SetColumnType;
import io.trino.sql.tree.SetPath;
import io.trino.sql.tree.SetProperties;
Expand Down Expand Up @@ -413,6 +414,12 @@ public Node visitStandaloneFunctionSpecification(SqlBaseParser.StandaloneFunctio
return visit(context.functionSpecification());
}

@Override
public Node visitStandaloneSessionSpecification(SqlBaseParser.StandaloneSessionSpecificationContext context)
{
return visit(context.sessionSpecification());
}

// ******************* statements **********************

@Override
Expand Down Expand Up @@ -1097,10 +1104,16 @@ public Node visitRootQuery(SqlBaseParser.RootQueryContext context)

return new Query(
getLocation(context),
Optional.ofNullable(context.withFunction())
Optional.ofNullable(context.queryScoped())
.map(SqlBaseParser.QueryScopedContext::withFunction)
.map(SqlBaseParser.WithFunctionContext::functionSpecification)
.map(contexts -> visit(contexts, FunctionSpecification.class))
.orElseGet(ImmutableList::of),
Optional.ofNullable(context.queryScoped())
.map(SqlBaseParser.QueryScopedContext::withSession)
.map(SqlBaseParser.WithSessionContext::sessionSpecification)
.map(contexts -> visit(contexts, SessionSpecification.class))
.orElseGet(ImmutableList::of),
query.getWith(),
query.getQueryBody(),
query.getOrderBy(),
Expand All @@ -1116,6 +1129,7 @@ public Node visitQuery(SqlBaseParser.QueryContext context)
return new Query(
getLocation(context),
ImmutableList.of(),
ImmutableList.of(),
visitIfPresent(context.with(), With.class),
body.getQueryBody(),
body.getOrderBy(),
Expand Down Expand Up @@ -1211,6 +1225,7 @@ else if (context.limit.rowCount().INTEGER_VALUE() != null) {
return new Query(
getLocation(context),
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
new QuerySpecification(
getLocation(context),
Expand All @@ -1231,6 +1246,7 @@ else if (context.limit.rowCount().INTEGER_VALUE() != null) {
return new Query(
getLocation(context),
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
term,
orderBy,
Expand Down Expand Up @@ -3722,6 +3738,15 @@ public Node visitFunctionSpecification(SqlBaseParser.FunctionSpecificationContex
statement);
}

@Override
public Node visitSessionSpecification(SqlBaseParser.SessionSpecificationContext context)
{
return new SessionSpecification(
getLocation(context),
getQualifiedName(context.qualifiedName()),
(Expression) visit(context.expression()));
}

@Override
public Node visitParameterDeclaration(SqlBaseParser.ParameterDeclarationContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.sql.tree.NodeLocation;
import io.trino.sql.tree.PathSpecification;
import io.trino.sql.tree.RowPattern;
import io.trino.sql.tree.SessionSpecification;
import io.trino.sql.tree.Statement;
import org.antlr.v4.runtime.ANTLRErrorListener;
import org.antlr.v4.runtime.BaseErrorListener;
Expand Down Expand Up @@ -120,6 +121,11 @@ public FunctionSpecification createFunctionSpecification(String sql)
return (FunctionSpecification) invokeParser("function specification", sql, SqlBaseParser::standaloneFunctionSpecification);
}

public SessionSpecification createSessionSpecification(String sql)
{
return (SessionSpecification) invokeParser("session specification", sql, SqlBaseParser::standaloneSessionSpecification);
}

private Node invokeParser(String name, String sql, Function<SqlBaseParser, ParserRuleContext> parseFunction)
{
return invokeParser(name, sql, Optional.empty(), parseFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,11 @@ protected R visitFunctionSpecification(FunctionSpecification node, C context)
return visitNode(node, context);
}

protected R visitSessionSpecification(SessionSpecification node, C context)
{
return visitNode(node, context);
}

protected R visitParameterDeclaration(ParameterDeclaration node, C context)
{
return visitNode(node, context);
Expand Down
22 changes: 18 additions & 4 deletions core/trino-parser/src/main/java/io/trino/sql/tree/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class Query
extends Statement
{
private final List<FunctionSpecification> functions;
private final List<SessionSpecification> sessionProperties;
private final Optional<With> with;
private final QueryBody queryBody;
private final Optional<OrderBy> orderBy;
Expand All @@ -35,38 +36,42 @@ public class Query

public Query(
List<FunctionSpecification> functions,
List<SessionSpecification> sessionProperties,
Optional<With> with,
QueryBody queryBody,
Optional<OrderBy> orderBy,
Optional<Offset> offset,
Optional<Node> limit)
{
this(Optional.empty(), functions, with, queryBody, orderBy, offset, limit);
this(Optional.empty(), functions, sessionProperties, with, queryBody, orderBy, offset, limit);
}

public Query(
NodeLocation location,
List<FunctionSpecification> functions,
List<SessionSpecification> sessionProperties,
Optional<With> with,
QueryBody queryBody,
Optional<OrderBy> orderBy,
Optional<Offset> offset,
Optional<Node> limit)
{
this(Optional.of(location), functions, with, queryBody, orderBy, offset, limit);
this(Optional.of(location), functions, sessionProperties, with, queryBody, orderBy, offset, limit);
}

private Query(
Optional<NodeLocation> location,
List<FunctionSpecification> functions,
List<SessionSpecification> sessionProperties,
Optional<With> with,
QueryBody queryBody,
Optional<OrderBy> orderBy,
Optional<Offset> offset,
Optional<Node> limit)
{
super(location);
requireNonNull(functions, "function si snull");
requireNonNull(functions, "functions is null");
requireNonNull(sessionProperties, "sessionProperties is null");
requireNonNull(with, "with is null");
requireNonNull(queryBody, "queryBody is null");
requireNonNull(orderBy, "orderBy is null");
Expand All @@ -75,6 +80,7 @@ private Query(
checkArgument(!limit.isPresent() || limit.get() instanceof FetchFirst || limit.get() instanceof Limit, "limit must be optional of either FetchFirst or Limit type");

this.functions = ImmutableList.copyOf(functions);
this.sessionProperties = ImmutableList.copyOf(sessionProperties);
this.with = with;
this.queryBody = queryBody;
this.orderBy = orderBy;
Expand All @@ -87,6 +93,11 @@ public List<FunctionSpecification> getFunctions()
return functions;
}

public List<SessionSpecification> getSessionProperties()
{
return sessionProperties;
}

public Optional<With> getWith()
{
return with;
Expand Down Expand Up @@ -123,6 +134,7 @@ public List<Node> getChildren()
{
ImmutableList.Builder<Node> nodes = ImmutableList.builder();
nodes.addAll(functions);
nodes.addAll(sessionProperties);
with.ifPresent(nodes::add);
nodes.add(queryBody);
orderBy.ifPresent(nodes::add);
Expand All @@ -136,6 +148,7 @@ public String toString()
{
return toStringHelper(this)
.add("functions", functions.isEmpty() ? null : functions)
.add("sessionProperties", sessionProperties.isEmpty() ? null : sessionProperties)
.add("with", with.orElse(null))
.add("queryBody", queryBody)
.add("orderBy", orderBy)
Expand All @@ -156,6 +169,7 @@ public boolean equals(Object obj)
}
Query o = (Query) obj;
return Objects.equals(functions, o.functions) &&
Objects.equals(sessionProperties, o.sessionProperties) &&
Objects.equals(with, o.with) &&
Objects.equals(queryBody, o.queryBody) &&
Objects.equals(orderBy, o.orderBy) &&
Expand All @@ -166,7 +180,7 @@ public boolean equals(Object obj)
@Override
public int hashCode()
{
return Objects.hash(functions, with, queryBody, orderBy, offset, limit);
return Objects.hash(functions, sessionProperties, with, queryBody, orderBy, offset, limit);
}

@Override
Expand Down
Loading

0 comments on commit a429c24

Please sign in to comment.