Skip to content

Commit fac6332

Browse files
committed
changed dataset
1 parent b9144d4 commit fac6332

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

configs/vae.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exp_params:
2121

2222
trainer_params:
2323
gpus: [1]
24-
max_epochs: 10
24+
max_epochs: 100
2525

2626
logging_params:
2727
save_dir: "logs/"

dataset.py

+39-39
Original file line numberDiff line numberDiff line change
@@ -100,56 +100,56 @@ def __init__(
100100
def setup(self, stage: Optional[str] = None) -> None:
101101
# ========================= OxfordPets Dataset =========================
102102

103-
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
104-
transforms.CenterCrop(self.patch_size),
105-
# transforms.Resize(self.patch_size),
106-
transforms.ToTensor(),
107-
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
108-
109-
val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
110-
transforms.CenterCrop(self.patch_size),
111-
# transforms.Resize(self.patch_size),
112-
transforms.ToTensor(),
113-
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
114-
115-
self.train_dataset = OxfordPets(
116-
self.data_dir,
117-
split='train',
118-
transform=train_transforms,
119-
)
120-
121-
self.val_dataset = OxfordPets(
122-
self.data_dir,
123-
split='val',
124-
transform=val_transforms,
125-
)
126-
127-
# ========================= CelebA Dataset =========================
128-
129103
# train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
130-
# transforms.CenterCrop(148),
131-
# transforms.Resize(self.patch_size),
132-
# transforms.ToTensor(),])
104+
# transforms.CenterCrop(self.patch_size),
105+
# # transforms.Resize(self.patch_size),
106+
# transforms.ToTensor(),
107+
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
133108

134109
# val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
135-
# transforms.CenterCrop(148),
136-
# transforms.Resize(self.patch_size),
137-
# transforms.ToTensor(),])
138-
139-
# self.train_dataset = MyCelebA(
110+
# transforms.CenterCrop(self.patch_size),
111+
# # transforms.Resize(self.patch_size),
112+
# transforms.ToTensor(),
113+
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
114+
115+
# self.train_dataset = OxfordPets(
140116
# self.data_dir,
141117
# split='train',
142118
# transform=train_transforms,
143-
# download=False,
144119
# )
145120

146-
# # Replace CelebA with your dataset
147-
# self.val_dataset = MyCelebA(
121+
# self.val_dataset = OxfordPets(
148122
# self.data_dir,
149-
# split='test',
123+
# split='val',
150124
# transform=val_transforms,
151-
# download=False,
152125
# )
126+
127+
# ========================= CelebA Dataset =========================
128+
129+
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
130+
transforms.CenterCrop(148),
131+
transforms.Resize(self.patch_size),
132+
transforms.ToTensor(),])
133+
134+
val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
135+
transforms.CenterCrop(148),
136+
transforms.Resize(self.patch_size),
137+
transforms.ToTensor(),])
138+
139+
self.train_dataset = MyCelebA(
140+
self.data_dir,
141+
split='train',
142+
transform=train_transforms,
143+
download=False,
144+
)
145+
146+
# Replace CelebA with your dataset
147+
self.val_dataset = MyCelebA(
148+
self.data_dir,
149+
split='test',
150+
transform=val_transforms,
151+
download=False,
152+
)
153153
# ===============================================================
154154

155155
def train_dataloader(self) -> DataLoader:

0 commit comments

Comments
 (0)