diff --git a/test/test_transformer.py b/test/test_transformer.py
index 899f68c55..cee785ac0 100644
--- a/test/test_transformer.py
+++ b/test/test_transformer.py
@@ -1,6 +1,8 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates
 import torch
 from pippy import annotate_split_points, Pipe, SplitPoint
+import torch.distributed.checkpoint as dcp
+import tempfile
 
 
 d_hid = 16
@@ -66,6 +68,49 @@ def get_layers(module):
     return layers
 
 
+def pipe_to_sd(pipe):
+    sd = {}
+    for stage_idx in range(pipe.num_stages):
+        stage_mod = pipe.get_stage_module(stage_idx)
+        sd[f"stage_{stage_idx}"] = stage_mod
+    return sd
+
+with tempfile.TemporaryDirectory() as tmpdir:
+    #Simulate saving the pipe
+    # Option 1:
+    # for stage_idx in range(pipe.num_stages):
+    #     print(f"Saving pipeline stage {stage_idx}")
+    #     stage_mod = pipe.get_stage_module(stage_idx)
+    #     dcp.save(
+    #         {f"stage_{stage_idx}": stage_mod},
+    #         checkpoint_id=f"{tmpdir}_{stage_idx}"
+    #     )
+    # Option 2:
+    sd = pipe_to_sd(pipe)
+    dcp.save(state_dict, checkpoint_id=tmpdir)
+
+
+    #Simulate loading the pipe
+    # Option 1:
+    # for stage_idx in range(pipe.num_stages):
+    #     print(f"Loading pipeline stage {stage_idx}")
+    #     stage_mod = pipe.get_stage_module(stage_idx)
+    #     dcp.load(
+    #         {f"stage_{stage_idx}": stage_mod},
+    #         checkpoint_id=f"{tmpdir}_{stage_idx}"
+    #     )
+
+    #Option 2:
+    new_pipe = Pipe.from_tracing(
+        transformer,
+        1,
+        (x,),
+    )
+    sd = pipe_to_sd(new_pipe)
+    dcp.load(sd, checkpoint_id=tmpdir)
+
+pipe = new_pipe
+
 # Collect all layers in pipe
 layers = []
 for stage_idx in range(pipe.num_stages):