diff --git a/rerun_py/tests/e2e_redap_tests/conftest.py b/rerun_py/tests/e2e_redap_tests/conftest.py index 988552c36c02..d4d8ed436b67 100644 --- a/rerun_py/tests/e2e_redap_tests/conftest.py +++ b/rerun_py/tests/e2e_redap_tests/conftest.py @@ -31,7 +31,7 @@ ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def table_filepath() -> Generator[pathlib.Path, None, None]: """ Copies test data to a temp directory. diff --git a/rerun_py/tests/e2e_redap_tests/test_table_write.py b/rerun_py/tests/e2e_redap_tests/test_table_write.py index 374f6bf943b6..8faf9df9bfe3 100644 --- a/rerun_py/tests/e2e_redap_tests/test_table_write.py +++ b/rerun_py/tests/e2e_redap_tests/test_table_write.py @@ -1,8 +1,10 @@ from __future__ import annotations +import threading from typing import TYPE_CHECKING import pyarrow as pa +import pytest from datafusion import DataFrameWriteOptions, InsertOp, SessionContext, col from rerun.catalog import TableInsertMode @@ -90,3 +92,61 @@ def test_client_append_to_table(server_instance: ServerInstance) -> None: table_name, id=[3, 4, 5], bool_col=[False, True, None], double_col=[2.0, None, 1.0] ) assert ctx.table(table_name).count() == original_rows + 4 + + +@pytest.mark.parametrize("is_append", [True, False]) +def test_concurrent_write_tables(server_instance: ServerInstance, is_append: bool) -> None: + num_writes = 100 + + table_name = "simple_datatypes" + ctx: SessionContext = server_instance.client.ctx + + df_prior = ctx.table(table_name) + prior_count = df_prior.count() + + df_low = ctx.table(table_name).filter(col("id") < 3).cache() + low_count = df_low.count() + + df_high = ctx.table(table_name).filter(col("id") >= 3).cache() + high_count = df_high.count() + + # Track any exceptions from threads + exceptions = [] + + insert_mode = InsertOp.APPEND if is_append else InsertOp.OVERWRITE + + def write_low() -> None: + for _ in range(num_writes): + try: + df_low.write_table(table_name, write_options=DataFrameWriteOptions(insert_operation=insert_mode)) + except Exception as e: + exceptions.append(e) + return + + def write_high() -> None: + for _ in range(num_writes): + try: + df_high.write_table(table_name, write_options=DataFrameWriteOptions(insert_operation=insert_mode)) + except Exception as e: + exceptions.append(e) + return + + thread1 = threading.Thread(target=write_low) + thread2 = threading.Thread(target=write_high) + + thread1.start() + thread2.start() + + thread1.join() + thread2.join() + + if exceptions: + raise exceptions[0] + + final_count = ctx.table(table_name).count() + + expected = ( + [prior_count + (num_writes * low_count) + (num_writes * high_count)] if is_append else [low_count, high_count] + ) + + assert final_count in expected, f"Expected rows in {expected} rows, got {final_count}"