-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmnist.py
87 lines (68 loc) · 2.42 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.utils.data
import numpy as np
import gzip
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
"""
This is a version of: https://github.com/gpapamak/maf/blob/master/datasets/mnist.py,
adapted to work with Python 3.x and PyTorch.
"""
batch_size = 100
class MNISTDataset:
alpha = 1e-6
class Data:
"""
Constructs the dataset.
"""
def __init__(self, data, logit, dequantize, rng):
x = (
self._dequantize(data[0], rng) if dequantize else data[0]
) # dequantize pixels
self.x = self._logit_transform(x) if logit else x # logit
self.N = self.x.shape[0] # number of datapoints
@staticmethod
def _dequantize(x, rng):
"""
Adds noise to pixels to dequantize them.
"""
return x + rng.rand(*x.shape) / 256.0
@staticmethod
def _logit_transform(x):
"""
Transforms pixel values with logit to be unconstrained.
"""
x = MNISTDataset.alpha + (1 - 2 * MNISTDataset.alpha) * x
return np.log(x / (1.0 - x))
def __init__(self, logit=True, dequantize=True):
root = "datasets/maf_data/"
# load dataset
f = gzip.open(root + "mnist/mnist.pkl.gz", "rb")
train, val, test = pickle.load(f, encoding="latin1")
f.close()
rng = np.random.RandomState(42)
self.train = self.Data(train, logit, dequantize, rng)
self.val = self.Data(val, logit, dequantize, rng)
self.test = self.Data(test, logit, dequantize, rng)
self.n_dims = self.train.x.shape[1]
self.image_size = [int(np.sqrt(self.n_dims))] * 2
def show_pixel_histograms(self, split, pixel=None):
"""
Shows the histogram of pixel values, or of a specific pixel if given.
"""
data_split = getattr(self, split, None)
if data_split is None:
raise ValueError("Invalid data split")
if pixel is None:
data = data_split.x.flatten()
else:
row, col = pixel
idx = row * self.image_size[0] + col
data = data_split.x[:, idx]
n_bins = int(np.sqrt(data_split.N))
fig, ax = plt.subplots(1, 1)
ax.hist(data, n_bins, density=True, color="lightblue")
ax.set_yticklabels("")
ax.set_yticks([])
plt.show()