Skip to content

Conversation

@Dev-Sudarshan
Copy link

This PR adds optional GPU support to L-C2ST by introducing a PyTorch-based
MLP classifier implemented via skorch. This addresses issue #1160 .

Changes:

  • Add a PyTorch MLP with skorch to support GPU training
  • Preserve sklearn-like defaults and training dynamics
  • Add device handling (cpu / cuda)
  • Support user overrides via classifier_kwargs while preserving sbi defaults
  • Extend tests to cover:
    • GPU/CPU device placement
    • Default parameter behavior
    • User override merging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant