Skip to content

Commit d7e7ca0

Browse files
author
jhlu
committed
feat(vision): add detection local zoo CLI
1 parent 5e0ed24 commit d7e7ca0

3 files changed

Lines changed: 365 additions & 0 deletions

File tree

dlhub/vision/detection_zoo.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from dataclasses import dataclass
5+
from pathlib import Path
6+
7+
from torch import nn
8+
9+
10+
@dataclass(frozen=True)
11+
class BuildConfig:
12+
in_channels: int
13+
num_classes: int
14+
width_mult: float = 1.0
15+
16+
17+
class UnknownLocalArch(KeyError):
18+
pass
19+
20+
21+
def _split_arch_id(arch_id: str) -> tuple[str, str]:
22+
arch_id = str(arch_id).strip()
23+
if ":" not in arch_id:
24+
return "dldet", arch_id
25+
26+
prefix, name = arch_id.split(":", 1)
27+
prefix = prefix.strip().lower()
28+
name = name.strip()
29+
if not prefix or not name:
30+
raise ValueError(f"Invalid arch id: {arch_id!r}")
31+
return prefix, name
32+
33+
34+
Builder = Callable[[BuildConfig], nn.Module]
35+
36+
37+
def _extract_variants_from_source(src: str) -> list[str] | None:
38+
"""Extract `_VARIANTS` keys from a module source without importing it."""
39+
40+
import ast
41+
42+
try:
43+
tree = ast.parse(src)
44+
except SyntaxError:
45+
return None
46+
47+
for node in tree.body:
48+
target_name: str | None = None
49+
value = None
50+
51+
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
52+
target_name = node.target.id
53+
value = node.value
54+
elif isinstance(node, ast.Assign):
55+
for t in node.targets:
56+
if isinstance(t, ast.Name):
57+
target_name = t.id
58+
break
59+
value = node.value
60+
61+
if target_name != "_VARIANTS" or not isinstance(value, ast.Dict):
62+
continue
63+
64+
keys: list[str] = []
65+
for k in value.keys:
66+
if isinstance(k, ast.Constant) and isinstance(k.value, str):
67+
keys.append(k.value)
68+
return keys or None
69+
70+
return None
71+
72+
73+
def _extract_builder_name_from_source(src: str) -> str | None:
74+
"""Extract the first `build_*_detector` function name without importing."""
75+
76+
import ast
77+
78+
try:
79+
tree = ast.parse(src)
80+
except SyntaxError:
81+
return None
82+
83+
for node in tree.body:
84+
if not isinstance(node, ast.FunctionDef):
85+
continue
86+
name = str(node.name)
87+
if name.startswith("build_") and name.endswith("_detector"):
88+
return name
89+
return None
90+
91+
92+
def _make_lazy_detector_builder(module_name: str, *, builder_name: str, variant: str) -> Builder:
93+
module_name = str(module_name).strip()
94+
builder_name = str(builder_name).strip()
95+
variant = str(variant).strip()
96+
97+
def _builder(cfg: BuildConfig) -> nn.Module:
98+
import importlib
99+
import inspect
100+
101+
mod = importlib.import_module(f"dlhub.vision.detection.{module_name}")
102+
fn = getattr(mod, builder_name, None)
103+
if fn is None:
104+
raise RuntimeError(f"Detection module {module_name!r} missing {builder_name}()")
105+
106+
kwargs: dict[str, object] = {
107+
"in_channels": int(cfg.in_channels),
108+
"num_classes": int(cfg.num_classes),
109+
"variant": str(variant),
110+
"width_mult": float(cfg.width_mult),
111+
}
112+
113+
try:
114+
sig = inspect.signature(fn)
115+
except (TypeError, ValueError):
116+
sig = None
117+
118+
if sig is not None:
119+
params = set(sig.parameters)
120+
kwargs = {k: v for k, v in kwargs.items() if k in params}
121+
122+
return fn(**kwargs)
123+
124+
return _builder
125+
126+
127+
def _extend_registry_with_discovered_detectors(r: dict[str, Builder]) -> None:
128+
"""Discover detector variants under `dlhub/vision/detection/*.py`."""
129+
130+
here = Path(__file__).resolve().parent
131+
det_dir = here / "detection"
132+
133+
if not det_dir.exists():
134+
return
135+
136+
hidden = {"__init__"}
137+
138+
for py in sorted(det_dir.glob("*.py")):
139+
module_name = py.stem
140+
if module_name in hidden or module_name.startswith("_"):
141+
continue
142+
143+
try:
144+
src = py.read_text(encoding="utf-8")
145+
except OSError:
146+
continue
147+
148+
if "_VARIANTS" not in src or "def build_" not in src:
149+
continue
150+
151+
variants = _extract_variants_from_source(src)
152+
if not variants:
153+
continue
154+
155+
builder_name = _extract_builder_name_from_source(src)
156+
if builder_name is None:
157+
continue
158+
159+
for v in variants:
160+
name = str(v).lower().strip()
161+
if not name or name in r:
162+
continue
163+
r[name] = _make_lazy_detector_builder(module_name, builder_name=builder_name, variant=name)
164+
165+
166+
def _registry() -> dict[str, Builder]:
167+
r: dict[str, Builder] = {}
168+
_extend_registry_with_discovered_detectors(r)
169+
return r
170+
171+
172+
_REGISTRY = _registry()
173+
174+
175+
def list_local_arches() -> list[str]:
176+
"""List all available local detection architecture ids (e.g. `dldet:ssd_tiny`)."""
177+
178+
return [f"dldet:{name}" for name in sorted(_REGISTRY)]
179+
180+
181+
def build_local_model(
182+
arch_id: str,
183+
*,
184+
in_channels: int,
185+
num_classes: int,
186+
width_mult: float = 1.0,
187+
) -> nn.Module:
188+
"""Build a local detection model by architecture id.
189+
190+
Architecture ids are variants extracted from each detector module's `_VARIANTS`,
191+
namespaced with `dldet:` (e.g. `dldet:ssd_tiny`).
192+
"""
193+
194+
prefix, name = _split_arch_id(arch_id)
195+
if prefix == "det":
196+
prefix = "dldet"
197+
if prefix not in {"dldet", "local"}:
198+
raise ValueError(f"Unsupported detection prefix: {prefix!r} (arch_id={arch_id!r})")
199+
200+
builder = _REGISTRY.get(str(name).lower().strip())
201+
if builder is None:
202+
raise UnknownLocalArch(f"Unknown detection arch: {arch_id!r}. Tip: run `python scripts/detection_zoo.py --list`.")
203+
204+
return builder(
205+
BuildConfig(
206+
in_channels=int(in_channels),
207+
num_classes=int(num_classes),
208+
width_mult=float(width_mult),
209+
)
210+
)
211+
212+
213+
__all__ = [
214+
"BuildConfig",
215+
"UnknownLocalArch",
216+
"build_local_model",
217+
"list_local_arches",
218+
]
219+

