Skip to content

Commit fcc943d

Browse files
committed
[do not land] multifunction experiments
1 parent 913436a commit fcc943d

File tree

2 files changed

+418
-35
lines changed

2 files changed

+418
-35
lines changed
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import argparse
9+
import os
10+
import shutil
11+
import subprocess
12+
import sys
13+
from pathlib import Path
14+
15+
import coremltools as ct
16+
17+
18+
def extract_models(pte_path: str, output_dir: str) -> list[str]:
19+
"""
20+
Extract CoreML models from a PTE file.
21+
Returns list of paths to extracted .mlpackage files.
22+
"""
23+
# Create output directory
24+
os.makedirs(output_dir, exist_ok=True)
25+
26+
# Run the extraction script
27+
script_path = Path(__file__).parent.parent / "scripts" / "extract_coreml_models.py"
28+
29+
# Save current directory and change to output dir (extract script outputs to cwd)
30+
original_cwd = os.getcwd()
31+
os.chdir(output_dir)
32+
33+
try:
34+
result = subprocess.run(
35+
[sys.executable, str(script_path), "-m", pte_path],
36+
capture_output=True,
37+
text=True
38+
)
39+
if result.returncode != 0:
40+
print(f"Error extracting models: {result.stderr}")
41+
sys.exit(1)
42+
print(result.stdout)
43+
finally:
44+
os.chdir(original_cwd)
45+
46+
# Find extracted mlpackage files
47+
extracted_dir = Path(output_dir) / "extracted_coreml_models"
48+
49+
# Debug: print what we find
50+
print(f" Looking in: {extracted_dir}")
51+
for model_dir in sorted(extracted_dir.iterdir()):
52+
print(f" {model_dir.name}/")
53+
if model_dir.is_dir():
54+
for item in list(model_dir.iterdir())[:10]:
55+
print(f" {item.name}")
56+
57+
model_paths = []
58+
for model_dir in sorted(extracted_dir.iterdir()):
59+
if model_dir.is_dir():
60+
# Look for .mlpackage inside the model directory
61+
found = False
62+
for item in model_dir.iterdir():
63+
if item.suffix == ".mlpackage":
64+
model_paths.append(str(item))
65+
found = True
66+
break
67+
68+
# If no .mlpackage found, check for lowered_module directory
69+
if not found:
70+
lowered_module = model_dir / "lowered_module"
71+
if lowered_module.exists() and lowered_module.is_dir():
72+
# Debug: show contents of lowered_module
73+
print(f" Contents of {lowered_module}:")
74+
for item in list(lowered_module.iterdir())[:10]:
75+
print(f" {item.name}")
76+
77+
# Look for .mlpackage inside lowered_module
78+
for item in lowered_module.iterdir():
79+
if item.suffix == ".mlpackage":
80+
model_paths.append(str(item))
81+
found = True
82+
break
83+
84+
# If still not found, look for model.mlmodel file
85+
if not found:
86+
mlmodel_file = lowered_module / "model.mlmodel"
87+
if mlmodel_file.exists():
88+
# Load and save as mlpackage
89+
mlpackage_path = model_dir / f"{model_dir.name}.mlpackage"
90+
model = ct.models.MLModel(str(mlmodel_file))
91+
model.save(str(mlpackage_path))
92+
model_paths.append(str(mlpackage_path))
93+
found = True
94+
95+
return model_paths
96+
97+
98+
def create_multifunction_model(
99+
prefill_mlpackage: str,
100+
decode_mlpackage: str,
101+
output_path: str,
102+
compile_model: bool
103+
) -> str:
104+
"""
105+
Create a multifunction model combining prefill and decode.
106+
Returns the path to the output model.
107+
"""
108+
desc = ct.utils.MultiFunctionDescriptor()
109+
110+
desc.add_function(
111+
prefill_mlpackage,
112+
src_function_name="main",
113+
target_function_name="prefill"
114+
)
115+
desc.add_function(
116+
decode_mlpackage,
117+
src_function_name="main",
118+
target_function_name="decode"
119+
)
120+
121+
desc.default_function_name = "decode"
122+
123+
if compile_model:
124+
# Save mlpackage first, then compile
125+
mlpackage_path = output_path + ".mlpackage"
126+
ct.utils.save_multifunction(desc, mlpackage_path)
127+
128+
compiled_path = ct.utils.compile_model(mlpackage_path)
129+
dest_path = output_path + ".mlmodelc"
130+
131+
if os.path.exists(dest_path):
132+
shutil.rmtree(dest_path)
133+
shutil.move(compiled_path, dest_path)
134+
135+
# Clean up intermediate mlpackage
136+
shutil.rmtree(mlpackage_path)
137+
138+
print(f"Saved compiled model to {dest_path}")
139+
return dest_path
140+
else:
141+
mlpackage_path = output_path + ".mlpackage"
142+
ct.utils.save_multifunction(desc, mlpackage_path)
143+
print(f"Saved model to {mlpackage_path}")
144+
return mlpackage_path
145+
146+
147+
def main():
148+
parser = argparse.ArgumentParser(
149+
description="Create multifunction CoreML models from prefill/decode PTE files"
150+
)
151+
parser.add_argument(
152+
"--prefill_model",
153+
required=True,
154+
help="Path to the prefill PTE file"
155+
)
156+
parser.add_argument(
157+
"--decode_model",
158+
required=True,
159+
help="Path to the decode PTE file"
160+
)
161+
parser.add_argument(
162+
"--compile",
163+
action="store_true",
164+
default=False,
165+
help="Compile the models to .mlmodelc format"
166+
)
167+
parser.add_argument(
168+
"--output_dir",
169+
default=".",
170+
help="Output directory for the multifunction models (default: current directory)"
171+
)
172+
173+
args = parser.parse_args()
174+
175+
# Create temp directories for extraction
176+
temp_dir = Path(args.output_dir) / "temp_extraction"
177+
prefill_extract_dir = temp_dir / "prefill"
178+
decode_extract_dir = temp_dir / "decode"
179+
180+
print("Extracting prefill models...")
181+
prefill_models = extract_models(args.prefill_model, str(prefill_extract_dir))
182+
print(f"Found {len(prefill_models)} prefill models")
183+
184+
print("Extracting decode models...")
185+
decode_models = extract_models(args.decode_model, str(decode_extract_dir))
186+
print(f"Found {len(decode_models)} decode models")
187+
188+
if len(prefill_models) != len(decode_models):
189+
print(f"Error: Number of prefill models ({len(prefill_models)}) does not match decode models ({len(decode_models)})")
190+
sys.exit(1)
191+
192+
num_models = len(prefill_models)
193+
print(f"\nCreating {num_models} multifunction models...")
194+
195+
# Create multifunction models (mod1, mod2, mod3, ...)
196+
for i in range(num_models):
197+
model_num = i + 1
198+
output_path = str(Path(args.output_dir) / f"mod{model_num}")
199+
200+
print(f"\nCreating mod{model_num}...")
201+
print(f" Prefill: {prefill_models[i]}")
202+
print(f" Decode: {decode_models[i]}")
203+
204+
create_multifunction_model(
205+
prefill_mlpackage=prefill_models[i],
206+
decode_mlpackage=decode_models[i],
207+
output_path=output_path,
208+
compile_model=args.compile
209+
)
210+
211+
# Clean up temp directory
212+
print("\nCleaning up temporary files...")
213+
try:
214+
shutil.rmtree(temp_dir)
215+
except OSError as e:
216+
print(f"Warning: Could not fully clean up temp directory: {e}")
217+
print(f"You may want to manually delete: {temp_dir}")
218+
219+
print("\nDone!")
220+
221+
222+
if __name__ == "__main__":
223+
main()

0 commit comments

Comments
 (0)