-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
63 lines (56 loc) · 2.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageNet, ImageFolder
import torch.nn as nn
from models.clip import Clip
from models.dino import Dino
from models.siglip import Siglip
import os
from transformers import AutoTokenizer, CLIPTextModelWithProjection
def get_collate_fn(processor):
def collate_fn(batch):
images = [img[0] for img in batch]
return processor(images=images, return_tensors="pt", padding=True)
return collate_fn
def get_dataset(args, preprocess, processor, split, subset=1.0):
if args.dataset_name == 'cc3m':
# if subset < 1.0:
# raise NotImplementedError
# return get_cc3m(args, preprocess, split)
raise NotImplementedError
elif args.dataset_name == 'inat_birds':
ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess)
elif args.dataset_name == 'inat':
ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess)
elif args.dataset_name == 'imagenet':
ds = ImageNet(root=args.data_path, split=split, transform=preprocess)
elif args.dataset_name == 'cub':
ds = ImageFolder(root=os.path.join(args.data_path, split), transform=preprocess)
keep_every = int(1.0 / subset)
if keep_every > 1:
ds = Subset(ds, list(range(0, len(ds), keep_every)))
if processor is not None:
dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=get_collate_fn(processor))
else:
dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
return ds, dl
def get_model(args):
if args.model_name.startswith('clip'):
clip = Clip(args.model_name, args.device)
return clip, clip.processor
elif args.model_name.startswith('dino'):
dino = Dino(args.model_name, args.device)
return dino, dino.processor
elif args.model_name.startswith('siglip'):
siglip = Siglip(args.model_name, args.device)
return siglip, siglip.processor
def get_text_model(args):
if args.model_name.startswith('clip'):
model = CLIPTextModelWithProjection.from_pretrained(f"openai/{args.model_name}").to(args.device)
tokenizer = AutoTokenizer.from_pretrained(f"openai/{args.model_name}")
return model, tokenizer
class IdentitySAE(nn.Module):
def encode(self, x):
return x
def decode(self, x):
return x