diff --git a/openfold/utils/checkpointing.py b/openfold/utils/checkpointing.py index b2bb752cd..3351e6607 100644 --- a/openfold/utils/checkpointing.py +++ b/openfold/utils/checkpointing.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import importlib from typing import Any, Tuple, List, Callable, Optional @@ -34,7 +35,7 @@ def get_checkpoint_fn(): if(deepspeed_is_configured): checkpoint = deepspeed.checkpointing.checkpoint else: - checkpoint = torch.utils.checkpoint.checkpoint + checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) return checkpoint