Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature](security) Support block specific query with AST names (#43533) #43887

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3206,4 +3206,10 @@ public static int metaServiceRpcRetryTimes() {
"For testing purposes, all queries are forcibly forwarded to the master to verify"
+ "the behavior of forwarding queries."})
public static boolean force_forward_all_queries = false;

@ConfField(description = {"用于禁用某些SQL,配置项为AST的class simple name列表(例如CreateRepositoryStmt,"
+ "CreatePolicyCommand),用逗号间隔开",
"For disabling certain SQL queries, the configuration item is a list of simple class names of AST"
+ "(for example CreateRepositoryStmt, CreatePolicyCommand), separated by commas."})
public static String block_sql_ast_names = "";
}
3 changes: 3 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@
import org.apache.doris.qe.JournalObservable;
import org.apache.doris.qe.QueryCancelWorker;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.qe.VariableMgr;
import org.apache.doris.resource.AdmissionControl;
import org.apache.doris.resource.Tag;
Expand Down Expand Up @@ -1116,6 +1117,8 @@ public void initialize(String[] args) throws Exception {
notifyNewFETypeTransfer(FrontendNodeType.MASTER);
}
queryCancelWorker.start();

StmtExecutor.initBlockSqlAstNames();
}

// wait until FE is ready.
Expand Down
21 changes: 21 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ public class StmtExecutor {
private static final AtomicLong STMT_ID_GENERATOR = new AtomicLong(0);
public static final int MAX_DATA_TO_SEND_FOR_TXN = 100;
public static final String NULL_VALUE_FOR_LOAD = "\\N";
private static Set<String> blockSqlAstNames = Sets.newHashSet();

private Pattern beIpPattern = Pattern.compile("\\[(\\d+):");
private ConnectContext context;
private final StatementContext statementContext;
Expand Down Expand Up @@ -690,6 +692,7 @@ private void executeByNereids(TUniqueId queryId) throws Exception {
"Nereids only process LogicalPlanAdapter, but parsedStmt is " + parsedStmt.getClass().getName());
context.getState().setNereids(true);
LogicalPlan logicalPlan = ((LogicalPlanAdapter) parsedStmt).getLogicalPlan();
checkSqlBlocked(logicalPlan.getClass());
if (context.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
if (isForwardToMaster()) {
throw new UserException("Forward master command is not supported for prepare statement");
Expand Down Expand Up @@ -829,6 +832,23 @@ private void executeByNereids(TUniqueId queryId) throws Exception {
}
}

public static void initBlockSqlAstNames() {
blockSqlAstNames.clear();
blockSqlAstNames = Pattern.compile(",")
.splitAsStream(Config.block_sql_ast_names)
.map(String::trim)
.collect(Collectors.toSet());
if (blockSqlAstNames.isEmpty() && !Config.block_sql_ast_names.isEmpty()) {
blockSqlAstNames.add(Config.block_sql_ast_names);
}
}

public void checkSqlBlocked(Class<?> clazz) throws UserException {
if (blockSqlAstNames.contains(clazz.getSimpleName())) {
throw new UserException("SQL is blocked with AST name: " + clazz.getSimpleName());
}
}

private void parseByNereids() {
if (parsedStmt != null) {
return;
Expand Down Expand Up @@ -976,6 +996,7 @@ public void executeByLegacy(TUniqueId queryId) throws Exception {
try {
// parsedStmt maybe null here, we parse it. Or the predicate will not work.
parseByLegacy();
checkSqlBlocked(parsedStmt.getClass());
if (context.isTxnModel() && !(parsedStmt instanceof InsertStmt)
&& !(parsedStmt instanceof TransactionStmt)) {
throw new TException("This is in a transaction, only insert, update, delete, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.StmtExecutor;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -660,4 +662,38 @@ public void testCreateRole() {
String sql = "create role a comment 'create user'";
nereidsParser.parseSingle(sql);
}

@Test
public void testBlockSqlAst() {
String sql = "plan replayer dump select `AD``D` from t1 where a = 1";
NereidsParser nereidsParser = new NereidsParser();
LogicalPlan logicalPlan = nereidsParser.parseSingle(sql);

Config.block_sql_ast_names = "ReplayCommand";
StmtExecutor.initBlockSqlAstNames();
StmtExecutor stmtExecutor = new StmtExecutor(new ConnectContext(), "");
try {
stmtExecutor.checkSqlBlocked(logicalPlan.getClass());
Assertions.fail();
} catch (Exception ignore) {
// do nothing
}

Config.block_sql_ast_names = "CreatePolicyCommand, ReplayCommand";
StmtExecutor.initBlockSqlAstNames();
try {
stmtExecutor.checkSqlBlocked(logicalPlan.getClass());
Assertions.fail();
} catch (Exception ignore) {
// do nothing
}

Config.block_sql_ast_names = "";
StmtExecutor.initBlockSqlAstNames();
try {
stmtExecutor.checkSqlBlocked(logicalPlan.getClass());
} catch (Exception ex) {
Assertions.fail(ex);
}
}
}
90 changes: 90 additions & 0 deletions fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.doris.analysis.AccessTestUtil;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.CreateFileStmt;
import org.apache.doris.analysis.CreateFunctionStmt;
import org.apache.doris.analysis.DdlStmt;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.KillStmt;
Expand All @@ -31,6 +33,7 @@
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.analysis.UseStmt;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.Config;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.common.profile.Profile;
Expand Down Expand Up @@ -799,4 +802,91 @@ public void testUseWithCatalogFail(@Mocked UseStmt useStmt, @Mocked SqlParser pa

Assert.assertEquals(QueryState.MysqlStateType.ERR, state.getStateType());
}

@Test
public void testBlockSqlAst(@Mocked UseStmt useStmt, @Mocked CreateFileStmt createFileStmt,
@Mocked CreateFunctionStmt createFunctionStmt, @Mocked SqlParser parser) throws Exception {
new Expectations() {
{
useStmt.analyze((Analyzer) any);
minTimes = 0;

useStmt.getDatabase();
minTimes = 0;
result = "testDb";

useStmt.getRedirectStatus();
minTimes = 0;
result = RedirectStatus.NO_FORWARD;

useStmt.getCatalogName();
minTimes = 0;
result = InternalCatalog.INTERNAL_CATALOG_NAME;

Symbol symbol = new Symbol(0, Lists.newArrayList(createFileStmt));
parser.parse();
minTimes = 0;
result = symbol;
}
};

Config.block_sql_ast_names = "CreateFileStmt";
StmtExecutor.initBlockSqlAstNames();
StmtExecutor executor = new StmtExecutor(ctx, "");
try {
executor.execute();
} catch (Exception ignore) {
// do nothing
}
Assert.assertEquals(QueryState.MysqlStateType.ERR, state.getStateType());
Assert.assertTrue(state.getErrorMessage().contains("SQL is blocked with AST name: CreateFileStmt"));

Config.block_sql_ast_names = "AlterStmt, CreateFileStmt";
StmtExecutor.initBlockSqlAstNames();
executor = new StmtExecutor(ctx, "");
try {
executor.execute();
} catch (Exception ignore) {
// do nothing
}
Assert.assertEquals(QueryState.MysqlStateType.ERR, state.getStateType());
Assert.assertTrue(state.getErrorMessage().contains("SQL is blocked with AST name: CreateFileStmt"));

new Expectations() {
{
Symbol symbol = new Symbol(0, Lists.newArrayList(createFunctionStmt));
parser.parse();
minTimes = 0;
result = symbol;
}
};
Config.block_sql_ast_names = "CreateFunctionStmt, CreateFileStmt";
StmtExecutor.initBlockSqlAstNames();
executor = new StmtExecutor(ctx, "");
try {
executor.execute();
} catch (Exception ignore) {
// do nothing
}
Assert.assertEquals(QueryState.MysqlStateType.ERR, state.getStateType());
Assert.assertTrue(state.getErrorMessage().contains("SQL is blocked with AST name: CreateFunctionStmt"));

new Expectations() {
{
Symbol symbol = new Symbol(0, Lists.newArrayList(useStmt));
parser.parse();
minTimes = 0;
result = symbol;
}
};
executor = new StmtExecutor(ctx, "");
executor.execute();
Assert.assertEquals(QueryState.MysqlStateType.OK, state.getStateType());

Config.block_sql_ast_names = "";
StmtExecutor.initBlockSqlAstNames();
executor = new StmtExecutor(ctx, "");
executor.execute();
Assert.assertEquals(QueryState.MysqlStateType.OK, state.getStateType());
}
}
Loading