Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support setting session params in query #44

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions presto/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,12 @@ func (c *Client) UserPassword(user, password string) *Client {
return c
}

func (c *Client) GetSessionParams() string {
return c.getHeader(SessionHeader)
func (c *Client) GetSessionParams() map[string]any {
params := make(map[string]any)
for k, v := range c.sessionParams {
params[k] = v
}
return params
}

func (c *Client) ClearSessionParams() *Client {
Expand Down
185 changes: 185 additions & 0 deletions presto/query.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
package presto

import (
"bufio"
"context"
"io"
"net/http"
"pbench/presto/query_json"
"strconv"
"strings"
)

// QueryWithSession represents a query and its additional session parameters
type QueryWithSession struct {
Query string
SessionParams map[string]any
}

func (c *Client) requestQueryResults(ctx context.Context, req *http.Request) (*QueryResults, *http.Response, error) {
qr := new(QueryResults)
resp, err := c.Do(ctx, req, qr)
Expand Down Expand Up @@ -77,3 +86,179 @@ func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool,
}
return queryInfo, resp, nil
}

// ParseSessionCommand checks if a query is a session parameter command and returns the parameter and value
// Format: --session parameter_name=parameter_value or --SET SESSION parameter_name=parameter_value
func ParseSessionCommand(query string) (paramName string, paramValue string, isSession bool) {
query = strings.TrimSpace(query)

// Check if query starts with --session or --set session (case insensitive)
queryLower := strings.ToLower(query)
if !strings.HasPrefix(queryLower, "--session") && !strings.HasPrefix(queryLower, "--set session") {
return "", "", false
}

// Remove the prefix
if strings.HasPrefix(queryLower, "--set session") {
query = strings.TrimSpace(query[13:]) // len("--set session") = 13
} else {
query = strings.TrimSpace(query[9:]) // len("--session") = 9
}

// Split on equals sign and handle spaces
parts := strings.SplitN(query, "=", 2)
if len(parts) != 2 {
return "", "", false
}

paramName = strings.ToLower(strings.TrimSpace(parts[0]))
paramValue = strings.TrimSpace(parts[1])

// Remove quotes if present
if strings.HasPrefix(paramValue, "'") && strings.HasSuffix(paramValue, "'") {
paramValue = paramValue[1:len(paramValue)-1]
}
if strings.HasPrefix(paramValue, "\"") && strings.HasSuffix(paramValue, "\"") {
paramValue = paramValue[1:len(paramValue)-1]
}

// Remove trailing semicolon if present
if strings.HasSuffix(paramValue, ";") {
paramValue = strings.TrimSuffix(paramValue, ";")
}

// Convert value to uppercase for enum values
paramValue = strings.ToUpper(paramValue)

return paramName, paramValue, true
}

// cleanQuery removes unnecessary whitespace, newlines, comments and trailing semicolon from a query
func cleanQuery(query string) string {
// Split into lines and handle each line
lines := strings.Split(query, "\n")
cleanLines := make([]string, 0, len(lines))

for _, line := range lines {
// Remove inline comments
if idx := strings.Index(line, "--"); idx >= 0 {
line = line[:idx]
}

trimmed := strings.TrimSpace(line)
if trimmed != "" {
cleanLines = append(cleanLines, trimmed)
}
}

// Join with single spaces
query = strings.Join(cleanLines, " ")

// Remove trailing semicolon
if strings.HasSuffix(query, ";") {
query = strings.TrimSuffix(query, ";")
}

return query
}

// SplitQueriesWithSession splits a SQL file into individual queries and their associated session parameters
func SplitQueriesWithSession(r io.Reader) ([]QueryWithSession, error) {
queries := make([]QueryWithSession, 0)
currentSessionParams := make(map[string]any)

scanner := bufio.NewScanner(r)
var currentQuery strings.Builder
inMultilineComment := false

for scanner.Scan() {
line := scanner.Text()
trimmedLine := strings.TrimSpace(line)

// Skip empty lines
if len(trimmedLine) == 0 {
continue
}

// Handle multiline comments
if strings.HasPrefix(trimmedLine, "/*") {
inMultilineComment = true
}
if inMultilineComment {
if strings.HasSuffix(trimmedLine, "*/") {
inMultilineComment = false
}
continue
}

// Handle single line comments and session parameters
if strings.HasPrefix(trimmedLine, "--") {
paramName, paramValue, isSession := ParseSessionCommand(trimmedLine)
if isSession {
// Try to parse value as number or boolean first
if val, err := strconv.ParseInt(paramValue, 10, 64); err == nil {
currentSessionParams[paramName] = val
} else if val, err := strconv.ParseFloat(paramValue, 64); err == nil {
currentSessionParams[paramName] = val
} else if val, err := strconv.ParseBool(paramValue); err == nil {
currentSessionParams[paramName] = val
} else {
// Remove any remaining quotes from string values
if strings.HasPrefix(paramValue, "'") && strings.HasSuffix(paramValue, "'") {
paramValue = paramValue[1:len(paramValue)-1]
}
if strings.HasPrefix(paramValue, "\"") && strings.HasSuffix(paramValue, "\"") {
paramValue = paramValue[1:len(paramValue)-1]
}
// Treat as string if not a number or boolean
currentSessionParams[paramName] = paramValue
}
}
continue
}

currentQuery.WriteString(line)
currentQuery.WriteString("\n")

// Check if line ends with semicolon
if strings.HasSuffix(trimmedLine, ";") {
query := strings.TrimSpace(currentQuery.String())
if len(query) > 0 {
// Clean up the query formatting
query = cleanQuery(query)

// Create a copy of current session parameters for this query
sessionParams := make(map[string]any, len(currentSessionParams))
for k, v := range currentSessionParams {
sessionParams[k] = v
}
queries = append(queries, QueryWithSession{
Query: query,
SessionParams: sessionParams,
})
// Clear session parameters after query
currentSessionParams = make(map[string]any)
}
currentQuery.Reset()
}
}

// Handle last query if it doesn't end with semicolon
lastQuery := strings.TrimSpace(currentQuery.String())
if len(lastQuery) > 0 {
sessionParams := make(map[string]any, len(currentSessionParams))
for k, v := range currentSessionParams {
sessionParams[k] = v
}
queries = append(queries, QueryWithSession{
Query: lastQuery,
SessionParams: sessionParams,
})
}

if err := scanner.Err(); err != nil {
return nil, err
}

return queries, nil
}
31 changes: 31 additions & 0 deletions presto/query_splitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,34 @@ another query;;missing semicolon, should be discarded
}
}
}

func TestSplitQueriesWithSession(t *testing.T) {
input := `/* header comment */
--SET SESSION join_reordering_strategy = 'NONE';
--session query_max_memory = '1GB'
--session max_splits_per_node = 1234
--session optimize_hash_generation = true
-- normal comment
SELECT
* -- inline comment
FROM
table1
WHERE
id > 0;`

expected := []presto.QueryWithSession{
{
Query: "SELECT * FROM table1 WHERE id > 0",
SessionParams: map[string]any{
"join_reordering_strategy": "NONE",
"query_max_memory": "1GB",
"max_splits_per_node": int64(1234),
"optimize_hash_generation": true,
},
},
}

queries, err := presto.SplitQueriesWithSession(strings.NewReader(input))
assert.NoError(t, err)
assert.Equal(t, expected, queries)
}
Loading