Throughout the API we often take parameters like kT, dt, and alpha. In some cases, these are typed as float | torch.Tensor, in others they are just torch.Tensor. We should:
- standardize towards
float | torch.Tensor in init and step functions.
- Use
x = torch.as_tensor(x) instead of the if statement often used