forked from msminhas93/DeepLabv3FineTuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsegdataset.py
124 lines (112 loc) · 5.56 KB
/
segdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Author: Manpreet Singh Minhas
Contact: msminhas at uwaterloo ca
"""
from pathlib import Path
from typing import Any, Callable, Optional
import numpy as np
from PIL import Image
from torchvision.datasets.vision import VisionDataset
import torch
class SegmentationDataset(VisionDataset):
"""A PyTorch dataset for image segmentation task.
The dataset is compatible with torchvision transforms.
The transforms passed would be applied to both the Images and Masks.
"""
def __init__(self,
root: str,
image_folder: str,
mask_folder: str,
transforms: Optional[Callable] = None,
seed: int = None,
fraction: float = None,
subset: str = None,
image_color_mode: str = "rgb",
mask_color_mode: str = "grayscale") -> None:
"""
Args:
root (str): Root directory path.
image_folder (str): Name of the folder that contains the images in the root directory.
mask_folder (str): Name of the folder that contains the masks in the root directory.
transforms (Optional[Callable], optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.ToTensor`` for images. Defaults to None.
seed (int, optional): Specify a seed for the train and test split for reproducible results. Defaults to None.
fraction (float, optional): A float value from 0 to 1 which specifies the validation split fraction. Defaults to None.
subset (str, optional): 'Train' or 'Test' to select the appropriate set. Defaults to None.
image_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'rgb'.
mask_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'grayscale'.
Raises:
OSError: If image folder doesn't exist in root.
OSError: If mask folder doesn't exist in root.
ValueError: If subset is not either 'Train' or 'Test'
ValueError: If image_color_mode and mask_color_mode are either 'rgb' or 'grayscale'
"""
super().__init__(root, transforms)
image_folder_path = Path(self.root) / image_folder
mask_folder_path = Path(self.root) / mask_folder
if not image_folder_path.exists():
raise OSError(f"{image_folder_path} does not exist.")
if not mask_folder_path.exists():
raise OSError(f"{mask_folder_path} does not exist.")
if image_color_mode not in ["rgb", "grayscale"]:
raise ValueError(
f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale."
)
if mask_color_mode not in ["rgb", "grayscale"]:
raise ValueError(
f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale."
)
self.image_color_mode = image_color_mode
self.mask_color_mode = mask_color_mode
if not fraction:
self.image_names = sorted(image_folder_path.glob("*"))
self.mask_names = sorted(mask_folder_path.glob("*"))
else:
if subset not in ["Train", "Test"]:
raise (ValueError(
f"{subset} is not a valid input. Acceptable values are Train and Test."
))
self.fraction = fraction
self.image_list = np.array(sorted(image_folder_path.glob("*")))
self.mask_list = np.array(sorted(mask_folder_path.glob("*")))
if seed:
np.random.seed(seed)
indices = np.arange(len(self.image_list))
np.random.shuffle(indices)
self.image_list = self.image_list[indices]
self.mask_list = self.mask_list[indices]
if subset == "Train":
self.image_names = self.image_list[:int(
np.ceil(len(self.image_list) * (1 - self.fraction)))]
self.mask_names = self.mask_list[:int(
np.ceil(len(self.mask_list) * (1 - self.fraction)))]
else:
self.image_names = self.image_list[
int(np.ceil(len(self.image_list) * (1 - self.fraction))):]
self.mask_names = self.mask_list[
int(np.ceil(len(self.mask_list) * (1 - self.fraction))):]
def __len__(self) -> int:
return len(self.image_names)
def __getitem__(self, index: int) -> Any:
image_path = self.image_names[index]
mask_path = self.mask_names[index]
with open(image_path, "rb") as image_file, open(mask_path,
"rb") as mask_file:
image = Image.open(image_file)
if self.image_color_mode == "rgb":
image = image.convert("RGB")
elif self.image_color_mode == "grayscale":
image = image.convert("L")
mask = Image.open(mask_file)
if self.mask_color_mode == "rgb":
mask = mask.convert("RGB")
elif self.mask_color_mode == "grayscale":
mask = mask.convert("L")
sample = {"image": image, "mask": mask}
if self.transforms:
sample["image"] = self.transforms(sample["image"])
#sample["mask"] = self.transforms(sample["mask"])
sample["mask"] = torch.tensor(np.array(sample["mask"] , dtype=np.int64) , dtype=torch.long) # Use dtype=torch.long for class indices
return sample
return sample