diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index a2aa7e8de..d269290e4 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -3,6 +3,7 @@ ######################################################################################################## if __name__ == "__main__": + import deepspeed from argparse import ArgumentParser from pytorch_lightning import Trainer from pytorch_lightning.utilities import rank_zero_info, rank_zero_only