diff --git a/sqlmesh/core/engine_adapter/fabric.py b/sqlmesh/core/engine_adapter/fabric.py index 8e2fb0e496..8f84dab19f 100644 --- a/sqlmesh/core/engine_adapter/fabric.py +++ b/sqlmesh/core/engine_adapter/fabric.py @@ -14,6 +14,8 @@ ) from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.connection_pool import ConnectionPool +from sqlmesh.core.schema_diff import TableAlterOperation +from sqlmesh.utils import random_id logger = logging.getLogger(__name__) @@ -154,6 +156,113 @@ def set_current_catalog(self, catalog_name: str) -> None: f"Unable to switch catalog to {catalog_name}, catalog ended up as {catalog_after_switch}" ) + def alter_table( + self, alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]] + ) -> None: + """ + Applies alter expressions to a table. Fabric has limited support for ALTER TABLE, + so this method implements a workaround for column type changes. + This method is self-contained and sets its own catalog context. + """ + if not alter_expressions: + return + + # Get the target table from the first expression to determine the correct catalog. + first_op = alter_expressions[0] + expression = first_op.expression if isinstance(first_op, TableAlterOperation) else first_op + if not isinstance(expression, exp.Alter) or not expression.this.catalog: + # Fallback for unexpected scenarios + logger.warning( + "Could not determine catalog from alter expression, executing with current context." + ) + super().alter_table(alter_expressions) + return + + target_catalog = expression.this.catalog + self.set_current_catalog(target_catalog) + + with self.transaction(): + for op in alter_expressions: + expression = op.expression if isinstance(op, TableAlterOperation) else op + + if not isinstance(expression, exp.Alter): + self.execute(expression) + continue + + for action in expression.actions: + table_name = expression.this + + table_name_without_catalog = table_name.copy() + table_name_without_catalog.set("catalog", None) + + is_type_change = isinstance(action, exp.AlterColumn) and action.args.get( + "dtype" + ) + + if is_type_change: + column_to_alter = action.this + new_type = action.args["dtype"] + temp_column_name_str = f"{column_to_alter.name}__{random_id(short=True)}" + temp_column_name = exp.to_identifier(temp_column_name_str) + + logger.info( + "Applying workaround for column '%s' on table '%s' to change type to '%s'.", + column_to_alter.sql(), + table_name.sql(), + new_type.sql(), + ) + + # Step 1: Add a temporary column. + add_column_expr = exp.Alter( + this=table_name_without_catalog.copy(), + kind="TABLE", + actions=[ + exp.ColumnDef(this=temp_column_name.copy(), kind=new_type.copy()) + ], + ) + add_sql = self._to_sql(add_column_expr) + self.execute(add_sql) + + # Step 2: Copy and cast data. + update_sql = self._to_sql( + exp.Update( + this=table_name_without_catalog.copy(), + expressions=[ + exp.EQ( + this=temp_column_name.copy(), + expression=exp.Cast( + this=column_to_alter.copy(), to=new_type.copy() + ), + ) + ], + ) + ) + self.execute(update_sql) + + # Step 3: Drop the original column. + drop_sql = self._to_sql( + exp.Alter( + this=table_name_without_catalog.copy(), + kind="TABLE", + actions=[exp.Drop(this=column_to_alter.copy(), kind="COLUMN")], + ) + ) + self.execute(drop_sql) + + # Step 4: Rename the temporary column. + old_name_qualified = f"{table_name_without_catalog.sql(dialect=self.dialect)}.{temp_column_name.sql(dialect=self.dialect)}" + new_name_unquoted = column_to_alter.sql( + dialect=self.dialect, identify=False + ) + rename_sql = f"EXEC sp_rename '{old_name_qualified}', '{new_name_unquoted}', 'COLUMN'" + self.execute(rename_sql) + else: + # For other alterations, execute directly. + direct_alter_expr = exp.Alter( + this=table_name_without_catalog.copy(), kind="TABLE", actions=[action] + ) + self.execute(direct_alter_expr) + class FabricHttpClient: def __init__(self, tenant_id: str, workspace_id: str, client_id: str, client_secret: str): diff --git a/tests/core/engine_adapter/test_fabric.py b/tests/core/engine_adapter/test_fabric.py index 6b80ef7337..8d7f804bd0 100644 --- a/tests/core/engine_adapter/test_fabric.py +++ b/tests/core/engine_adapter/test_fabric.py @@ -88,3 +88,58 @@ def test_replace_query(adapter: FabricEngineAdapter, mocker: MockerFixture): "TRUNCATE TABLE [test_table];", "INSERT INTO [test_table] ([a]) SELECT [a] FROM [tbl];", ] + + +def test_alter_table_column_type_workaround(adapter: FabricEngineAdapter, mocker: MockerFixture): + """ + Tests the alter_table method's workaround for changing a column's data type. + """ + # Mock set_current_catalog to avoid connection pool side effects + set_catalog_mock = mocker.patch.object(adapter, "set_current_catalog") + # Mock random_id to have a predictable temporary column name + mocker.patch("sqlmesh.core.engine_adapter.fabric.random_id", return_value="abcdef") + + alter_expression = exp.Alter( + this=exp.to_table("my_db.my_schema.my_table"), + actions=[ + exp.AlterColumn( + this=exp.to_column("col_a"), + dtype=exp.DataType.build("BIGINT"), + ) + ], + ) + + adapter.alter_table([alter_expression]) + + set_catalog_mock.assert_called_once_with("my_db") + + expected_calls = [ + "ALTER TABLE [my_schema].[my_table] ADD [col_a__abcdef] BIGINT;", + "UPDATE [my_schema].[my_table] SET [col_a__abcdef] = CAST([col_a] AS BIGINT);", + "ALTER TABLE [my_schema].[my_table] DROP COLUMN [col_a];", + "EXEC sp_rename 'my_schema.my_table.col_a__abcdef', 'col_a', 'COLUMN'", + ] + + assert to_sql_calls(adapter) == expected_calls + + +def test_alter_table_direct_alteration(adapter: FabricEngineAdapter, mocker: MockerFixture): + """ + Tests the alter_table method for direct alterations like adding a column. + """ + set_catalog_mock = mocker.patch.object(adapter, "set_current_catalog") + + alter_expression = exp.Alter( + this=exp.to_table("my_db.my_schema.my_table"), + actions=[exp.ColumnDef(this=exp.to_column("new_col"), kind=exp.DataType.build("INT"))], + ) + + adapter.alter_table([alter_expression]) + + set_catalog_mock.assert_called_once_with("my_db") + + expected_calls = [ + "ALTER TABLE [my_schema].[my_table] ADD [new_col] INT;", + ] + + assert to_sql_calls(adapter) == expected_calls