5
5
import h5py
6
6
import numpy as np
7
7
import scipy .io
8
+ from selene_sdk .samplers .samples_batch import SamplesBatch
8
9
9
10
from .file_sampler import FileSampler
10
11
@@ -126,8 +127,8 @@ def sample(self, batch_size=1):
126
127
127
128
Returns
128
129
-------
129
- sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
130
- A tuple containing the numeric representation of the
130
+ SamplesBatch
131
+ A batch containing the numeric representation of the
131
132
sequence examples and their corresponding labels. The
132
133
shape of `sequences` will be
133
134
:math:`B \\ times L \\ times N`, where :math:`B` is
@@ -166,8 +167,8 @@ def sample(self, batch_size=1):
166
167
targets = self ._sample_tgts [:, use_indices ].astype (float )
167
168
targets = np .transpose (
168
169
targets , (1 , 0 ))
169
- return (sequences , targets )
170
- return sequences ,
170
+ return SamplesBatch (sequences , target_batch = targets )
171
+ return SamplesBatch ( sequences )
171
172
172
173
def get_data (self , batch_size , n_samples = None ):
173
174
"""
@@ -190,18 +191,20 @@ def get_data(self, batch_size, n_samples=None):
190
191
is `batch_size`, :math:`L` is the sequence length,
191
192
and :math:`N` is the size of the sequence type's alphabet.
192
193
"""
194
+ # TODO: Should this method return a collection of samples_batch.inputs()?
195
+
193
196
if not n_samples :
194
197
n_samples = self .n_samples
195
198
sequences = []
196
199
197
200
count = batch_size
198
201
while count < n_samples :
199
- seqs , = self .sample (batch_size = batch_size )
200
- sequences .append (seqs )
202
+ samples_batch = self .sample (batch_size = batch_size )
203
+ sequences .append (samples_batch . sequence_batch () )
201
204
count += batch_size
202
205
remainder = batch_size - (count - n_samples )
203
- seqs , = self .sample (batch_size = remainder )
204
- sequences .append (seqs )
206
+ samples_batch = self .sample (batch_size = remainder )
207
+ sequences .append (samples_batch . sequence_batch () )
205
208
return sequences
206
209
207
210
def get_data_and_targets (self , batch_size , n_samples = None ):
@@ -218,11 +221,11 @@ def get_data_and_targets(self, batch_size, n_samples=None):
218
221
219
222
Returns
220
223
-------
221
- sequences_and_targets , targets_matrix : \
222
- tuple(list(tuple(numpy.ndarray, numpy.ndarray) ), numpy.ndarray)
223
- Tuple containing the list of sequence-target pairs , as well
224
+ batches , targets_matrix : \
225
+ tuple(list(SamplesBatch ), numpy.ndarray)
226
+ Tuple containing the list of batches , as well
224
227
as a single matrix with all targets in the same order.
225
- Note that `sequences_and_targets `'s sequence elements are of
228
+ Note that `batches `'s sequence elements are of
226
229
the shape :math:`B \\ times L \\ times N` and its target
227
230
elements are of the shape :math:`B \\ times F`, where
228
231
:math:`B` is `batch_size`, :math:`L` is the sequence length,
@@ -237,19 +240,19 @@ def get_data_and_targets(self, batch_size, n_samples=None):
237
240
"initialization. Please use `get_data` instead." )
238
241
if not n_samples :
239
242
n_samples = self .n_samples
240
- sequences_and_targets = []
243
+ batches = []
241
244
targets_mat = []
242
245
243
246
count = batch_size
244
247
while count < n_samples :
245
- seqs , tgts = self .sample (batch_size = batch_size )
246
- sequences_and_targets .append (( seqs , tgts ) )
247
- targets_mat .append (tgts )
248
+ samples_batch = self .sample (batch_size = batch_size )
249
+ batches .append (samples_batch )
250
+ targets_mat .append (samples_batch . targets () )
248
251
count += batch_size
249
252
remainder = batch_size - (count - n_samples )
250
- seqs , tgts = self .sample (batch_size = remainder )
251
- sequences_and_targets .append (( seqs , tgts ) )
252
- targets_mat .append (tgts )
253
+ samples_batch = self .sample (batch_size = remainder )
254
+ batches .append (samples_batch )
255
+ targets_mat .append (samples_batch . targets () )
253
256
# TODO: should not assume targets are always integers
254
257
targets_mat = np .vstack (targets_mat ).astype (float )
255
- return sequences_and_targets , targets_mat
258
+ return batches , targets_mat
0 commit comments