scripts/detection_zoo.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import sys
5+
from collections.abc import Iterable
6+
from pathlib import Path
7+
8+
9+
def _ensure_repo_root_on_path() -> None:
10+
repo_root = Path(__file__).resolve().parents[1]
11+
sys.path.insert(0, str(repo_root))
12+
13+
14+
def _summarize(obj) -> str:
15+
try:
16+
import torch
17+
except Exception:
18+
return f"{type(obj).__name__}"
19+
20+
if isinstance(obj, torch.Tensor):
21+
return f"Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
22+
if isinstance(obj, dict):
23+
keys = ", ".join(sorted(map(str, obj.keys())))
24+
return f"dict(keys=[{keys}])"
25+
if isinstance(obj, list | tuple):
26+
head = ", ".join(_summarize(x) for x in obj[:2])
27+
tail = "" if len(obj) <= 2 else f", ... (+{len(obj) - 2})"
28+
return f"{type(obj).__name__}([{head}{tail}])"
29+
return f"{type(obj).__name__}"
30+
31+
32+
def _print_lines(lines: Iterable[str], *, limit: int = 80) -> None:
33+
lines = list(lines)
34+
if len(lines) <= limit:
35+
for line in lines:
36+
print(line)
37+
return
38+
39+
head = max(10, limit - 10)
40+
for line in lines[:head]:
41+
print(line)
42+
print(f"... ({len(lines) - limit} more) ...")
43+
for line in lines[-10:]:
44+
print(line)
45+
46+
47+
def parse_args() -> argparse.Namespace:
48+
parser = argparse.ArgumentParser(description="Detection local model zoo utilities (no downloads).")
49+
50+
parser.add_argument("--list", action="store_true", help="List available architecture ids.")
51+
parser.add_argument("--search", type=str, default=None, help="Filter list by substring (case-insensitive).")
52+
parser.add_argument("--limit", type=int, default=80, help="Max lines to print when listing.")
53+
54+
parser.add_argument("--smoke", type=str, default=None, metavar="ARCH_ID", help="Run a forward smoke on an arch id.")
55+
parser.add_argument("--batch-size", type=int, default=2, help="Batch size for smoke inputs.")
56+
parser.add_argument("--image-size", type=int, default=64, help="Image size for smoke inputs.")
57+
parser.add_argument("--in-channels", type=int, default=3, help="Input channels for local detectors.")
58+
parser.add_argument("--num-classes", type=int, default=2, help="Detector classes.")
59+
parser.add_argument("--width-mult", type=float, default=1.0, help="Width multiplier for local detectors.")
60+
61+
return parser.parse_args()
62+
63+
64+
def main() -> int:
65+
_ensure_repo_root_on_path()
66+
67+
from dlhub.vision.detection_zoo import build_local_model, list_local_arches
68+
69+
args = parse_args()
70+
71+
if not args.list and args.smoke is None:
72+
print("Nothing to do. Try one of:")
73+
print("- python scripts/detection_zoo.py --list")
74+
print("- python scripts/detection_zoo.py --smoke dldet:ssd_tiny")
75+
return 2
76+
77+
arches = list_local_arches()
78+
if args.search:
79+
needle = str(args.search).lower()
80+
arches = [a for a in arches if needle in a.lower()]
81+
82+
if args.list:
83+
print("Detection local zoo")
84+
print(f"- total_arches={len(arches)}")
85+
print("")
86+
_print_lines(arches, limit=int(args.limit))
87+
88+
if args.smoke is not None:
89+
arch_id = str(args.smoke).strip()
90+
if ":" not in arch_id:
91+
arch_id = f"dldet:{arch_id}"
92+
93+
import torch
94+
95+
x = torch.randn(int(args.batch_size), int(args.in_channels), int(args.image_size), int(args.image_size))
96+
model = build_local_model(
97+
arch_id,
98+
in_channels=int(args.in_channels),
99+
num_classes=int(args.num_classes),
100+
width_mult=float(args.width_mult),
101+
)
102+
model.eval()
103+
with torch.no_grad():
104+
out = model(x)
105+
106+
print("")
107+
print(f"smoke: {arch_id}")
108+
print(f"- output={_summarize(out)}")
109+
110+
return 0
111+
112+
113+
if __name__ == "__main__":
114+
raise SystemExit(main())
115+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
3+
4+
torch = pytest.importorskip("torch")
5+
6+
7+
def _sum_tensor_means(x):
8+
if torch.is_tensor(x):
9+
return x.to(torch.float32).mean()
10+
if isinstance(x, dict):
11+
return sum((_sum_tensor_means(v) for v in x.values()), start=torch.tensor(0.0))
12+
if isinstance(x, (list, tuple)):
13+
return sum((_sum_tensor_means(v) for v in x), start=torch.tensor(0.0))
14+
raise TypeError(f"Unsupported output type in detection zoo smoke: {type(x)!r}")
15+
16+
17+
def test_detection_zoo_list_and_build_smoke() -> None:
18+
from dlhub.vision.detection_zoo import build_local_model, list_local_arches
19+
20+
arches = list_local_arches()
21+
assert "dldet:ssd_tiny" in arches
22+
assert "dldet:detr_tiny" in arches
23+
24+
for arch_id in ["dldet:ssd_tiny", "dldet:detr_tiny", "dldet:yolo_v1_tiny"]:
25+
model = build_local_model(arch_id, in_channels=3, num_classes=2, width_mult=0.5)
26+
x = torch.randn(2, 3, 64, 64)
27+
out = model(x)
28+
loss = _sum_tensor_means(out)
29+
assert torch.isfinite(loss)
30+
loss.backward()
31+

0 commit comments

Comments
 (0)