-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patharchaeological_georgia_biostyle_dataloader.py
138 lines (126 loc) · 5.78 KB
/
archaeological_georgia_biostyle_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# @Time : 2023/4/22 19:22
# @Author : Li Jiaqi
# @Description :
from torch.utils.data import Dataset
import glob
import os
from skimage.io import imread
from torch.utils.data import DataLoader
from collections import OrderedDict
import numpy as np
COLOR_MAP = OrderedDict(
Background=(255, 255, 255),
Building=(255, 0, 0),
Road=(255, 255, 0),
Water=(0, 0, 255),
Barren=(159, 129, 183),
Forest=(0, 255, 0),
Agricultural=(255, 195, 128),
)
LABEL_MAP = OrderedDict(
Background=0,
Building=1,
Road=2,
Water=3,
Barren=4,
Forest=5,
Agricultural=6
)
class SitesBingBook(Dataset):
def __init__(self, data_dir, mask_dir, transforms=None, has_bing=True, has_book=False, has_mask=True):
self.data_dir = data_dir
self.mask_dir = mask_dir
self.has_bing = has_bing
self.has_book = has_book
self.has_mask = has_mask
self.id_list = []
self.unlabeled=False
png_list = glob.glob(os.path.join(data_dir, '*.png'))
for fp in png_list:
if 'mask' not in fp and len(os.path.split(fp)[-1])>8:
self.id_list.append(os.path.split(fp)[-1][:-8])
elif 'mask' not in fp:
self.id_list.append(os.path.split(fp)[-1][:-4])
self.unlabeled =True
self.transforms = transforms
def __getitem__(self, idx):
file_id = self.id_list[idx]
bing_file_name = file_id + 'bing.png' if not self.unlabeled else file_id + '.png'
book_file_name = file_id + 'book.jpg'
bing_mask_name = file_id + 'bing_mask.png'
book_mask_name = file_id + 'book_mask.png'
bing_image, bing_mask_image, book_image, book_mask_image = [], [], [], []
if self.has_bing:
bing_image = imread(os.path.join(self.data_dir, bing_file_name))
bing_image = bing_image[:-23, :, 0:3] # delete the alpha dimension in png files and bing flag
if self.has_mask:
bing_mask_image = imread(os.path.join(self.mask_dir, bing_mask_name))
bing_mask_image = bing_mask_image[:-23, :, 0:3] # delete the alpha dimension in png files and bing flag
if self.has_book:
book_image = imread(os.path.join(self.data_dir, book_file_name))
book_image = book_image[:-75, :] # delete the book flag
if self.has_mask:
book_mask_image = imread(os.path.join(self.mask_dir, book_mask_name))
book_mask_image = book_mask_image[:-75, :] # delete the book flag
# 2d black and white book images to 3d images
if self.has_book and len(book_image.shape) <= 2:
new_book_image = np.zeros(dtype=np.uint8, shape=(book_image.shape[0], book_image.shape[1], 3))
new_book_image[:, :, 0] = book_image * 255
new_book_image[:, :, 1] = book_image * 255
new_book_image[:, :, 2] = book_image * 255
book_image = new_book_image
elif self.has_book and book_image.dtype != np.uint8:
book_image = book_image * 255
# normalize and resize
if self.transforms is not None:
if self.has_mask:
if self.has_bing:
bing_mask_image = bing_mask_image[:, :, 0]
blob = self.transforms(image=bing_image, mask=bing_mask_image)
bing_image = blob['image']
bing_mask_image = blob['mask']
bing_mask_image = (bing_mask_image - np.min(bing_mask_image)) / (
np.max(bing_mask_image) - np.min(bing_mask_image))
if self.has_book:
book_mask_image = book_mask_image[:, :, 0]
blob = self.transforms(image=book_image, mask=book_mask_image)
book_image = blob['image']
book_mask_image = blob['mask']
book_mask_image = (book_mask_image - np.min(book_mask_image)) / (
np.max(book_mask_image) - np.min(book_mask_image))
else:
if self.has_bing:
blob = self.transforms(image=bing_image)
bing_image = blob['image']
if self.has_book:
blob = self.transforms(image=book_image)
book_image = blob['image']
# to C,W,H
if self.has_bing:
bing_image = np.rollaxis(bing_image, 2, 0)
if self.has_book:
book_image = np.rollaxis(book_image, 2, 0)
return bing_image, bing_mask_image, book_image, book_mask_image
def __len__(self):
return len(self.id_list)
class SitesLoader(DataLoader):
def __init__(self, config, dataset=None, flag="train"):
self.config = config
self.flag = flag
if dataset:
pass
elif flag == 'train':
dataset = SitesBingBook(self.config["dataset"], self.config["maskdir"], self.config["transforms"])
elif flag == 'unlabeled':
dataset = SitesBingBook(self.config["unlabeledset"], None, self.config["transforms"], has_mask=False)
elif flag=='pseudo':
dataset = SitesBingBook(self.config["pseudoset"], None, self.config["transforms"], has_mask=False)
else:
dataset = SitesBingBook(self.config["evalset"], self.config["maskdir"], self.config["transforms"])
super(SitesLoader, self).__init__(dataset,
batch_size=self.config['batch_size'],
num_workers=self.config['num_workers'],
shuffle=self.config['shuffle'],
pin_memory=self.config['pin_memory'],
drop_last=self.config['drop_last']
)