Skip to content

Commit f060414

Browse files
committed
Simplify random seed in epoch data for reproducibility
1 parent eab4770 commit f060414

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/nanotron/data/nanoset.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,13 @@ def build_nanoset_index(self) -> np.ndarray:
113113
)
114114

115115
# Shuffle indices in each epoch with different random seeds and concatenate them
116-
r = np.random.RandomState(self.random_seed)
117-
epoch_random_seeds = r.randint(0, 2**32 - 1, num_epochs)
118116
dataset_indices = []
119117
dataset_sample_indices = []
120-
for i in range(num_epochs):
118+
for num_epoch in range(num_epochs):
121119
# Shuffle the sample and dataset indices in epoch with a given seed
122-
numpy_random_state = np.random.RandomState(epoch_random_seeds[i])
120+
numpy_random_state = np.random.RandomState(self.random_seed + num_epoch)
123121
numpy_random_state.shuffle(dataset_index)
124-
numpy_random_state = np.random.RandomState(epoch_random_seeds[i])
122+
numpy_random_state = np.random.RandomState(self.random_seed + num_epoch)
125123
numpy_random_state.shuffle(dataset_sample_index)
126124

127125
dataset_indices.append(dataset_index)

0 commit comments

Comments
 (0)