|
| 1 | +import numpy as np |
| 2 | +import keras |
| 3 | + |
| 4 | + |
| 5 | +class DataGenerator(keras.utils.Sequence): |
| 6 | + 'Generates data for Keras' |
| 7 | + def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1, |
| 8 | + n_classes=10, shuffle=True): |
| 9 | + 'Initialization' |
| 10 | + self.dim = dim |
| 11 | + self.batch_size = batch_size |
| 12 | + self.labels = labels |
| 13 | + self.list_IDs = list_IDs |
| 14 | + self.n_channels = n_channels |
| 15 | + self.n_classes = n_classes |
| 16 | + self.shuffle = shuffle |
| 17 | + self.on_epoch_end() |
| 18 | + |
| 19 | + def __len__(self): |
| 20 | + 'Denotes the number of batches per epoch' |
| 21 | + return int(np.floor(len(self.list_IDs) / self.batch_size)) |
| 22 | + |
| 23 | + def __getitem__(self, index): |
| 24 | + 'Generate one batch of data' |
| 25 | + # Generate indexes of the batch |
| 26 | + indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] |
| 27 | + |
| 28 | + # Find list of IDs |
| 29 | + list_IDs_temp = [self.list_IDs[k] for k in indexes] |
| 30 | + |
| 31 | + # Generate data |
| 32 | + X, y = self.__data_generation(list_IDs_temp) |
| 33 | + |
| 34 | + return X, y |
| 35 | + |
| 36 | + def on_epoch_end(self): |
| 37 | + 'Updates indexes after each epoch' |
| 38 | + self.indexes = np.arange(len(self.list_IDs)) |
| 39 | + if self.shuffle == True: |
| 40 | + np.random.shuffle(self.indexes) |
| 41 | + |
| 42 | + def __data_generation(self, list_IDs_temp): |
| 43 | + 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) |
| 44 | + # Initialization |
| 45 | + X = np.empty((self.batch_size, *self.dim, self.n_channels)) |
| 46 | + y = np.empty((self.batch_size), dtype=int) |
| 47 | + |
| 48 | + # Generate data |
| 49 | + for i, ID in enumerate(list_IDs_temp): |
| 50 | + # Store sample |
| 51 | + X[i,] = np.load('data/' + ID + '.npy') |
| 52 | + |
| 53 | + # Store class |
| 54 | + y[i] = self.labels[ID] |
| 55 | + |
| 56 | + return X, keras.utils.to_categorical(y, num_classes=self.n_classes) |
0 commit comments