diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py
new file mode 100644
index 000000000..f967e6d2b
--- /dev/null
+++ b/examples/llama/2d_llama.py
@@ -0,0 +1,95 @@
+# $ torchrun --nproc-per-node 8 2d_llama.py
+import os
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
+from torch.distributed._tensor import init_device_mesh, DTensor
+from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
+
+
+# We set this flag to true to allow operations on a mix of tensor and dtensor
+# arguments. The mix is a result of `use_local_output=False`
+DTensor._op_dispatcher._allow_implicit_replication = True
+
+
+# Grab the model
+llama = AutoModelForCausalLM.from_pretrained(
+    "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True,
+)
+llama.eval()
+
+tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+prompts = (
+    "How do you", "I like to", "Can I help", "You need to",
+    "The weather is", "I found a", "What is your", "You are so",
+)  # bs = 8
+tokenizer.pad_token = tokenizer.eos_token
+inputs = tokenizer(prompts, return_tensors="pt", padding=True)
+
+rank = int(os.environ["RANK"])
+world_size = int(os.environ["WORLD_SIZE"])
+device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
+
+# Initialize 2D device mesh
+pp_group_size = 2
+tp_group_size = 4
+mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
+pp_group = mesh_2d["pp"].get_group()
+
+# Cut model by equal number of layers per rank
+layers_per_stage = llama.config.num_hidden_layers // pp_group_size
+for i in range(1, pp_group_size):
+    annotate_split_points(llama,
+        {f"model.layers.{i * layers_per_stage}": PipeSplitWrapper.SplitPoint.BEGINNING})
+
+# Create a pipeline representation from the model
+llama_pipe = Pipe.from_tracing(llama, pp_group_size, example_args=(inputs["input_ids"],))
+
+# Create pipeline stage for each rank
+stage_idx = rank // tp_group_size
+stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)
+
+# Tensor parallel
+starting_layer = stage_idx * layers_per_stage
+attn_plan = {}
+mlp_plan = {}
+for i in range(layers_per_stage):
+    # HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper
+    extra = "_mod" if starting_layer > 0 and i == 0 else ""
+    layer_name = f"L__self___model_layers_{starting_layer + i}{extra}"
+    attn_plan.update({
+        # We set `use_local_output` to False to keep the output tensor in
+        # DTensor form, so that it works with the view/reshape operations
+        # without code change.
+        f"{layer_name}_self_attn_q_proj": ColwiseParallel(use_local_output=False),
+        f"{layer_name}_self_attn_k_proj": ColwiseParallel(use_local_output=False),
+        f"{layer_name}_self_attn_v_proj": ColwiseParallel(use_local_output=False),
+        f"{layer_name}_self_attn_o_proj": RowwiseParallel(use_local_output=False),
+    })
+    mlp_plan.update({
+        f"{layer_name}_mlp_gate_proj": ColwiseParallel(),
+        f"{layer_name}_mlp_up_proj": ColwiseParallel(),
+        f"{layer_name}_mlp_down_proj": RowwiseParallel(),
+    })
+tp_mesh = mesh_2d["tp"]
+parallelize_module(
+    stage.submod, tp_mesh, {**attn_plan, **mlp_plan}
+)
+
+# Run
+inputs = inputs.to(device)
+if stage_idx == 0:
+    args = inputs["input_ids"]
+else:
+    args = None
+output = stage(args)
+
+# Decode
+if output is not None:
+    next_token_logits = output[0]
+    if isinstance(next_token_logits, DTensor):
+        # Convert DTensor back to regular tensor
+        next_token_logits = next_token_logits.to_local()
+    next_token_logits = next_token_logits[:, -1, :]
+    next_token = torch.argmax(next_token_logits, dim=-1)
+    print(tokenizer.batch_decode(next_token))
diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py
index fe39a15d6..3efca131d 100644
--- a/pippy/PipelineStage.py
+++ b/pippy/PipelineStage.py
@@ -476,7 +476,11 @@ def _send_activations(
                 )
                 peer_rank = self.stage_index_to_group_rank[dst]
                 work = dist.isend(
-                    out,
+                    # HACK: we convert DTensor to regular tensor here for it to
+                    # work with send ops. DTensor may show up in PP + TP cases.
+                    out.to_local()
+                    if isinstance(out, torch.distributed._tensor.DTensor)
+                    else out,
                     peer_rank
                     if self.group is None
                     else dist.get_global_rank(self.group, peer_rank),  # TODO