Skip to content

Commit

Permalink
removed vissl caching and added filename everywhere in torch.hub.load…
Browse files Browse the repository at this point in the history
…_state_dict_from_url
  • Loading branch information
jonasd4 committed Sep 10, 2024
1 parent 200e047 commit 28faaa1
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,14 @@ def __init__(
device=device,
)

def _download_and_save_model(self, model_url: str,
output_model_filepath: str, unique_model_id: str):
def _load_vissl_state_dict(self, model_url: str, unique_model_filename: str):
"""
Downloads the model in vissl format, converts it to torchvision format and
saves it under output_model_filepath.
"""
model = load_state_dict_from_url(model_url,
map_location=torch.device("cpu"),
file_name=f'{unique_model_id}.pt')
file_name=unique_model_filename)

# get the model trunk to rename
if "classy_state_dict" in model.keys():
Expand All @@ -369,11 +368,10 @@ def _download_and_save_model(self, model_url: str,
model_trunk = model

converted_model = self._replace_module_prefix(model_trunk, "_feature_blocks.")
torch.save(converted_model, output_model_filepath)
return converted_model

def _replace_module_prefix(
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
):
"""
Remove prefixes in a state_dict needed when loading models that are not VISSL
Expand All @@ -394,25 +392,23 @@ def load_model_from_source(self) -> None:
Otherwise, loads it from the cache directory.
"""
if self.model_name in SSLExtractor.MODELS:

# unique model id name for all models
unique_model_filename = f'thingsvision_ssl_v0_{self.model_name}'

# defines how the model should be loaded
model_config = SSLExtractor.MODELS[self.model_name]

if model_config["type"] == "vissl":
cache_dir = os.path.join(get_torch_home(), "vissl")
model_filepath = os.path.join(cache_dir, self.model_name + ".torch")
if not os.path.exists(model_filepath):
os.makedirs(cache_dir, exist_ok=True)
model_state_dict = self._download_and_save_model(
model_url=model_config["url"],
output_model_filepath=model_filepath,
unique_model_id=f'thingsvision_vissl_{self.model_name}'
)
else:
model_state_dict = torch.load(
model_filepath, map_location=torch.device("cpu")
)
model_state_dict = self._load_vissl_state_dict(
model_url=model_config["url"],
unique_model_filename=unique_model_filename
)
self.model = getattr(torchvision.models, model_config["arch"])()
if model_config["arch"] == "resnet50":
self.model.fc = torch.nn.Identity()
self.model.load_state_dict(model_state_dict, strict=True)

elif model_config["type"] == "hub":
if self.model_name.startswith("dino-vit"):
if self.model_name == "dino-vit-tiny-p8":
Expand All @@ -430,7 +426,9 @@ def load_model_from_source(self) -> None:
else:
raise ValueError(f"\n{self.model_name} is not available.\n")
state_dict = torch.hub.load_state_dict_from_url(
model_config["checkpoint_url"]
model_config["checkpoint_url"],
# This is used to cache the file
file_name=unique_model_filename
)
model.load_state_dict(state_dict, strict=True)
self.model = model
Expand All @@ -444,7 +442,8 @@ def load_model_from_source(self) -> None:
else:
raise ValueError(f"\n{self.model_name} is not available.\n")
state_dict = torch.hub.load_state_dict_from_url(
model_config["checkpoint_url"]
model_config["checkpoint_url"],
file_name=unique_model_filename
)
checkpoint_model = state_dict["model"]
# interpolate position embedding
Expand Down

0 comments on commit 28faaa1

Please sign in to comment.