diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index c79419ff..1e373b79 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -34,8 +34,7 @@ def pretrained_sevenn_model(): """Load a pretrained SevenNet model for testing.""" cp = sevenn.util.load_checkpoint(model_name) - backend = "e3nn" - model_loaded = cp.build_model(backend) + model_loaded = cp.build_model() model_loaded.set_is_batch_data(True) return model_loaded.to(DEVICE)