Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adhoc metrics #30202

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,7 @@ def adhoc_metric_to_sqla(
expression = self._process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down Expand Up @@ -1566,6 +1567,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
expression = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down
21 changes: 19 additions & 2 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetParseError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
Expand Down Expand Up @@ -112,6 +113,7 @@
def validate_adhoc_subquery(
sql: str,
database_id: int,
engine: str,
default_schema: str,
) -> str:
"""
Expand All @@ -126,7 +128,12 @@ def validate_adhoc_subquery(
"""
statements = []
for statement in sqlparse.parse(sql):
if has_table_query(statement):
try:
has_table = has_table_query(str(statement), engine)
except SupersetParseError:
has_table = True

if has_table:
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
raise SupersetSecurityException(
SupersetError(
Expand All @@ -135,7 +142,9 @@ def validate_adhoc_subquery(
level=ErrorLevel.ERROR,
)
)
# TODO (betodealmeida): reimplement with sqlglot
statement = insert_rls_in_predicate(statement, database_id, default_schema)

statements.append(statement)

return ";\n".join(str(statement) for statement in statements)
Expand Down Expand Up @@ -810,10 +819,11 @@ def get_sqla_row_level_filters(
# for datasources of type query
return []

def _process_sql_expression(
def _process_sql_expression( # pylint: disable=too-many-arguments
self,
expression: Optional[str],
database_id: int,
engine: str,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
Expand All @@ -823,6 +833,7 @@ def _process_sql_expression(
expression = validate_adhoc_subquery(
expression,
database_id,
engine,
schema,
)
try:
Expand Down Expand Up @@ -1108,6 +1119,7 @@ def adhoc_metric_to_sqla(
expression = self._process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down Expand Up @@ -1551,6 +1563,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
col["sqlExpression"] = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down Expand Up @@ -1613,6 +1626,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.database.backend,
self.schema,
)
outer = literal_column(f"({selected})")
Expand All @@ -1639,6 +1653,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
selected = validate_adhoc_subquery(
_sql,
self.database_id,
self.database.backend,
self.schema,
)

Expand Down Expand Up @@ -1915,6 +1930,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
where = self._process_sql_expression(
expression=where,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand All @@ -1933,6 +1949,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
having = self._process_sql_expression(
expression=having,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down
1 change: 1 addition & 0 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def adhoc_column_to_sqla(
expression = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down
40 changes: 13 additions & 27 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
extract_tables_from_statement,
SQLGLOT_DIALECTS,
SQLScript,
SQLStatement,
Table,
)
from superset.utils.backports import StrEnum
Expand Down Expand Up @@ -570,46 +571,31 @@ class InsertRLSState(StrEnum):
FOUND_TABLE = "FOUND_TABLE"


def has_table_query(token_list: TokenList) -> bool:
def has_table_query(expression: str, engine: str) -> bool:
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
"""
Return if a statement has a query reading from a table.

>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
>>> has_table_query("COUNT(*)", "postgresql")
False
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
>>> has_table_query("SELECT * FROM table", "postgresql")
True

Note that queries reading from constant values return false:

>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
>>> has_table_query("SELECT * FROM (SELECT 1)", "postgresql")
False

"""
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Ignore comments
if isinstance(token, sqlparse.sql.Comment):
continue

# Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
):
return True
# Remove trailing semicolon.
expression = expression.strip().rstrip(";")

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING
# Wrap the expression in parentheses if it's not already.
if not expression.startswith("("):
expression = f"({expression})"

return False
sql = f"SELECT {expression}"
statement = SQLStatement(sql, engine)
return any(statement.tables)


def add_table_name(rls: TokenList, table: str) -> None:
Expand Down
4 changes: 4 additions & 0 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_main_database,
)
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
Expand Down Expand Up @@ -585,6 +586,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
assert "INCORRECT SQL" in rv.json.get("error")


@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
Expand Down Expand Up @@ -649,6 +651,7 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
assert rv.json["result"]["rowcount"] == 0


@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
Expand All @@ -669,6 +672,7 @@ def test_get_samples_with_time_filter(test_client, login_as_admin, physical_data
assert rv.json["result"]["total_count"] == 2


@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_with_multiple_filters(
test_client, login_as_admin, physical_dataset
):
Expand Down
7 changes: 6 additions & 1 deletion tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
)
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.conftest import only_postgresql, only_sqlite
from tests.integration_tests.conftest import (
only_postgresql,
only_sqlite,
with_feature_flags,
)
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
load_birth_names_data, # noqa: F401
Expand Down Expand Up @@ -858,6 +862,7 @@ def test_non_time_column_with_time_grain(app_context, physical_dataset):
assert df["COL2 ALIAS"][0] == "a"


@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_special_chars_in_column_name(app_context, physical_dataset):
qc = QueryContextFactory().create(
datasource={
Expand Down
50 changes: 35 additions & 15 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,46 +1286,66 @@ def test_sqlparse_issue_652():


@pytest.mark.parametrize(
"sql,expected",
("engine", "sql", "expected"),
[
("SELECT * FROM table", True),
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
("COUNT(*)", False),
("SELECT a FROM (SELECT 1 AS a)", False),
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
("extract(HOUR from from_unixtime(hour_ts)", False),
("(SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) from birth_names)", True),
("postgresql", "extract(HOUR from from_unixtime(hour_ts))", False),
("postgresql", "SELECT * FROM table", True),
("postgresql", "(SELECT * FROM table)", True),
(
"postgresql",
"SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)",
True,
),
(
"postgresql",
"(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)",
True,
),
("postgresql", "COUNT(*)", False),
("postgresql", "SELECT a FROM (SELECT 1 AS a)", False),
("postgresql", "SELECT a FROM (SELECT 1 AS a) JOIN table", True),
(
"postgresql",
"SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar",
False,
),
("postgresql", "SELECT * FROM other_table", True),
("postgresql", "(SELECT COUNT(DISTINCT name) from birth_names)", True),
(
"postgresql",
"(SELECT table_name FROM information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"postgresql",
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"postgresql",
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
True,
),
(
"postgresql",
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
True,
),
(
"postgresql",
"((select users.id from (select 'majorie' as a) b, users where b.a = users.name and users.name in ('majorie') limit 1) like 'U%')",
True,
),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
def test_has_table_query(engine: str, sql: str, expected: bool) -> None:
"""
Test if a given statement queries a table.

This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
row-level security.
"""
statement = sqlparse.parse(sql)[0]
assert has_table_query(statement) == expected
assert has_table_query(sql, engine) == expected


@pytest.mark.parametrize(
Expand Down
Loading