Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import io.trino.gateway.ha.router.RoutingRulesManager;
import io.trino.gateway.ha.router.StochasticRoutingManager;
import io.trino.gateway.ha.security.AuthorizedExceptionMapper;
import io.trino.gateway.ha.security.QueryMetadataParser;
import io.trino.gateway.ha.security.QueryUserInfoParser;
import io.trino.gateway.proxyserver.ForProxy;
import io.trino.gateway.proxyserver.ProxyRequestHandler;
import io.trino.gateway.proxyserver.RouteToBackendResource;
Expand Down Expand Up @@ -187,6 +189,8 @@ private static void registerProxyResources(Binder binder)
{
jaxrsBinder(binder).bind(RouteToBackendResource.class);
jaxrsBinder(binder).bind(RouterPreMatchContainerRequestFilter.class);
jaxrsBinder(binder).bind(QueryUserInfoParser.class);
jaxrsBinder(binder).bind(QueryMetadataParser.class);
jaxrsBinder(binder).bind(ProxyRequestHandler.class);
httpClientBinder(binder).bindHttpClient("proxy", ForProxy.class);
httpClientBinder(binder).bindHttpClient("monitor", ForMonitor.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public class HttpUtils
public static final String TRINO_UI_PATH = "/ui";
public static final String OAUTH_PATH = "/oauth2";
public static final String USER_HEADER = "X-Trino-User";
public static final String TRINO_REQUEST_USER = "trinoRequestUser";
public static final String TRINO_QUERY_PROPERTIES = "trinoQueryProperties";

private HttpUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.regex.Pattern;

import static com.google.common.base.Strings.isNullOrEmpty;
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -78,7 +79,7 @@ public static Optional<String> extractQueryIdIfPresent(
throw new RuntimeException("Error reading request body", e);
}
if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) {
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
return trinoQueryProperties.getQueryId();
}
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,10 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;
import java.util.stream.Stream;

import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.USER_HEADER;
import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
import static io.trino.gateway.ha.handler.ProxyUtils.buildUriWithNewCluster;
import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
import static java.util.Objects.requireNonNull;
Expand All @@ -56,7 +48,6 @@ public class RoutingTargetHandler
private final RoutingGroupSelector routingGroupSelector;
private final String defaultRoutingGroup;
private final List<String> statementPaths;
private final List<Pattern> extraWhitelistPaths;
private final boolean requestAnalyserClientsUseV2Format;
private final int requestAnalyserMaxBodySize;
private final boolean cookiesEnabled;
Expand All @@ -71,7 +62,6 @@ public RoutingTargetHandler(
this.routingGroupSelector = requireNonNull(routingGroupSelector);
this.defaultRoutingGroup = haGatewayConfiguration.getRouting().getDefaultRoutingGroup();
statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format();
requestAnalyserMaxBodySize = haGatewayConfiguration.getRequestAnalyzerConfig().getMaxBodySize();
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
Expand Down Expand Up @@ -118,18 +108,6 @@ private RoutingTargetResponse getRoutingTargetResponse(HttpServletRequest reques
modifiedRequest);
}

public boolean isPathWhiteListed(String path)
{
return statementPaths.stream().anyMatch(path::startsWith)
|| path.startsWith(V1_QUERY_PATH)
|| path.startsWith(TRINO_UI_PATH)
|| path.startsWith(V1_INFO_PATH)
|| path.startsWith(V1_NODE_PATH)
|| path.startsWith(UI_API_STATS_PATH)
|| path.startsWith(OAUTH_PATH)
|| extraWhitelistPaths.stream().anyMatch(pattern -> pattern.matcher(path).matches());
}

/**
* A wrapper for HttpServletRequest that allows modifying multiple headers.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import io.trino.gateway.ha.router.HaGatewayManager;
import io.trino.gateway.ha.router.HaQueryHistoryManager;
import io.trino.gateway.ha.router.HaResourceGroupsManager;
import io.trino.gateway.ha.router.PathFilter;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.ResourceGroupsManager;
import io.trino.gateway.ha.router.RoutingGroupSelector;
Expand Down Expand Up @@ -88,6 +89,7 @@ public class HaGatewayProviderModule
private final ResourceGroupsManager resourceGroupsManager;
private final GatewayBackendManager gatewayBackendManager;
private final QueryHistoryManager queryHistoryManager;
private final PathFilter pathFilter;

@Override
protected void configure()
Expand All @@ -97,11 +99,13 @@ protected void configure()
binder().bind(GatewayBackendManager.class).toInstance(gatewayBackendManager);
binder().bind(QueryHistoryManager.class).toInstance(queryHistoryManager);
binder().bind(BackendStateManager.class).in(Scopes.SINGLETON);
binder().bind(PathFilter.class).toInstance(pathFilter);
}

public HaGatewayProviderModule(HaGatewayConfiguration configuration)
{
this.configuration = requireNonNull(configuration, "configuration is null");
pathFilter = new PathFilter(configuration.getStatementPaths(), configuration.getExtraWhitelistPaths());
Map<String, UserConfiguration> presetUsers = configuration.getPresetUsers();

oauthManager = getOAuthManager(configuration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler;
import static io.airlift.http.client.Request.Builder.preparePost;
import static io.airlift.json.JsonCodec.jsonCodec;
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER;
import static java.util.Collections.list;
import static java.util.Objects.requireNonNull;

Expand All @@ -58,7 +60,6 @@ public class ExternalRoutingGroupSelector
private final boolean propagateErrors;
private final HttpClient httpClient;
private final RequestAnalyzerConfig requestAnalyzerConfig;
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
private static final JsonCodec<RoutingGroupExternalBody> ROUTING_GROUP_EXTERNAL_BODY_JSON_CODEC = jsonCodec(RoutingGroupExternalBody.class);
private static final JsonResponseHandler<ExternalRouterResponse> ROUTING_GROUP_EXTERNAL_RESPONSE_JSON_RESPONSE_HANDLER =
createJsonResponseHandler(jsonCodec(ExternalRouterResponse.class));
Expand All @@ -74,7 +75,6 @@ public class ExternalRoutingGroupSelector
propagateErrors = rulesExternalConfiguration.isPropagateErrors();

this.requestAnalyzerConfig = requestAnalyzerConfig;
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig);
try {
this.uri = new URI(requireNonNull(rulesExternalConfiguration.getUrlPath(),
"Invalid URL provided, using routing group header as default."));
Expand Down Expand Up @@ -143,8 +143,8 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request)
TrinoQueryProperties trinoQueryProperties = null;
TrinoRequestUser trinoRequestUser = null;
if (requestAnalyzerConfig.isAnalyzeRequest()) {
trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig.isClientsUseV2Format(), requestAnalyzerConfig.getMaxBodySize());
trinoRequestUser = trinoRequestUserProvider.getInstance(request);
trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER);
}

return new RoutingGroupExternalBody(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import java.util.Map;

import static com.google.common.base.Suppliers.memoizeWithExpiration;
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.sort;

Expand All @@ -46,16 +48,10 @@ public class FileBasedRoutingGroupSelector

private final Supplier<List<RoutingRule>> rules;
private final boolean analyzeRequest;
private final boolean clientsUseV2Format;
private final int maxBodySize;
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;

public FileBasedRoutingGroupSelector(String rulesPath, Duration rulesRefreshPeriod, RequestAnalyzerConfig requestAnalyzerConfig)
{
analyzeRequest = requestAnalyzerConfig.isAnalyzeRequest();
clientsUseV2Format = requestAnalyzerConfig.isClientsUseV2Format();
maxBodySize = requestAnalyzerConfig.getMaxBodySize();
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig);

rules = memoizeWithExpiration(() -> readRulesFromPath(Path.of(rulesPath)), rulesRefreshPeriod.toJavaTime());
}
Expand All @@ -68,12 +64,9 @@ public RoutingSelectorResponse findRoutingDestination(HttpServletRequest request

Map<String, Object> data;
if (analyzeRequest) {
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(
request,
clientsUseV2Format,
maxBodySize);
TrinoRequestUser trinoRequestUser = trinoRequestUserProvider.getInstance(request);
data = ImmutableMap.of("request", request, "trinoQueryProperties", trinoQueryProperties, "trinoRequestUser", trinoRequestUser);
TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
TrinoRequestUser trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER);
data = ImmutableMap.of("request", request, TRINO_QUERY_PROPERTIES, trinoQueryProperties, TRINO_REQUEST_USER, trinoRequestUser);
}
else {
data = ImmutableMap.of("request", request);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.router;

import com.google.inject.Inject;
import com.google.inject.Singleton;

import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH;
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
import static java.util.Objects.requireNonNull;

/**
* A filter component that determines whether a given path should be whitelisted
* for routing to Trino clusters.
*/
@Singleton
public class PathFilter
{
private final Set<String> statementPaths;
private final List<Pattern> extraWhitelistPatterns;

@Inject
public PathFilter(
List<String> statementPaths,
List<String> extraWhitelistPaths)
Comment on lines +44 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please inject HaGatewayConfiguration instead.

{
this.statementPaths = Set.copyOf(requireNonNull(statementPaths, "Required configuration 'statementPaths' can't be null"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ImmutableSet.copyOf and remove requireNonNull.

this.extraWhitelistPatterns = requireNonNull(extraWhitelistPaths, "extraWhitelistPaths cannot be null").stream()
.map(pattern -> {
try {
return Pattern.compile(pattern);
}
catch (PatternSyntaxException e) {
throw new IllegalArgumentException("Invalid regex pattern: " + pattern, e);
}
})
.collect(toImmutableList());
}

/**
* Determines if the given path is whitelisted for routing to backend.
*
* @param path the request path to check
* @return true if the path should be routed to backend, false otherwise
*/
public boolean isPathWhiteListed(String path)
{
return statementPaths.stream().anyMatch(path::startsWith)
|| path.startsWith(V1_QUERY_PATH)
|| path.startsWith(TRINO_UI_PATH)
|| path.startsWith(V1_INFO_PATH)
|| path.startsWith(V1_NODE_PATH)
|| path.startsWith(UI_API_STATS_PATH)
|| path.startsWith(OAUTH_PATH)
|| extraWhitelistPatterns.stream().anyMatch(pattern -> pattern.matcher(path).matches());
}
}
Loading