diff --git a/gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java b/gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java index aa56da8b3..1f52be6c3 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java +++ b/gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java @@ -39,6 +39,8 @@ import io.trino.gateway.ha.router.RoutingRulesManager; import io.trino.gateway.ha.router.StochasticRoutingManager; import io.trino.gateway.ha.security.AuthorizedExceptionMapper; +import io.trino.gateway.ha.security.QueryMetadataParser; +import io.trino.gateway.ha.security.QueryUserInfoParser; import io.trino.gateway.proxyserver.ForProxy; import io.trino.gateway.proxyserver.ProxyRequestHandler; import io.trino.gateway.proxyserver.RouteToBackendResource; @@ -187,6 +189,8 @@ private static void registerProxyResources(Binder binder) { jaxrsBinder(binder).bind(RouteToBackendResource.class); jaxrsBinder(binder).bind(RouterPreMatchContainerRequestFilter.class); + jaxrsBinder(binder).bind(QueryUserInfoParser.class); + jaxrsBinder(binder).bind(QueryMetadataParser.class); jaxrsBinder(binder).bind(ProxyRequestHandler.class); httpClientBinder(binder).bindHttpClient("proxy", ForProxy.class); httpClientBinder(binder).bindHttpClient("monitor", ForMonitor.class); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/HttpUtils.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/HttpUtils.java index 77053c567..bf683d1f4 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/HttpUtils.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/HttpUtils.java @@ -25,6 +25,8 @@ public class HttpUtils public static final String TRINO_UI_PATH = "/ui"; public static final String OAUTH_PATH = "/oauth2"; public static final String USER_HEADER = "X-Trino-User"; + public static final String TRINO_REQUEST_USER = "trinoRequestUser"; + public static final String TRINO_QUERY_PROPERTIES = "trinoQueryProperties"; private HttpUtils() {} } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java index e64ed07d8..afd314a48 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java @@ -30,6 +30,7 @@ import java.util.regex.Pattern; import static com.google.common.base.Strings.isNullOrEmpty; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES; import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH; import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH; import static java.nio.charset.StandardCharsets.UTF_8; @@ -78,7 +79,7 @@ public static Optional extractQueryIdIfPresent( throw new RuntimeException("Error reading request body", e); } if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) { - TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize); + TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES); return trinoQueryProperties.getQueryId(); } return Optional.empty(); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java index 9b3f03bc4..16521c44e 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java @@ -33,18 +33,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.regex.Pattern; import java.util.stream.Stream; import static com.google.common.base.Strings.isNullOrEmpty; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH; -import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH; -import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH; import static io.trino.gateway.ha.handler.HttpUtils.USER_HEADER; -import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH; -import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH; -import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH; import static io.trino.gateway.ha.handler.ProxyUtils.buildUriWithNewCluster; import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent; import static java.util.Objects.requireNonNull; @@ -56,7 +48,6 @@ public class RoutingTargetHandler private final RoutingGroupSelector routingGroupSelector; private final String defaultRoutingGroup; private final List statementPaths; - private final List extraWhitelistPaths; private final boolean requestAnalyserClientsUseV2Format; private final int requestAnalyserMaxBodySize; private final boolean cookiesEnabled; @@ -71,7 +62,6 @@ public RoutingTargetHandler( this.routingGroupSelector = requireNonNull(routingGroupSelector); this.defaultRoutingGroup = haGatewayConfiguration.getRouting().getDefaultRoutingGroup(); statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths()); - extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList()); requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format(); requestAnalyserMaxBodySize = haGatewayConfiguration.getRequestAnalyzerConfig().getMaxBodySize(); cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled(); @@ -118,18 +108,6 @@ private RoutingTargetResponse getRoutingTargetResponse(HttpServletRequest reques modifiedRequest); } - public boolean isPathWhiteListed(String path) - { - return statementPaths.stream().anyMatch(path::startsWith) - || path.startsWith(V1_QUERY_PATH) - || path.startsWith(TRINO_UI_PATH) - || path.startsWith(V1_INFO_PATH) - || path.startsWith(V1_NODE_PATH) - || path.startsWith(UI_API_STATS_PATH) - || path.startsWith(OAUTH_PATH) - || extraWhitelistPaths.stream().anyMatch(pattern -> pattern.matcher(path).matches()); - } - /** * A wrapper for HttpServletRequest that allows modifying multiple headers. */ diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java index 91594f9f1..2973655b8 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java @@ -47,6 +47,7 @@ import io.trino.gateway.ha.router.HaGatewayManager; import io.trino.gateway.ha.router.HaQueryHistoryManager; import io.trino.gateway.ha.router.HaResourceGroupsManager; +import io.trino.gateway.ha.router.PathFilter; import io.trino.gateway.ha.router.QueryHistoryManager; import io.trino.gateway.ha.router.ResourceGroupsManager; import io.trino.gateway.ha.router.RoutingGroupSelector; @@ -88,6 +89,7 @@ public class HaGatewayProviderModule private final ResourceGroupsManager resourceGroupsManager; private final GatewayBackendManager gatewayBackendManager; private final QueryHistoryManager queryHistoryManager; + private final PathFilter pathFilter; @Override protected void configure() @@ -97,11 +99,13 @@ protected void configure() binder().bind(GatewayBackendManager.class).toInstance(gatewayBackendManager); binder().bind(QueryHistoryManager.class).toInstance(queryHistoryManager); binder().bind(BackendStateManager.class).in(Scopes.SINGLETON); + binder().bind(PathFilter.class).toInstance(pathFilter); } public HaGatewayProviderModule(HaGatewayConfiguration configuration) { this.configuration = requireNonNull(configuration, "configuration is null"); + pathFilter = new PathFilter(configuration.getStatementPaths(), configuration.getExtraWhitelistPaths()); Map presetUsers = configuration.getPresetUsers(); oauthManager = getOAuthManager(configuration); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java index e5f7aa674..a7821e5b1 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java @@ -46,6 +46,8 @@ import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; import static io.airlift.http.client.Request.Builder.preparePost; import static io.airlift.json.JsonCodec.jsonCodec; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER; import static java.util.Collections.list; import static java.util.Objects.requireNonNull; @@ -58,7 +60,6 @@ public class ExternalRoutingGroupSelector private final boolean propagateErrors; private final HttpClient httpClient; private final RequestAnalyzerConfig requestAnalyzerConfig; - private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider; private static final JsonCodec ROUTING_GROUP_EXTERNAL_BODY_JSON_CODEC = jsonCodec(RoutingGroupExternalBody.class); private static final JsonResponseHandler ROUTING_GROUP_EXTERNAL_RESPONSE_JSON_RESPONSE_HANDLER = createJsonResponseHandler(jsonCodec(ExternalRouterResponse.class)); @@ -74,7 +75,6 @@ public class ExternalRoutingGroupSelector propagateErrors = rulesExternalConfiguration.isPropagateErrors(); this.requestAnalyzerConfig = requestAnalyzerConfig; - trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig); try { this.uri = new URI(requireNonNull(rulesExternalConfiguration.getUrlPath(), "Invalid URL provided, using routing group header as default.")); @@ -143,8 +143,8 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request) TrinoQueryProperties trinoQueryProperties = null; TrinoRequestUser trinoRequestUser = null; if (requestAnalyzerConfig.isAnalyzeRequest()) { - trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig.isClientsUseV2Format(), requestAnalyzerConfig.getMaxBodySize()); - trinoRequestUser = trinoRequestUserProvider.getInstance(request); + trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES); + trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER); } return new RoutingGroupExternalBody( diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/FileBasedRoutingGroupSelector.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/FileBasedRoutingGroupSelector.java index 188013c87..96106b889 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/FileBasedRoutingGroupSelector.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/FileBasedRoutingGroupSelector.java @@ -33,6 +33,8 @@ import java.util.Map; import static com.google.common.base.Suppliers.memoizeWithExpiration; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Collections.sort; @@ -46,16 +48,10 @@ public class FileBasedRoutingGroupSelector private final Supplier> rules; private final boolean analyzeRequest; - private final boolean clientsUseV2Format; - private final int maxBodySize; - private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider; public FileBasedRoutingGroupSelector(String rulesPath, Duration rulesRefreshPeriod, RequestAnalyzerConfig requestAnalyzerConfig) { analyzeRequest = requestAnalyzerConfig.isAnalyzeRequest(); - clientsUseV2Format = requestAnalyzerConfig.isClientsUseV2Format(); - maxBodySize = requestAnalyzerConfig.getMaxBodySize(); - trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig); rules = memoizeWithExpiration(() -> readRulesFromPath(Path.of(rulesPath)), rulesRefreshPeriod.toJavaTime()); } @@ -68,12 +64,9 @@ public RoutingSelectorResponse findRoutingDestination(HttpServletRequest request Map data; if (analyzeRequest) { - TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties( - request, - clientsUseV2Format, - maxBodySize); - TrinoRequestUser trinoRequestUser = trinoRequestUserProvider.getInstance(request); - data = ImmutableMap.of("request", request, "trinoQueryProperties", trinoQueryProperties, "trinoRequestUser", trinoRequestUser); + TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES); + TrinoRequestUser trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER); + data = ImmutableMap.of("request", request, TRINO_QUERY_PROPERTIES, trinoQueryProperties, TRINO_REQUEST_USER, trinoRequestUser); } else { data = ImmutableMap.of("request", request); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/PathFilter.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/PathFilter.java new file mode 100644 index 000000000..a4e197e89 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/PathFilter.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import com.google.inject.Inject; +import com.google.inject.Singleton; + +import java.util.List; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH; +import static java.util.Objects.requireNonNull; + +/** + * A filter component that determines whether a given path should be whitelisted + * for routing to Trino clusters. + */ +@Singleton +public class PathFilter +{ + private final Set statementPaths; + private final List extraWhitelistPatterns; + + @Inject + public PathFilter( + List statementPaths, + List extraWhitelistPaths) + { + this.statementPaths = Set.copyOf(requireNonNull(statementPaths, "Required configuration 'statementPaths' can't be null")); + this.extraWhitelistPatterns = requireNonNull(extraWhitelistPaths, "extraWhitelistPaths cannot be null").stream() + .map(pattern -> { + try { + return Pattern.compile(pattern); + } + catch (PatternSyntaxException e) { + throw new IllegalArgumentException("Invalid regex pattern: " + pattern, e); + } + }) + .collect(toImmutableList()); + } + + /** + * Determines if the given path is whitelisted for routing to backend. + * + * @param path the request path to check + * @return true if the path should be routed to backend, false otherwise + */ + public boolean isPathWhiteListed(String path) + { + return statementPaths.stream().anyMatch(path::startsWith) + || path.startsWith(V1_QUERY_PATH) + || path.startsWith(TRINO_UI_PATH) + || path.startsWith(V1_INFO_PATH) + || path.startsWith(V1_NODE_PATH) + || path.startsWith(UI_API_STATS_PATH) + || path.startsWith(OAUTH_PATH) + || extraWhitelistPatterns.stream().anyMatch(pattern -> pattern.matcher(path).matches()); + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java index 52260dbe8..105363a49 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java @@ -65,13 +65,17 @@ import io.trino.sql.tree.Table; import io.trino.sql.tree.TableFunctionInvocation; import io.trino.sql.tree.WithQuery; -import jakarta.servlet.http.HttpServletRequest; import jakarta.ws.rs.HttpMethod; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.MediaType; import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.net.URLDecoder; import java.util.ArrayList; +import java.util.Collections; import java.util.Enumeration; import java.util.HashSet; import java.util.List; @@ -142,35 +146,44 @@ public TrinoQueryProperties( maxBodySize = -1; } - public TrinoQueryProperties(HttpServletRequest request, boolean isClientsUseV2Format, int maxBodySize) + public TrinoQueryProperties() { - requireNonNull(request, "request is null"); + this("", "", "", ImmutableList.of(), Optional.empty(), Optional.empty(), + ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), false, Optional.empty()); + } + + public TrinoQueryProperties(ContainerRequestContext requestContext, boolean isClientsUseV2Format, int maxBodySize) + { + requireNonNull(requestContext, "requestContext is null"); this.isClientsUseV2Format = isClientsUseV2Format; this.maxBodySize = maxBodySize; - defaultCatalog = Optional.ofNullable(request.getHeader(TRINO_CATALOG_HEADER_NAME)); - defaultSchema = Optional.ofNullable(request.getHeader(TRINO_SCHEMA_HEADER_NAME)); - if (request.getMethod().equals(HttpMethod.POST)) { + defaultCatalog = Optional.ofNullable(requestContext.getHeaderString(TRINO_CATALOG_HEADER_NAME)); + defaultSchema = Optional.ofNullable(requestContext.getHeaderString(TRINO_SCHEMA_HEADER_NAME)); + if (requestContext.getMethod().equals(HttpMethod.POST)) { isNewQuerySubmission = true; - processRequestBody(request); + processRequestBody(requestContext); } } - private void processRequestBody(HttpServletRequest request) + private void processRequestBody(BufferedReader reader, Map preparedStatements) { - try (BufferedReader reader = request.getReader()) { + try (reader) { if (reader == null) { log.warn("HTTP request returned null reader"); body = ""; return; } - Map preparedStatements = getPreparedStatements(request); SqlParser parser = new SqlParser(); reader.mark(maxBodySize); char[] buffer = new char[maxBodySize]; int nChars = reader.read(buffer, 0, maxBodySize); reader.reset(); + if (nChars <= 0) { + log.warn("query text is empty"); + return; + } if (nChars == maxBodySize) { log.warn("Query length greater or equal to requestAnalyzerConfig.maxBodySize detected"); return; @@ -238,11 +251,51 @@ else if (statement instanceof ExecuteImmediate executeImmediate) { } } - private Map getPreparedStatements(HttpServletRequest request) + private void processRequestBody(ContainerRequestContext requestContext) + { + if (!requestContext.hasEntity()) { + return; + } + + MediaType mediaType = requestContext.getMediaType(); + if (mediaType == null) { + return; + } + + String charset = mediaType.getParameters().get("charset"); + if (charset == null) { + log.debug("charset is not set in the request"); + return; + } + + if (!UTF_8.name().equalsIgnoreCase(charset)) { + log.debug("Request charset is not UTF-8 (%s), skipping query parsing", charset); + return; + } + + InputStream entityStream = requestContext.getEntityStream(); + try (InputStreamReader entityReader = new InputStreamReader(entityStream, UTF_8); + BufferedReader reader = new BufferedReader(entityReader)) { + processRequestBody(reader, getPreparedStatements(requestContext)); + } + catch (IOException e) { + log.warn("Error extracting request body for rules processing: %s", e.getMessage()); + errorMessage = Optional.of(e.getMessage()); + } + catch (ParsingException e) { + log.info("Could not parse request body as SQL: %s; Message: %s", body, e.getMessage()); + errorMessage = Optional.of(e.getMessage()); + } + catch (RequestParsingException e) { + log.warn(e, "Error parsing request for rules"); + errorMessage = Optional.of(e.getMessage()); + } + } + + private Map getPreparedStatements(Enumeration headers) throws RequestParsingException { ImmutableMap.Builder preparedStatementsMapBuilder = ImmutableMap.builder(); - Enumeration headers = request.getHeaders(TRINO_PREPARED_STATEMENT_HEADER_NAME); if (headers == null) { return preparedStatementsMapBuilder.build(); } @@ -259,6 +312,19 @@ private Map getPreparedStatements(HttpServletRequest request) return preparedStatementsMapBuilder.build(); } + private Map getPreparedStatements(ContainerRequestContext requestContext) + throws RequestParsingException + { + if (requestContext.getHeaders() == null) { + return ImmutableMap.of(); + } + List headers = requestContext.getHeaders().get(TRINO_PREPARED_STATEMENT_HEADER_NAME); + if (headers == null || headers.isEmpty()) { + return ImmutableMap.of(); + } + return getPreparedStatements(Collections.enumeration(headers)); + } + private String decodePreparedStatementFromHeader(String headerValue) { // From io.trino.server.protocol.PreparedStatementEncoder diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java index e9d693b9f..449a198f1 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java @@ -33,14 +33,14 @@ import com.nimbusds.openid.connect.sdk.claims.UserInfo; import io.airlift.log.Logger; import io.trino.gateway.ha.config.RequestAnalyzerConfig; -import jakarta.servlet.http.Cookie; -import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.HttpHeaders; import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.Base64; +import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -62,7 +62,7 @@ public class TrinoRequestUser private final Optional> userInfoCache; - private TrinoRequestUser(HttpServletRequest request, String userField, Optional> userInfoCache) + private TrinoRequestUser(ContainerRequestContext request, String userField, Optional> userInfoCache) { this.userInfoCache = requireNonNull(userInfoCache); user = extractUser(request, userField); @@ -106,15 +106,17 @@ public boolean userExistsAndEquals(String testUser) return user.filter(testUser::equals).isPresent(); } - private Optional extractUserFromCookies(HttpServletRequest request, String userField) + private Optional extractUserFromCookies(ContainerRequestContext requestContext, String userField) { - if (request.getCookies() == null) { + Map cookies = requestContext.getCookies(); + if (cookies == null || cookies.isEmpty()) { + log.debug("cookies are empty"); return Optional.empty(); } - log.debug("Trying to get user from cookie"); - Optional uiToken = Arrays.stream(request.getCookies()) - .filter(cookie -> cookie.getName().equals(TRINO_UI_TOKEN_NAME) || cookie.getName().equals(TRINO_SECURE_UI_TOKEN_NAME)) - .findAny(); + + log.debug("Trying to get user from cookie from ContainerRequestContext"); + Optional uiToken = Optional.ofNullable(cookies.get(TRINO_UI_TOKEN_NAME)) + .or(() -> Optional.ofNullable(cookies.get(TRINO_SECURE_UI_TOKEN_NAME))); return uiToken.map(t -> { try { @@ -129,20 +131,20 @@ private Optional extractUserFromCookies(HttpServletRequest request, Stri }); } - private Optional extractUser(HttpServletRequest request, String userField) + private Optional extractUser(ContainerRequestContext requestContext, String userField) { String header; - header = request.getHeader(TRINO_USER_HEADER_NAME); + header = requestContext.getHeaderString(TRINO_USER_HEADER_NAME); if (header != null) { return Optional.of(header); } - Optional user = extractUserFromAuthorizationHeader(request.getHeader("Authorization"), userField); + Optional user = extractUserFromAuthorizationHeader(requestContext.getHeaderString(HttpHeaders.AUTHORIZATION), userField); if (user.isPresent()) { return user; } - return extractUserFromCookies(request, userField); + return extractUserFromCookies(requestContext, userField); } private Optional extractUserFromAuthorizationHeader(String header, String userField) @@ -225,7 +227,7 @@ public TrinoRequestUserProvider(RequestAnalyzerConfig config) } } - public TrinoRequestUser getInstance(HttpServletRequest request) + public TrinoRequestUser getInstance(ContainerRequestContext request) { return new TrinoRequestUser(request, userField, userInfoCache); } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/QueryMetadataParser.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/QueryMetadataParser.java new file mode 100644 index 000000000..f5ab5b49b --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/QueryMetadataParser.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.security; + +import com.google.inject.Inject; +import io.airlift.log.Logger; +import io.trino.gateway.ha.config.HaGatewayConfiguration; +import io.trino.gateway.ha.config.RequestAnalyzerConfig; +import io.trino.gateway.ha.router.PathFilter; +import io.trino.gateway.ha.router.TrinoQueryProperties; +import io.trino.gateway.ha.security.util.GatewayFilterPriorities; +import jakarta.annotation.Priority; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerRequestFilter; +import jakarta.ws.rs.container.PreMatching; +import org.glassfish.jersey.server.ContainerRequest; + +import java.io.IOException; + +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES; + +/** + * + * This filter parses the query statement and stores the TrinoQueryProperties object + * as a property to be accessed in the later processing + */ +@PreMatching +@Priority(GatewayFilterPriorities.PRE_AUTHORIZATION) +public class QueryMetadataParser + implements ContainerRequestFilter +{ + private static final Logger log = Logger.get(QueryMetadataParser.class); + private static final int MAX_QUERY_TEXT_LOG_LENGTH = 100; + private final boolean isAnalyzeRequest; + private final boolean isClientsUseV2Format; + private final int maxBodySize; + private final PathFilter pathFilter; + + @Inject + public QueryMetadataParser(HaGatewayConfiguration config, PathFilter pathFilter) + { + RequestAnalyzerConfig analyzerConfig = config.getRequestAnalyzerConfig(); + this.isAnalyzeRequest = analyzerConfig.isAnalyzeRequest(); + this.isClientsUseV2Format = analyzerConfig.isClientsUseV2Format(); + this.maxBodySize = analyzerConfig.getMaxBodySize(); + this.pathFilter = pathFilter; + } + + @Override + public void filter(ContainerRequestContext requestContext) + throws IOException + { + String path = requestContext.getUriInfo().getRequestUri().getPath(); + if (path == null || !isAnalyzeRequest || !pathFilter.isPathWhiteListed(path)) { + return; + } + + log.debug("Processing query metadata for path: %s", path); + // Buffer the entity (aka body of the request) for future reads during request processing + ContainerRequest jerseyRequest = (ContainerRequest) requestContext; + jerseyRequest.bufferEntity(); + + TrinoQueryProperties queryProps; + try { + queryProps = new TrinoQueryProperties(requestContext, isClientsUseV2Format, maxBodySize); + } + catch (Exception ex) { + log.warn(ex, "Failed to parse query properties for query text: [%s]. Error: %s. Using empty properties.", + getQueryTextForLogging(requestContext), ex.getMessage()); + queryProps = new TrinoQueryProperties(); + } + + requestContext.setProperty(TRINO_QUERY_PROPERTIES, queryProps); + } + + private String getQueryTextForLogging(ContainerRequestContext requestContext) + { + try { + ContainerRequest jerseyRequest = (ContainerRequest) requestContext; + + String body = jerseyRequest.readEntity(String.class); + if (body == null || body.isEmpty()) { + return ""; + } + else if (body.length() > MAX_QUERY_TEXT_LOG_LENGTH) { + return body.substring(0, MAX_QUERY_TEXT_LOG_LENGTH) + "..."; + } + + return body; + } + + catch (Exception e) { + log.error(e, "unable to read query text"); + return ""; + } + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/QueryUserInfoParser.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/QueryUserInfoParser.java new file mode 100644 index 000000000..6fa9e0f35 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/QueryUserInfoParser.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.gateway.ha.security; + +import com.google.inject.Inject; +import io.airlift.log.Logger; +import io.trino.gateway.ha.config.HaGatewayConfiguration; +import io.trino.gateway.ha.config.RequestAnalyzerConfig; +import io.trino.gateway.ha.router.PathFilter; +import io.trino.gateway.ha.router.TrinoRequestUser; +import io.trino.gateway.ha.security.util.GatewayFilterPriorities; +import jakarta.annotation.Priority; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerRequestFilter; +import jakarta.ws.rs.container.PreMatching; + +import java.io.IOException; + +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER; + +/** + * A filter which parses and extracts Trino user identity from incoming request headers + * and stores it in the request context property TRINO_REQUEST_USER + * for downstream filters and handlers to use. + */ + +@PreMatching +@Priority(GatewayFilterPriorities.PRE_AUTHENTICATION) +public class QueryUserInfoParser + implements ContainerRequestFilter +{ + private static final Logger log = Logger.get(QueryUserInfoParser.class); + private final String tokenUserField; + private final String oauthTokenInfoUrl; + private final PathFilter pathFilter; + + @Inject + public QueryUserInfoParser(HaGatewayConfiguration config, PathFilter pathFilter) + { + RequestAnalyzerConfig requestAnalyzerConfig = config.getRequestAnalyzerConfig(); + this.tokenUserField = requestAnalyzerConfig.getTokenUserField(); + this.oauthTokenInfoUrl = requestAnalyzerConfig.getOauthTokenInfoUrl(); + this.pathFilter = pathFilter; + } + + @Override + public void filter(ContainerRequestContext requestContext) + throws IOException + { + String path = requestContext.getUriInfo().getRequestUri().getPath(); + if (!pathFilter.isPathWhiteListed(path)) { + return; + } + + RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig(); + requestAnalyzerConfig.setTokenUserField(tokenUserField); + requestAnalyzerConfig.setOauthTokenInfoUrl(oauthTokenInfoUrl); + + TrinoRequestUser user = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(requestContext); + requestContext.setProperty(TRINO_REQUEST_USER, user); + log.debug("Parsed user %s", user.getUser().orElse("None")); + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/util/GatewayFilterPriorities.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/util/GatewayFilterPriorities.java new file mode 100644 index 000000000..fee81e68e --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/util/GatewayFilterPriorities.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.security.util; + +public final class GatewayFilterPriorities +{ + private GatewayFilterPriorities() {} + + public static final int PRE_AUTHENTICATION = 500; + public static final int PRE_AUTHORIZATION = 1500; +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java index dc74f0410..8da9713ca 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java @@ -63,6 +63,7 @@ import static io.airlift.http.client.Request.Builder.preparePut; import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER; import static io.trino.gateway.ha.handler.ProxyUtils.SOURCE_HEADER; import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE; import static jakarta.ws.rs.core.Response.Status.BAD_GATEWAY; @@ -87,7 +88,6 @@ public class ProxyRequestHandler private final boolean addXForwardedHeaders; private final List statementPaths; private final boolean includeClusterInfoInResponse; - private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider; private final ProxyResponseConfiguration proxyResponseConfiguration; @Inject @@ -100,7 +100,6 @@ public ProxyRequestHandler( this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.routingManager = requireNonNull(routingManager, "routingManager is null"); this.queryHistoryManager = requireNonNull(queryHistoryManager, "queryHistoryManager is null"); - trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(haGatewayConfiguration.getRequestAnalyzerConfig()); cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled(); asyncTimeout = haGatewayConfiguration.getRouting().getAsyncTimeout(); addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders(); @@ -189,7 +188,7 @@ private void performRequest( FluentFuture future = executeHttp(request); if (statementPaths.stream().anyMatch(request.getUri().getPath()::startsWith) && request.getMethod().equals(HttpMethod.POST)) { - Optional username = trinoRequestUserProvider.getInstance(servletRequest).getUser(); + Optional username = ((TrinoRequestUser) servletRequest.getAttribute(TRINO_REQUEST_USER)).getUser(); future = future.transform(response -> recordBackendForQueryId(request, response, username, routingDestination), executor); if (includeClusterInfoInResponse) { cookieBuilder.add(new NewCookie.Builder("trinoClusterHost").value(remoteUri.getHost()).build()); diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/RouterPreMatchContainerRequestFilter.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/RouterPreMatchContainerRequestFilter.java index 838181ddb..69a439df8 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/RouterPreMatchContainerRequestFilter.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/RouterPreMatchContainerRequestFilter.java @@ -14,7 +14,7 @@ package io.trino.gateway.proxyserver; import com.google.inject.Inject; -import io.trino.gateway.ha.handler.RoutingTargetHandler; +import io.trino.gateway.ha.router.PathFilter; import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.container.ContainerRequestFilter; import jakarta.ws.rs.container.PreMatching; @@ -22,8 +22,6 @@ import java.io.IOException; import java.net.URI; -import static java.util.Objects.requireNonNull; - /** * This pre-matching ContainerRequestFilter catches all requests and forwards * those that need to be routed to a Trino backend to {@link RouteToBackendResource}. @@ -35,19 +33,19 @@ public class RouterPreMatchContainerRequestFilter { public static final String ROUTE_TO_BACKEND = "/trino-gateway/internal/route_to_backend"; - private final RoutingTargetHandler routingTargetHandler; + private final PathFilter pathFilter; @Inject - public RouterPreMatchContainerRequestFilter(RoutingTargetHandler routingTargetHandler) + public RouterPreMatchContainerRequestFilter(PathFilter pathFilter) { - this.routingTargetHandler = requireNonNull(routingTargetHandler); + this.pathFilter = pathFilter; } @Override public void filter(ContainerRequestContext request) throws IOException { - if (routingTargetHandler.isPathWhiteListed(request.getUriInfo().getRequestUri().getPath())) { + if (pathFilter.isPathWhiteListed(request.getUriInfo().getRequestUri().getPath())) { request.setRequestUri(URI.create(ROUTE_TO_BACKEND)); } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java b/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java index 9ecc3c570..89d0074d2 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java @@ -14,27 +14,19 @@ package io.trino.gateway.ha.handler; import com.google.common.collect.ImmutableList; -import jakarta.servlet.ReadListener; -import jakarta.servlet.ServletInputStream; +import io.trino.gateway.ha.config.RequestAnalyzerConfig; +import io.trino.gateway.ha.util.QueryRequestMock; import jakarta.servlet.http.HttpServletRequest; -import jakarta.ws.rs.HttpMethod; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; -import org.mockito.Mockito; -import java.io.BufferedReader; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.StringReader; import java.util.List; import java.util.Optional; import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent; -import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.when; @TestInstance(Lifecycle.PER_CLASS) final class TestQueryIdCachingProxyHandler @@ -171,12 +163,6 @@ CALL KILL_QUERY ('20200416_160256_03078_6b4yt', 'If he dies, he dies') "system", "runtime"))).hasValue("20200416_160256_03078_6b4yt"); - assertThatThrownBy(() -> extractQueryId(request(""" - CALL KILL_QUERY (lower('20200416_160256_03078_6b4yt'), 'If he dies, he dies') - """, - "system", - "runtime"))).isInstanceOf(IllegalArgumentException.class); - assertThat(extractQueryId(request("CALL notsystem.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')"))).isEmpty(); assertThat(extractQueryId(request("CALL runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')", "notsystem"))) @@ -200,63 +186,35 @@ private static Optional extractQueryId(HttpServletRequest request) private static HttpServletRequest request(String query, String defaultCatalog) throws IOException { - // Warning - this is not a fully featured mock of the behavior of HttpServlet with respect to headers. For example, - // getHeaderNames will return an empty list, and getHeader is not fully case-insensitive. This is only intended to be - // a minimal mock for this test. - HttpServletRequest request = request(query); - when(request.getHeader("X-Trino-Catalog")).thenReturn(defaultCatalog); - when(request.getHeader("X-trino-catalog")).thenReturn(defaultCatalog); - return request; + RequestAnalyzerConfig config = new RequestAnalyzerConfig(); + config.setAnalyzeRequest(true); + return new QueryRequestMock().query(query) + .httpHeader("X-Trino-Catalog", defaultCatalog) + .requestAnalyzerConfig(config) + .getHttpServletRequest(); } private static HttpServletRequest request(String query, String defaultCatalog, String defaultSchema) throws IOException { - HttpServletRequest request = request(query); - when(request.getHeader("X-Trino-Catalog")).thenReturn(defaultCatalog); - when(request.getHeader("X-trino-catalog")).thenReturn(defaultCatalog); - when(request.getHeader("X-Trino-Schema")).thenReturn(defaultSchema); - when(request.getHeader("X-trino-schema")).thenReturn(defaultSchema); - return request; + RequestAnalyzerConfig config = new RequestAnalyzerConfig(); + config.setAnalyzeRequest(true); + return new QueryRequestMock().query(query) + .httpHeader("X-Trino-Catalog", defaultCatalog) + .httpHeader("X-trino-catalog", defaultCatalog) + .httpHeader("X-Trino-Schema", defaultSchema) + .httpHeader("X-trino-schema", defaultSchema) + .requestAnalyzerConfig(config) + .getHttpServletRequest(); } private static HttpServletRequest request(String query) throws IOException { - HttpServletRequest request = Mockito.mock(HttpServletRequest.class); - - ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8)); - when(request.getMethod()).thenReturn(HttpMethod.POST); - when(request.getInputStream()).thenReturn(new ServletInputStream() - { - @Override - public boolean isFinished() - { - return byteArrayInputStream.available() > 0; - } - - @Override - public boolean isReady() - { - return true; - } - - @Override - public void setReadListener(ReadListener readListener) - {} - - @Override - public int read() - throws IOException - { - return byteArrayInputStream.read(); - } - }); - - when(request.getReader()).thenReturn(new BufferedReader(new StringReader(query))); - - when(request.getQueryString()).thenReturn(""); - - return request; + RequestAnalyzerConfig config = new RequestAnalyzerConfig(); + config.setAnalyzeRequest(true); + return new QueryRequestMock().query(query) + .requestAnalyzerConfig(config) + .getHttpServletRequest(); } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestExternalRoutingGroupSelector.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestExternalRoutingGroupSelector.java index bae7184ff..d5092f09a 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestExternalRoutingGroupSelector.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestExternalRoutingGroupSelector.java @@ -26,8 +26,8 @@ import io.trino.gateway.ha.router.schema.ExternalRouterResponse; import io.trino.gateway.ha.router.schema.RoutingGroupExternalBody; import io.trino.gateway.ha.router.schema.RoutingSelectorResponse; +import io.trino.gateway.ha.util.QueryRequestMock; import jakarta.servlet.http.HttpServletRequest; -import jakarta.ws.rs.HttpMethod; import jakarta.ws.rs.WebApplicationException; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -37,10 +37,10 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URI; -import java.net.URISyntaxException; import java.util.Collections; import java.util.Enumeration; import java.util.List; @@ -49,10 +49,9 @@ import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; import static io.airlift.http.client.Request.Builder.preparePost; import static io.airlift.json.JsonCodec.jsonCodec; -import static io.trino.gateway.ha.handler.HttpUtils.USER_HEADER; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER; import static io.trino.gateway.ha.router.RoutingGroupSelector.ROUTING_GROUP_HEADER; -import static io.trino.gateway.ha.router.TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME; -import static io.trino.gateway.ha.router.TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; @@ -89,7 +88,7 @@ static RulesExternalConfiguration provideRoutingRuleExternalConfig() @Test void testByRoutingRulesExternalEngine() - throws URISyntaxException + throws Exception { RulesExternalConfiguration rulesExternalConfiguration = provideRoutingRuleExternalConfig(); HttpServletRequest mockRequest = prepareMockRequest(); @@ -133,6 +132,7 @@ void testByRoutingRulesExternalEngine() @Test void testFallbackToHeaderOnApiFailure() + throws IOException { // Mock this specific test an HTTP request HttpClient httpClient = mock(HttpClient.class); @@ -374,17 +374,13 @@ void testPropagateErrorsTrueResponseWithErrors() private HttpServletRequest prepareMockRequest() { - HttpServletRequest mockRequest = mock(HttpServletRequest.class); - when(mockRequest.getMethod()).thenReturn(HttpMethod.POST); - return mockRequest; + return new QueryRequestMock() + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); } private void setMockHeaders(HttpServletRequest mockRequest) { - when(mockRequest.getHeader(TRINO_CATALOG_HEADER_NAME)).thenReturn("default"); - when(mockRequest.getHeader(TRINO_SCHEMA_HEADER_NAME)).thenReturn("test"); - when(mockRequest.getHeader(USER_HEADER)).thenReturn("user"); - List defaultHeaderNames = List.of("Accept-Encoding"); List defaultAcceptEncodingValues = List.of("gzip", "deflate", "br"); Enumeration headerNamesEnumeration = Collections.enumeration(defaultHeaderNames); @@ -400,11 +396,8 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request) TrinoQueryProperties trinoQueryProperties = null; TrinoRequestUser trinoRequestUser = null; if (requestAnalyzerConfig.isAnalyzeRequest()) { - trinoQueryProperties = new TrinoQueryProperties( - request, - requestAnalyzerConfig.isClientsUseV2Format(), - requestAnalyzerConfig.getMaxBodySize()); - trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(request); + trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES); + trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER); } return new RoutingGroupExternalBody( diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestPathFilter.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestPathFilter.java new file mode 100644 index 000000000..286582e93 --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestPathFilter.java @@ -0,0 +1,219 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import com.google.common.collect.ImmutableList; +import io.trino.gateway.ha.config.HaGatewayConfiguration; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH; +import static io.trino.gateway.ha.handler.HttpUtils.V1_STATEMENT_PATH; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TestPathFilter +{ + private PathFilter pathFilter; + + TestPathFilter() + { + List statementPaths = ImmutableList.of(V1_STATEMENT_PATH, "/v2/statement"); + List extraWhitelistPaths = ImmutableList.of( + "/api/v1/custom/.*", + "/health/.*", + "/metrics"); + pathFilter = new PathFilter(statementPaths, extraWhitelistPaths); + } + + @Test + void testHardcodedTrinoQueryPath() + { + assertThat(pathFilter.isPathWhiteListed(V1_QUERY_PATH)).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_QUERY_PATH + "/query123")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_QUERY_PATH + "/query123/status")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(TRINO_UI_PATH)).isTrue(); + assertThat(pathFilter.isPathWhiteListed(TRINO_UI_PATH + "/query.html")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(TRINO_UI_PATH + "/assets/app.js")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_INFO_PATH)).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_INFO_PATH + "/status")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_NODE_PATH)).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_NODE_PATH + "/node123")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_NODE_PATH + "/node123/status")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(UI_API_STATS_PATH)).isTrue(); + assertThat(pathFilter.isPathWhiteListed(UI_API_STATS_PATH + "/running")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(UI_API_STATS_PATH + "/completed")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(OAUTH_PATH)).isTrue(); + assertThat(pathFilter.isPathWhiteListed(OAUTH_PATH + "/callback")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(OAUTH_PATH + "/token")).isTrue(); + } + + @Test + void testConfiguredStatementPaths() + { + // Test V1 statement path + assertThat(pathFilter.isPathWhiteListed(V1_STATEMENT_PATH)).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_STATEMENT_PATH + "/executing")).isTrue(); + assertThat(pathFilter.isPathWhiteListed(V1_STATEMENT_PATH + "/queued")).isTrue(); + + // Test V2 statement path (from our configuration) + assertThat(pathFilter.isPathWhiteListed("/v2/statement")).isTrue(); + assertThat(pathFilter.isPathWhiteListed("/v2/statement/query456")).isTrue(); + assertThat(pathFilter.isPathWhiteListed("/v2/statement/batch")).isTrue(); + } + + @Test + void testDynamicRegexPaths() + { + // Test custom API regex pattern "/api/v1/custom/.*" + assertThat(pathFilter.isPathWhiteListed("/api/v1/custom/")).isTrue(); + assertThat(pathFilter.isPathWhiteListed("/api/v1/custom/users")).isTrue(); + assertThat(pathFilter.isPathWhiteListed("/api/v1/custom/users/123")).isTrue(); + assertThat(pathFilter.isPathWhiteListed("/api/v1/custom/reports/daily")).isTrue(); + + // Test health check regex pattern "/health/.*" + assertThat(pathFilter.isPathWhiteListed("/health/")).isTrue(); + assertThat(pathFilter.isPathWhiteListed("/health/status")).isTrue(); + assertThat(pathFilter.isPathWhiteListed("/health/ready")).isTrue(); + assertThat(pathFilter.isPathWhiteListed("/health/live")).isTrue(); + + // Test exact match for metrics + assertThat(pathFilter.isPathWhiteListed("/metrics")).isTrue(); + } + + @Test + void testNonWhitelistedPaths() + { + assertThat(pathFilter.isPathWhiteListed("/v3/statement")).isFalse(); // Not in our statement paths + assertThat(pathFilter.isPathWhiteListed("/api/v2/custom/users")).isFalse(); // Doesn't match v1 pattern + assertThat(pathFilter.isPathWhiteListed("/status")).isFalse(); // Not health/status + assertThat(pathFilter.isPathWhiteListed("/metrics/extra")).isFalse(); // metrics is exact match only + assertThat(pathFilter.isPathWhiteListed("")).isFalse(); // Empty path + assertThat(pathFilter.isPathWhiteListed("/")).isFalse(); // Root path + } + + @Test + void testEdgeCases() + { + // Test case sensitivity + assertThat(pathFilter.isPathWhiteListed("/V1/query")).isFalse(); // Case sensitive + assertThat(pathFilter.isPathWhiteListed("/API/v1/custom/test")).isFalse(); // Case sensitive + + // Test partial matches + assertThat(pathFilter.isPathWhiteListed("/v1/quer")).isFalse(); // Partial match of hardcoded path + assertThat(pathFilter.isPathWhiteListed("/api/v1/custo")).isFalse(); // Partial match of regex + } + + @Test + void testRegexPattern() + { + List complexStatementPaths = ImmutableList.of(V1_STATEMENT_PATH); + List complexRegexPaths = ImmutableList.of( + "/api/v[1-9]/.*"); + PathFilter complexFilter = new PathFilter(complexStatementPaths, complexRegexPaths); + + // Test version pattern + assertThat(complexFilter.isPathWhiteListed("/api/v1/users")).isTrue(); + assertThat(complexFilter.isPathWhiteListed("/api/v2/data")).isTrue(); + assertThat(complexFilter.isPathWhiteListed("/api/v9/reports")).isTrue(); + assertThat(complexFilter.isPathWhiteListed("/api/v0/test")).isFalse(); // v0 not in [1-9] + assertThat(complexFilter.isPathWhiteListed("/api/v10/test")).isFalse(); // v10 not single digit + } + + @Test + void testEmptyConfiguration() + { + // Test PathFilter with empty lists + PathFilter emptyFilter = new PathFilter(ImmutableList.of(), ImmutableList.of()); + + // Should still allow hardcoded paths + assertThat(emptyFilter.isPathWhiteListed(V1_QUERY_PATH)).isTrue(); + assertThat(emptyFilter.isPathWhiteListed(TRINO_UI_PATH)).isTrue(); + assertThat(emptyFilter.isPathWhiteListed(V1_INFO_PATH)).isTrue(); + assertThat(emptyFilter.isPathWhiteListed(OAUTH_PATH)).isTrue(); + + // Should not allow any custom paths + assertThat(emptyFilter.isPathWhiteListed(V1_STATEMENT_PATH)).isFalse(); + assertThat(emptyFilter.isPathWhiteListed("/custom/path")).isFalse(); + } + + @Test + void testStatementPathsOnly() + { + // Test PathFilter with only statement paths, no regex patterns + PathFilter statementOnlyFilter = new PathFilter( + ImmutableList.of(V1_STATEMENT_PATH, "/v2/statement", "/custom/execute"), + ImmutableList.of()); + + // Should allow hardcoded paths + assertThat(statementOnlyFilter.isPathWhiteListed(V1_QUERY_PATH)).isTrue(); + assertThat(statementOnlyFilter.isPathWhiteListed(TRINO_UI_PATH)).isTrue(); + + // Should allow configured statement paths + assertThat(statementOnlyFilter.isPathWhiteListed(V1_STATEMENT_PATH)).isTrue(); + assertThat(statementOnlyFilter.isPathWhiteListed("/v2/statement")).isTrue(); + assertThat(statementOnlyFilter.isPathWhiteListed("/custom/execute")).isTrue(); + assertThat(statementOnlyFilter.isPathWhiteListed("/custom/execute/batch")).isTrue(); + + // Should not allow other paths + assertThat(statementOnlyFilter.isPathWhiteListed("/api/custom")).isFalse(); + assertThat(statementOnlyFilter.isPathWhiteListed("/health")).isFalse(); + } + + @Test + void testInvalidRegexFilter() + { + List statementPaths = ImmutableList.of(V1_STATEMENT_PATH, "/v2/statement"); + List extraWhitelistPaths = ImmutableList.of( + "[/api/v1/custom/.*"); + assertThatThrownBy(() -> { + PathFilter invalidRegex = new PathFilter(statementPaths, extraWhitelistPaths); + invalidRegex.isPathWhiteListed("/api/v1/custom"); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testDefaultHaConfigurationForPaths() + { + HaGatewayConfiguration configuration = new HaGatewayConfiguration(); + PathFilter filter = new PathFilter(configuration.getStatementPaths(), + configuration.getExtraWhitelistPaths()); + + assertThat(filter.isPathWhiteListed(V1_STATEMENT_PATH)).isTrue(); + assertThat(filter.isPathWhiteListed(V1_STATEMENT_PATH + "/executing")).isTrue(); + assertThat(filter.isPathWhiteListed(V1_STATEMENT_PATH + "/queued")).isTrue(); + assertThat(filter.isPathWhiteListed(V1_QUERY_PATH)).isTrue(); + assertThat(filter.isPathWhiteListed(TRINO_UI_PATH)).isTrue(); + assertThat(filter.isPathWhiteListed(OAUTH_PATH)).isTrue(); + + assertThat(filter.isPathWhiteListed("/v2/statement")).isFalse(); + } + + @Test + void testDefaultHaConfigurationNullPaths() + { + HaGatewayConfiguration configuration = new HaGatewayConfiguration(); + + assertThatThrownBy(() -> { + new PathFilter(null, + configuration.getExtraWhitelistPaths()); + }).isInstanceOf(NullPointerException.class); + } +} diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java index 9e3dcdeef..0ca8f41f7 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java @@ -16,9 +16,12 @@ import com.google.common.collect.ImmutableSet; import io.airlift.units.Duration; import io.trino.gateway.ha.config.RequestAnalyzerConfig; +import io.trino.gateway.ha.util.QueryRequestMock; import io.trino.sql.tree.QualifiedName; import jakarta.servlet.http.HttpServletRequest; import jakarta.ws.rs.HttpMethod; +import jakarta.ws.rs.core.MultivaluedHashMap; +import jakarta.ws.rs.core.MultivaluedMap; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -31,16 +34,15 @@ import java.io.BufferedWriter; import java.io.File; import java.io.IOException; -import java.io.Reader; import java.io.StringReader; import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.Base64; -import java.util.Collections; import java.util.Set; import java.util.stream.Stream; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES; import static io.trino.gateway.ha.router.RoutingGroupSelector.ROUTING_GROUP_HEADER; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.concurrent.TimeUnit.HOURS; @@ -104,11 +106,11 @@ void testByRoutingRulesEngine(String rulesConfigPath) RoutingGroupSelector routingGroupSelector = RoutingGroupSelector.byRoutingRulesEngine(rulesConfigPath, oneHourRefreshPeriod, requestAnalyzerConfig); - HttpServletRequest mockRequest = prepareMockRequest(); - - when(mockRequest.getHeader(TRINO_SOURCE_HEADER)).thenReturn("airflow"); + HttpServletRequest mockRequest = new QueryRequestMock() + .httpHeader(TRINO_SOURCE_HEADER, "airflow") + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); - assertThat(routingGroup).isEqualTo("etl"); } @@ -122,8 +124,11 @@ void testGetUserFromBasicAuth() requestAnalyzerConfig); String encodedUsernamePassword = Base64.getEncoder().encodeToString("will:supersecret".getBytes(UTF_8)); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getHeader("Authorization")).thenReturn("Basic " + encodedUsernamePassword); + HttpServletRequest mockRequest = new QueryRequestMock() + .httpHeader("Authorization", "Basic " + encodedUsernamePassword) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("will-group"); @@ -139,12 +144,12 @@ void testTrinoQueryPropertiesQueryDetails() oneHourRefreshPeriod, requestAnalyzerConfig); String query = "SELECT x.*, y.*, z.* FROM catx.schemx.tblx x, schemy.tbly y, tblz z"; - Reader reader = new StringReader(query); - BufferedReader bufferedReader = new BufferedReader(reader); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getReader()).thenReturn(bufferedReader); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn("cat_default"); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn("schem_\\\"default"); + + HttpServletRequest mockRequest = new QueryRequestMock().query(query) + .httpHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME, "cat_default") + .httpHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME, "schem_\\\"default") + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("tbl-group"); @@ -160,14 +165,14 @@ void testTrinoQueryPropertiesCatalogSchemas() oneHourRefreshPeriod, requestAnalyzerConfig); String query = "SELECT x.*, y.* FROM catx.nondefault.tblx x, caty.default.tbly y"; - Reader reader = new StringReader(query); - BufferedReader bufferedReader = new BufferedReader(reader); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getReader()).thenReturn(bufferedReader); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn("catx"); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn("default"); - String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); + HttpServletRequest mockRequest = new QueryRequestMock().query(query) + .httpHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME, "catx") + .httpHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME, "default") + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("catalog-schema-group"); } @@ -179,12 +184,14 @@ void testTrinoQueryPropertiesSessionDefaults() "src/test/resources/rules/routing_rules_trino_query_properties.yml", oneHourRefreshPeriod, requestAnalyzerConfig); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn("other_catalog"); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn("other_schema"); - String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); + HttpServletRequest mockRequest = new QueryRequestMock() + .httpHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME, "other_catalog") + .httpHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME, "other_schema") + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("defaults-group"); } @@ -198,12 +205,12 @@ void testTrinoQueryPropertiesQueryType() oneHourRefreshPeriod, requestAnalyzerConfig); String query = "INSERT INTO foo SELECT 1"; - Reader reader = new StringReader(query); - BufferedReader bufferedReader = new BufferedReader(reader); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getReader()).thenReturn(bufferedReader); - String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); + HttpServletRequest mockRequest = new QueryRequestMock() + .query(query) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("type-group"); } @@ -216,10 +223,13 @@ void testTrinoQueryPropertiesResourceGroupQueryType() "src/test/resources/rules/routing_rules_trino_query_properties.yml", oneHourRefreshPeriod, requestAnalyzerConfig); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getReader()).thenReturn(new BufferedReader(new StringReader("CREATE TABLE cat.schem.foo (c1 int)"))); - String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); + String query = "CREATE TABLE cat.schem.foo (c1 int)"; + HttpServletRequest mockRequest = new QueryRequestMock() + .query(query) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("resource-group-type-group"); } @@ -234,12 +244,11 @@ void testTrinoQueryPropertiesAlternateStatementFormat() oneHourRefreshPeriod, requestAnalyzerConfig); String body = "{\"preparedStatements\" : {\"statement1\":\"INSERT INTO foo SELECT 1\"}, \"query\": \"EXECUTE statement1\"}"; - Reader reader = new StringReader(body); - BufferedReader bufferedReader = new BufferedReader(reader); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getReader()).thenReturn(bufferedReader); - String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); + HttpServletRequest mockRequest = new QueryRequestMock().query(body) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("type-group"); } @@ -255,19 +264,48 @@ void testTrinoQueryPropertiesPreparedStatementInHeader() "src/test/resources/rules/routing_rules_trino_query_properties.yml", oneHourRefreshPeriod, requestAnalyzerConfig); - Reader reader = new StringReader(body); - BufferedReader bufferedReader = new BufferedReader(reader); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getReader()).thenReturn(bufferedReader); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_PREPARED_STATEMENT_HEADER_NAME)).thenReturn(encodedStatements); - when(mockRequest.getHeaders(TrinoQueryProperties.TRINO_PREPARED_STATEMENT_HEADER_NAME)).thenReturn(Collections.enumeration(Arrays.asList(encodedStatements.split(",")))); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn("cat"); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn("schem"); - String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); + MultivaluedMap headers = new MultivaluedHashMap<>(); + headers.addAll(TrinoQueryProperties.TRINO_PREPARED_STATEMENT_HEADER_NAME, Arrays.asList(encodedStatements.split(","))); + + HttpServletRequest mockRequest = new QueryRequestMock().query(body).httpHeaders(headers) + .httpHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME, "cat") + .httpHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME, "schem") + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("statement-header-group"); } + @Test + void testTrinoQueryPropertiesParsingError() + throws IOException + { + RoutingGroupSelector routingGroupSelector = + RoutingGroupSelector.byRoutingRulesEngine( + "src/test/resources/rules/routing_rules_trino_query_properties.yml", + oneHourRefreshPeriod, + requestAnalyzerConfig); + + // Invalid SQL that will cause a ParsingException + String invalidQuery = "SELECT * FROM table WHERE column = "; + HttpServletRequest mockRequest = new QueryRequestMock() + .query(invalidQuery) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + + // When parsing fails, the query should route to the default "no-match" group + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); + assertThat(routingGroup).isEqualTo("no-match"); + + // Verify that the TrinoQueryProperties indicates a parsing failure + TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) mockRequest.getAttribute(TRINO_QUERY_PROPERTIES); + assertThat(trinoQueryProperties).isNotNull(); + assertThat(trinoQueryProperties.isQueryParsingSuccessful()).isFalse(); + assertThat(trinoQueryProperties.getErrorMessage()).isPresent(); + } + @ParameterizedTest @MethodSource("provideRoutingRuleConfigFiles") void testByRoutingRulesEngineSpecialLabel(String rulesConfigPath) @@ -275,13 +313,13 @@ void testByRoutingRulesEngineSpecialLabel(String rulesConfigPath) RoutingGroupSelector routingGroupSelector = RoutingGroupSelector.byRoutingRulesEngine(rulesConfigPath, oneHourRefreshPeriod, requestAnalyzerConfig); - HttpServletRequest mockRequest = prepareMockRequest(); + HttpServletRequest mockRequest = new QueryRequestMock() + .httpHeader(TRINO_SOURCE_HEADER, "airflow") + .httpHeader(TRINO_CLIENT_TAGS_HEADER, "email=test@example.com,label=special") + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); - when(mockRequest.getHeader(TRINO_SOURCE_HEADER)).thenReturn("airflow"); - when(mockRequest.getHeader(TRINO_CLIENT_TAGS_HEADER)).thenReturn( - "email=test@example.com,label=special"); String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); - assertThat(routingGroup).isEqualTo("etl-special"); } @@ -292,11 +330,13 @@ void testByRoutingRulesEngineNoMatch(String rulesConfigPath) RoutingGroupSelector routingGroupSelector = RoutingGroupSelector.byRoutingRulesEngine(rulesConfigPath, oneHourRefreshPeriod, requestAnalyzerConfig); - HttpServletRequest mockRequest = prepareMockRequest(); // even though special label is present, query is not from airflow. // should return no match - when(mockRequest.getHeader(TRINO_CLIENT_TAGS_HEADER)).thenReturn( - "email=test@example.com,label=special"); + HttpServletRequest mockRequest = new QueryRequestMock() + .httpHeader(TRINO_CLIENT_TAGS_HEADER, "email=test@example.com,label=special") + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isNull(); @@ -322,11 +362,12 @@ void testByRoutingRulesEngineFileChange() RoutingGroupSelector routingGroupSelector = RoutingGroupSelector.byRoutingRulesEngine(file.getPath(), refreshPeriod, requestAnalyzerConfig); - HttpServletRequest mockRequest = prepareMockRequest(); + HttpServletRequest mockRequest = new QueryRequestMock() + .httpHeader(TRINO_SOURCE_HEADER, "airflow") + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); - when(mockRequest.getHeader(TRINO_SOURCE_HEADER)).thenReturn("airflow"); String routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); - assertThat(routingGroup).isEqualTo("etl"); try (BufferedWriter writer = Files.newBufferedWriter(file.toPath(), UTF_8)) { @@ -339,10 +380,12 @@ void testByRoutingRulesEngineFileChange() + " - \"result.put(\\\"routingGroup\\\", \\\"etl2\\\")\""); // change from etl to etl2 } Thread.sleep(2 * refreshPeriod.toMillis()); + when(mockRequest.getHeader(TRINO_SOURCE_HEADER)).thenReturn("airflow"); routingGroup = routingGroupSelector.findRoutingDestination(mockRequest).routingGroup(); assertThat(routingGroup).isEqualTo("etl2"); + file.deleteOnExit(); } @@ -449,16 +492,13 @@ private Stream provideTableExtractionQueries() void testTrinoQueryPropertiesTableExtraction(String query, Set catalogs, Set schemas, Set tables) throws IOException { - BufferedReader bufferedReader = new BufferedReader(new StringReader(query)); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getReader()).thenReturn(bufferedReader); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG); - when(mockRequest.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA); + HttpServletRequest mockRequest = new QueryRequestMock().query(query) + .httpHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME, DEFAULT_CATALOG) + .httpHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME, DEFAULT_SCHEMA) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); - TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties( - mockRequest, - requestAnalyzerConfig.isClientsUseV2Format(), - requestAnalyzerConfig.getMaxBodySize()); + TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) mockRequest.getAttribute(TRINO_QUERY_PROPERTIES); assertThat(trinoQueryProperties.getTables()).isEqualTo(tables); assertThat(trinoQueryProperties.getSchemas()).isEqualTo(schemas); @@ -474,29 +514,28 @@ WITH dos AS (SELECT c1 from cat.schem.tbl1), uno as (SELECT c1 FROM dos) SELECT c1 FROM uno, dos """; - HttpServletRequest mockRequestWithDefaults = prepareMockRequest(); - when(mockRequestWithDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query))); - when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG); - when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA); - - TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties( - mockRequestWithDefaults, - requestAnalyzerConfig.isClientsUseV2Format(), - requestAnalyzerConfig.getMaxBodySize()); + + HttpServletRequest mockRequestNoDefaults = new QueryRequestMock().query(query) + .httpHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME, DEFAULT_CATALOG) + .httpHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME, DEFAULT_SCHEMA) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + + TrinoQueryProperties trinoQueryPropertiesWithDefaults = getTrinoQueryProps(mockRequestNoDefaults); Set tablesWithDefaults = trinoQueryPropertiesWithDefaults.getTables(); assertThat(tablesWithDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1")); - - HttpServletRequest mockRequestNoDefaults = prepareMockRequest(); when(mockRequestNoDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query))); - TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties( - mockRequestNoDefaults, - requestAnalyzerConfig.isClientsUseV2Format(), - requestAnalyzerConfig.getMaxBodySize()); + TrinoQueryProperties trinoQueryPropertiesNoDefaults = (TrinoQueryProperties) mockRequestNoDefaults.getAttribute(TRINO_QUERY_PROPERTIES); Set tablesNoDefaults = trinoQueryPropertiesNoDefaults.getTables(); assertThat(tablesNoDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1")); } + private TrinoQueryProperties getTrinoQueryProps(HttpServletRequest request) + { + return (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES); + } + private HttpServletRequest prepareMockRequest() { HttpServletRequest mockRequest = mock(HttpServletRequest.class); @@ -508,13 +547,13 @@ private HttpServletRequest prepareMockRequest() void testLongQuery() throws IOException { - BufferedReader bufferedReader = Files.newBufferedReader(Path.of("src/test/resources/wide_select.sql"), UTF_8); - HttpServletRequest mockRequest = prepareMockRequest(); - when(mockRequest.getReader()).thenReturn(bufferedReader); - TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties( - mockRequest, - requestAnalyzerConfig.isClientsUseV2Format(), - requestAnalyzerConfig.getMaxBodySize()); + String query = Files.readString(Path.of("src/test/resources/wide_select.sql"), UTF_8); + + HttpServletRequest mockRequest = new QueryRequestMock().query(query) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); + + TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) mockRequest.getAttribute(TRINO_QUERY_PROPERTIES); assertThat(trinoQueryProperties.tablesContains("kat.schem.widetable")).isTrue(); } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoRequestUser.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoRequestUser.java index bf83529e6..ca4d7f3e7 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoRequestUser.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoRequestUser.java @@ -17,6 +17,7 @@ import com.auth0.jwt.algorithms.Algorithm; import io.airlift.json.JsonCodec; import io.trino.gateway.ha.config.RequestAnalyzerConfig; +import io.trino.gateway.ha.util.QueryRequestMock; import jakarta.servlet.http.HttpServletRequest; import org.junit.jupiter.api.Test; @@ -26,12 +27,11 @@ import java.util.Optional; import static com.auth0.jwt.algorithms.Algorithm.HMAC256; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER; import static io.trino.gateway.ha.handler.HttpUtils.USER_HEADER; import static jakarta.ws.rs.core.HttpHeaders.AUTHORIZATION; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; final class TestTrinoRequestUser { @@ -74,11 +74,14 @@ void testUserFromJwtToken() .withExpiresAt(Date.from(expiryTime)) .sign(algorithm); - HttpServletRequest mockRequest = mock(HttpServletRequest.class); - when(mockRequest.getHeader(USER_HEADER)).thenReturn(null); - when(mockRequest.getHeader(AUTHORIZATION)).thenReturn("Bearer " + token); + HttpServletRequest mockRequest = new QueryRequestMock() + .requestAnalyzerConfig(requestAnalyzerConfig) + .httpHeader(USER_HEADER, null) + .httpHeader(AUTHORIZATION, "Bearer " + token) + .requestAnalyzerConfig(requestAnalyzerConfig) + .getHttpServletRequest(); - TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest); + TrinoRequestUser trinoRequestUser = (TrinoRequestUser) mockRequest.getAttribute(TRINO_REQUEST_USER); assertThat(trinoRequestUser.getUser()).hasValue(claimUserValue); } @@ -90,13 +93,15 @@ void testGetBasicAuthUser() String password = "don't care"; String credentials = username + ":" + password; String encodedCredentials = Base64.getEncoder().encodeToString(credentials.getBytes(UTF_8)); + RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig(); - HttpServletRequest mockRequest = mock(HttpServletRequest.class); - when(mockRequest.getHeader(USER_HEADER)).thenReturn(null); - when(mockRequest.getHeader(AUTHORIZATION)).thenReturn("Basic " + encodedCredentials); + HttpServletRequest mockRequest = new QueryRequestMock() + .requestAnalyzerConfig(requestAnalyzerConfig) + .httpHeader(USER_HEADER, null) + .httpHeader(AUTHORIZATION, "Basic " + encodedCredentials) + .getHttpServletRequest(); - RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig(); - TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest); + TrinoRequestUser trinoRequestUser = (TrinoRequestUser) mockRequest.getAttribute(TRINO_REQUEST_USER); assertThat(trinoRequestUser.getUser()).hasValue(username); } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestQueryMetadataParser.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestQueryMetadataParser.java new file mode 100644 index 000000000..f78a2be6d --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestQueryMetadataParser.java @@ -0,0 +1,208 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.security; + +import io.trino.gateway.ha.config.HaGatewayConfiguration; +import io.trino.gateway.ha.config.RequestAnalyzerConfig; +import io.trino.gateway.ha.handler.HttpUtils; +import io.trino.gateway.ha.router.PathFilter; +import io.trino.gateway.ha.router.TrinoQueryProperties; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.UriInfo; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ExtendedUriInfo; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +final class TestQueryMetadataParser +{ + private QueryMetadataParser filter; + private RequestAnalyzerConfig requestAnalyzerConfig; + private PathFilter pathFilter; + + TestQueryMetadataParser() + { + HaGatewayConfiguration config = new HaGatewayConfiguration(); + requestAnalyzerConfig = new RequestAnalyzerConfig(); + requestAnalyzerConfig.setAnalyzeRequest(true); + config.setRequestAnalyzerConfig(requestAnalyzerConfig); + pathFilter = new PathFilter(config.getStatementPaths(), config.getExtraWhitelistPaths()); + filter = new QueryMetadataParser(config, pathFilter); + } + + @Test + void testFilterSetsTrinoQueryPropertiesWithEntityBody() + throws Exception + { + ContainerRequestContext requestContext = mock(ContainerRequest.class); + when(requestContext.getMethod()).thenReturn("POST"); + + UriInfo uriInfo = mock(ExtendedUriInfo.class); + try { + when(uriInfo.getRequestUri()).thenReturn(new URI("http://localhost" + HttpUtils.V1_STATEMENT_PATH)); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + when(requestContext.getUriInfo()).thenReturn(uriInfo); + + MediaType mediaType = new MediaType("application", "json", java.util.Map.of("charset", "UTF-8")); + when(requestContext.getMediaType()).thenReturn(mediaType); + + String query = "Select xyz from cat1.schema1.table1"; + InputStream entityStream = new ByteArrayInputStream(query.getBytes(StandardCharsets.UTF_8)); + when(requestContext.getEntityStream()).thenReturn(entityStream); + when(requestContext.hasEntity()).thenReturn(true); + filter.filter(requestContext); + + ArgumentCaptor captor = ArgumentCaptor.forClass(TrinoQueryProperties.class); + verify(requestContext).setProperty(eq(TRINO_QUERY_PROPERTIES), captor.capture()); + verify((ContainerRequest) requestContext).bufferEntity(); + verify(requestContext).getEntityStream(); + } + + @Test + void testFilterSetsTrinoQueryPropertiesWithNoV1Statement() + throws Exception + { + ContainerRequestContext requestContext = mock(ContainerRequest.class); + when(requestContext.getMethod()).thenReturn("POST"); + + UriInfo uriInfo = mock(ExtendedUriInfo.class); + try { + when(uriInfo.getRequestUri()).thenReturn(new URI("http://localhost" + HttpUtils.OAUTH_PATH)); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + when(requestContext.getUriInfo()).thenReturn(uriInfo); + + MediaType mediaType = new MediaType("application", "json", java.util.Map.of("charset", "UTF-8")); + when(requestContext.getMediaType()).thenReturn(mediaType); + + String query = "Select xyz from cat1.schema1.table1"; + InputStream entityStream = new ByteArrayInputStream(query.getBytes(StandardCharsets.UTF_8)); + when(requestContext.getEntityStream()).thenReturn(entityStream); + when(requestContext.hasEntity()).thenReturn(true); + filter.filter(requestContext); + + ArgumentCaptor captor = ArgumentCaptor.forClass(TrinoQueryProperties.class); + verify(requestContext).setProperty(eq(TRINO_QUERY_PROPERTIES), captor.capture()); + verify((ContainerRequest) requestContext).bufferEntity(); + verify(requestContext).getEntityStream(); + } + + @Test + void testFilterSetsTrinoQueryPropertiesWithNoMedia() + throws Exception + { + ContainerRequestContext requestContext = mock(ContainerRequest.class); + when(requestContext.getMethod()).thenReturn("POST"); + + UriInfo uriInfo = mock(ExtendedUriInfo.class); + try { + when(uriInfo.getRequestUri()).thenReturn(new URI("http://localhost" + HttpUtils.V1_STATEMENT_PATH)); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + when(requestContext.getUriInfo()).thenReturn(uriInfo); + + String query = "Select xyz from cat1.schema1.table1"; + InputStream entityStream = new ByteArrayInputStream(query.getBytes(StandardCharsets.UTF_8)); + when(requestContext.getEntityStream()).thenReturn(entityStream); + when(requestContext.hasEntity()).thenReturn(true); + filter.filter(requestContext); + + ArgumentCaptor captor = ArgumentCaptor.forClass(TrinoQueryProperties.class); + verify(requestContext).setProperty(eq(TRINO_QUERY_PROPERTIES), captor.capture()); + TrinoQueryProperties queryProperties = (TrinoQueryProperties) requestContext.getProperty(TRINO_QUERY_PROPERTIES); + assertThat(queryProperties).isEqualTo(null); + verify((ContainerRequest) requestContext).bufferEntity(); + } + + @Test + void testFilterSetsTrinoQueryPropertiesWithNoQueryText() + throws Exception + { + ContainerRequestContext requestContext = mock(ContainerRequest.class); + when(requestContext.getMethod()).thenReturn("POST"); + + UriInfo uriInfo = mock(ExtendedUriInfo.class); + try { + when(uriInfo.getRequestUri()).thenReturn(new URI("http://localhost" + HttpUtils.V1_STATEMENT_PATH)); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + when(requestContext.getUriInfo()).thenReturn(uriInfo); + + MediaType mediaType = new MediaType("application", "json", java.util.Map.of("charset", "UTF-8")); + when(requestContext.getMediaType()).thenReturn(mediaType); + + String query = ""; + InputStream entityStream = new ByteArrayInputStream(query.getBytes(StandardCharsets.UTF_8)); + when(requestContext.getEntityStream()).thenReturn(entityStream); + when(requestContext.hasEntity()).thenReturn(true); + filter.filter(requestContext); + + ArgumentCaptor captor = ArgumentCaptor.forClass(TrinoQueryProperties.class); + verify(requestContext).setProperty(eq(TRINO_QUERY_PROPERTIES), captor.capture()); + verify((ContainerRequest) requestContext).bufferEntity(); + verify(requestContext).getEntityStream(); + + TrinoQueryProperties queryProperties = (TrinoQueryProperties) requestContext.getProperty(TRINO_QUERY_PROPERTIES); + assertThat(queryProperties).isEqualTo(null); + } + + @Test + void testFilterSetsTrinoQueryPropertiesWithNoPostPayload() + throws Exception + { + ContainerRequestContext requestContext = mock(ContainerRequest.class); + when(requestContext.getMethod()).thenReturn("POST"); + + UriInfo uriInfo = mock(ExtendedUriInfo.class); + try { + when(uriInfo.getRequestUri()).thenReturn(new URI("http://localhost" + HttpUtils.V1_STATEMENT_PATH)); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + when(requestContext.getUriInfo()).thenReturn(uriInfo); + + MediaType mediaType = new MediaType("application", "json", java.util.Map.of("charset", "UTF-8")); + when(requestContext.getMediaType()).thenReturn(mediaType); + + filter.filter(requestContext); + + ArgumentCaptor captor = ArgumentCaptor.forClass(TrinoQueryProperties.class); + verify(requestContext).setProperty(eq(TRINO_QUERY_PROPERTIES), captor.capture()); + verify((ContainerRequest) requestContext).bufferEntity(); + } +} diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestQueryUserInfoParser.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestQueryUserInfoParser.java new file mode 100644 index 000000000..11c6d6145 --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestQueryUserInfoParser.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.security; + +import io.trino.gateway.ha.config.HaGatewayConfiguration; +import io.trino.gateway.ha.config.RequestAnalyzerConfig; +import io.trino.gateway.ha.handler.HttpUtils; +import io.trino.gateway.ha.router.PathFilter; +import io.trino.gateway.ha.router.TrinoRequestUser; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.UriInfo; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ExtendedUriInfo; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import java.net.URI; +import java.util.Base64; + +import static io.trino.gateway.ha.handler.HttpUtils.V1_STATEMENT_PATH; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +final class TestQueryUserInfoParser +{ + private QueryUserInfoParser filter; + + TestQueryUserInfoParser() + { + HaGatewayConfiguration config = new HaGatewayConfiguration(); + RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig(); + requestAnalyzerConfig.setAnalyzeRequest(true); + config.setRequestAnalyzerConfig(requestAnalyzerConfig); + + PathFilter pathFilter = new PathFilter(config.getStatementPaths(), config.getExtraWhitelistPaths()); + + filter = new QueryUserInfoParser(config, pathFilter); + } + + @Test + void testFilterSetsQueryUserInfo() + throws Exception + { + ContainerRequestContext requestContext = mock(ContainerRequest.class); + when(requestContext.getMethod()).thenReturn("POST"); + + UriInfo uriInfo = mock(ExtendedUriInfo.class); + when(uriInfo.getRequestUri()).thenReturn(new URI("http://localhost" + V1_STATEMENT_PATH)); + when(requestContext.getUriInfo()).thenReturn(uriInfo); + + MediaType mediaType = new MediaType("application", "json", java.util.Map.of("charset", "UTF-8")); + when(requestContext.getMediaType()).thenReturn(mediaType); + + String encodedUsernamePassword = Base64.getEncoder().encodeToString("MrXYZ:OutInTheOpen".getBytes(UTF_8)); + when(requestContext.getHeaderString(HttpHeaders.AUTHORIZATION)).thenReturn("Basic " + encodedUsernamePassword); + + filter.filter(requestContext); + + ArgumentCaptor userCaptor = ArgumentCaptor.forClass(TrinoRequestUser.class); + verify(requestContext).setProperty(eq(HttpUtils.TRINO_REQUEST_USER), userCaptor.capture()); + } +} diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/util/QueryRequestMock.java b/gateway-ha/src/test/java/io/trino/gateway/ha/util/QueryRequestMock.java new file mode 100644 index 000000000..e02a3cf7f --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/util/QueryRequestMock.java @@ -0,0 +1,173 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.util; + +import io.trino.gateway.ha.config.HaGatewayConfiguration; +import io.trino.gateway.ha.config.RequestAnalyzerConfig; +import io.trino.gateway.ha.handler.HttpUtils; +import io.trino.gateway.ha.router.PathFilter; +import io.trino.gateway.ha.router.TrinoQueryProperties; +import io.trino.gateway.ha.router.TrinoRequestUser; +import io.trino.gateway.ha.security.QueryMetadataParser; +import io.trino.gateway.ha.security.QueryUserInfoParser; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.HttpMethod; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.MultivaluedMap; +import jakarta.ws.rs.core.UriInfo; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ExtendedUriInfo; +import org.mockito.ArgumentCaptor; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringReader; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES; +import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class QueryRequestMock +{ + private RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig(); + private ContainerRequestContext requestContext = mock(ContainerRequest.class); + private HttpServletRequest mockRequest = mock(HttpServletRequest.class); + + private void setDefaultMockParams() + { + when(requestContext.getMethod()).thenReturn("POST"); + when(mockRequest.getMethod()).thenReturn(HttpMethod.POST); + UriInfo uriInfo = mock(ExtendedUriInfo.class); + try { + when(uriInfo.getRequestUri()).thenReturn(new URI("http://localhost" + HttpUtils.V1_STATEMENT_PATH)); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + when(requestContext.getUriInfo()).thenReturn(uriInfo); + } + + public QueryRequestMock requestAnalyzerConfig(RequestAnalyzerConfig config) + { + requestAnalyzerConfig = config; + return this; + } + + public QueryRequestMock query(String query) + throws IOException + { + if (!query.isEmpty()) { + MediaType mediaType = new MediaType("application", "json", java.util.Map.of("charset", "UTF-8")); + when(requestContext.getMediaType()).thenReturn(mediaType); + InputStream entityStream = new ByteArrayInputStream(query.getBytes(StandardCharsets.UTF_8)); + when(requestContext.getEntityStream()).thenReturn(entityStream); + when(requestContext.hasEntity()).thenReturn(true); + } + else { + when(requestContext.hasEntity()).thenReturn(false); + } + + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8)); + when(mockRequest.getMethod()).thenReturn(HttpMethod.POST); + when(mockRequest.getInputStream()).thenReturn(new ServletInputStream() + { + @Override + public boolean isFinished() + { + return byteArrayInputStream.available() > 0; + } + + @Override + public boolean isReady() + { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) + {} + + @Override + public int read() + throws IOException + { + return byteArrayInputStream.read(); + } + }); + + when(mockRequest.getReader()).thenReturn(new BufferedReader(new StringReader(query))); + when(mockRequest.getQueryString()).thenReturn(""); + return this; + } + + public QueryRequestMock httpHeader(String name, String value) + { + when(requestContext.getHeaderString(name)).thenReturn(value); + when(mockRequest.getHeader(name)).thenReturn(value); + return this; + } + + public QueryRequestMock httpHeaders(MultivaluedMap headers) + { + when(requestContext.getHeaders()).thenReturn(headers); + return this; + } + + public HttpServletRequest getHttpServletRequest() + { + setDefaultMockParams(); + + HaGatewayConfiguration config = new HaGatewayConfiguration(); + config.setRequestAnalyzerConfig(requestAnalyzerConfig); + + PathFilter pathFilter = new PathFilter(config.getStatementPaths(), config.getExtraWhitelistPaths()); + + QueryUserInfoParser userInfoParser = new QueryUserInfoParser(config, pathFilter); + try { + userInfoParser.filter(requestContext); + ArgumentCaptor captorUserInfo = ArgumentCaptor.forClass(TrinoRequestUser.class); + verify(requestContext).setProperty(eq(TRINO_REQUEST_USER), captorUserInfo.capture()); + when(mockRequest.getAttribute(TRINO_REQUEST_USER)).thenReturn(captorUserInfo.getValue()); + } + catch (IOException ex) { + return null; + } + + QueryMetadataParser queryMetadataParser = new QueryMetadataParser(config, pathFilter); + try { + if (requestAnalyzerConfig.isAnalyzeRequest()) { + queryMetadataParser.filter(requestContext); + ArgumentCaptor captor = ArgumentCaptor.forClass(TrinoQueryProperties.class); + verify(requestContext).setProperty(eq(TRINO_QUERY_PROPERTIES), captor.capture()); + when(mockRequest.getAttribute(TRINO_QUERY_PROPERTIES)).thenReturn(captor.getValue()); + } + } + catch (IOException ex) { + return null; + } + return mockRequest; + } +}