From 7d6464147f3830d3ea4d8011039aaa2fc24278f0 Mon Sep 17 00:00:00 2001 From: mdasadul Date: Thu, 22 Sep 2016 11:49:22 -0700 Subject: [PATCH 1/2] updating preprocesor-shards.py to support additional input features --- preprocess-shards.py | 308 +++++++++++++++++++++++++++++++------------ 1 file changed, 221 insertions(+), 87 deletions(-) diff --git a/preprocess-shards.py b/preprocess-shards.py index 4859799..bd90346 100644 --- a/preprocess-shards.py +++ b/preprocess-shards.py @@ -25,7 +25,7 @@ def add_w(self, ws): for w in ws: if w not in self.d: self.d[w] = len(self.d) + 1 - + def convert(self, w): return self.d[w] if w in self.d else self.d[self.UNK] @@ -38,13 +38,16 @@ def clean(self, s): s = s.replace(self.BOS, "") s = s.replace(self.EOS, "") return s - - def write(self, outfile): + + def write(self, outfile, chars=1): out = open(outfile, "w") items = [(v, k) for k, v in self.d.iteritems()] items.sort() for v, k in items: - print >>out, k.encode('utf-8'), v + if chars == 1: + print >>out, k.encode('utf-8'), v + else: + print >>out, k, v out.close() def prune_vocab(self, k): @@ -55,62 +58,145 @@ def prune_vocab(self, k): for word in self.pruned_vocab: if word not in self.d: self.d[word] = len(self.d) + 1 - + + ###May be this is not required + def load_vocab(self, vocab_file, chars=0): + self.d = {} + for line in open(vocab_file, 'r'): + if chars == 1: + v, k = line.decode("utf-8").strip().split() + else: + v, k = line.strip().split() + self.d[v] = int(k) + def pad(ls, length, symbol): if len(ls) >= length: return ls[:length] return ls + [symbol] * (length -len(ls)) - + +def save_features(name, indexers, outputfile): + if len(indexers) > 0: + print("Number of additional features on {} side: {}".format(name, len(indexers))) + for i in range(len(indexers)): + indexers[i].write(outputfile + "." + name + "_feature_" + str(i+1) + ".dict", ) + print(" * {} feature {} of size: {}".format(name, i+1, len(indexers[i].d))) + +def load_features(name, indexers, outputfile): + for i in range(len(indexers)): + indexers[i].load_vocab(outputfile + "." + name + "_feature_" + str(i+1) + ".dict", ) + print(" * {} feature {} of size: {}".format(name, i+1, len(indexers[i].d))) + def get_data(args): -# src_indexer = Indexer(["`","|","@","^"]) -# target_indexer = Indexer(["`","|","@","^"]) src_indexer = Indexer(["","","",""]) - target_indexer = Indexer(["","","",""]) + src_feature_indexers = [] + target_indexer = Indexer(["","","",""]) char_indexer = Indexer(["","","{","}"]) char_indexer.add_w([src_indexer.PAD, src_indexer.UNK, src_indexer.BOS, src_indexer.EOS]) - - def make_vocab(srcfile, targetfile, seqlength, max_word_l=0, chars=0): + + def init_feature_indexers(indexers, count): + for i in range(count): + indexers.append(Indexer(["","","",""])) + + def load_sentence(sent, indexers): + sent_seq = sent.strip().split() + sent_words = '' + sent_features = [] + + for entry in sent_seq: + fields = entry.split('-|-') + word = fields[0] + sent_words += (' ' if sent_words else '') + word + + if len(fields) > 1: + count = len(fields) - 1 + if len(sent_features) == 0: + sent_features = [ [] for i in range(count) ] + if len(indexers) == 0: + init_feature_indexers(indexers, count) + for i in range(1, len(fields)): + sent_features[i-1].append(fields[i]) + + return sent_words, sent_features + + def add_features_vocab(orig_features, indexers): + if len(indexers) > 0: + index = 0 + for value in orig_features: + indexers[index].add_w(value) + index += 1 + + def make_vocab(srcfile, targetfile, seqlength, max_word_l=0, chars=0, train=1): num_sents = 0 for _, (src_orig, targ_orig) in \ enumerate(itertools.izip(open(srcfile,'r'), open(targetfile,'r'))): - src_orig = src_indexer.clean(src_orig.decode("utf-8").strip()) - targ_orig = target_indexer.clean(targ_orig.decode("utf-8").strip()) + src_orig, src_orig_features = load_sentence(src_orig, src_feature_indexers) + if chars == 1: + src_orig = src_indexer.clean(src_orig.decode("utf-8").strip()) + targ_orig = target_indexer.clean(targ_orig.decode("utf-8").strip()) + else: + src_orig = src_indexer.clean(src_orig.strip()) + targ_orig = target_indexer.clean(targ_orig.strip()) targ = targ_orig.strip().split() src = src_orig.strip().split() if len(targ) > seqlength or len(src) > seqlength or len(targ) < 1 or len(src) < 1: continue num_sents += 1 - for word in targ: - if chars == 1: - word = char_indexer.clean(word) - if len(word) == 0: - continue - max_word_l = max(len(word)+2, max_word_l) - for char in list(word): - char_indexer.vocab[char] += 1 - target_indexer.vocab[word] += 1 - - for word in src: - if chars == 1: - word = char_indexer.clean(word) - if len(word) == 0: - continue - max_word_l = max(len(word)+2, max_word_l) - for char in list(word): - char_indexer.vocab[char] += 1 - src_indexer.vocab[word] += 1 - + if train == 1: + for word in targ: + if chars == 1: + word = char_indexer.clean(word) + if len(word) == 0: + continue + max_word_l = max(len(word)+2, max_word_l) + for char in list(word): + char_indexer.vocab[char] += 1 + target_indexer.vocab[word] += 1 + + add_features_vocab(src_orig_features, src_feature_indexers) + + for word in src: + if chars == 1: + word = char_indexer.clean(word) + if len(word) == 0: + continue + max_word_l = max(len(word)+2, max_word_l) + for char in list(word): + char_indexer.vocab[char] += 1 + src_indexer.vocab[word] += 1 return max_word_l, num_sents - + + ###Change here def convert(srcfile, targetfile, batchsize, seqlength, outfile, total_num_sents, - num_sents, max_word_l=35, max_sent_l=0,chars=0): + num_sents, max_word_l = 35, max_sent_l=0,chars=1): + + def init_features_tensor(indexers): + return [ np.zeros((num_sents, newseqlength), dtype=int) + for i in range(len(indexers)) ] + + def load_features(orig_features, indexers, seqlength): + if len(orig_features) == 0: + return None + features = [] + for i in range(len(orig_features)): + features.append([[indexers[i].BOS]] + + orig_features[i] + + [[indexers[i].EOS]]) + + for i in range(len(features)): + features[i] = pad(features[i], seqlength, [indexers[i].PAD]) + for j in range(len(features[i])): + features[i][j] = indexers[i].convert_sequence(features[i][j])[0] + features[i] = np.array(features[i], dtype=int) + return features if num_sents > total_num_sents: - num_sents = total_num_sents + num_sents = total_num_sents + newseqlength = seqlength + 2 #add 2 for EOS and BOS targets = np.zeros((num_sents, newseqlength), dtype=int) target_output = np.zeros((num_sents, newseqlength), dtype=int) sources = np.zeros((num_sents, newseqlength), dtype=int) + sources_features = init_features_tensor(src_feature_indexers) source_lengths = np.zeros((num_sents,), dtype=int) target_lengths = np.zeros((num_sents,), dtype=int) if chars==1: @@ -119,39 +205,44 @@ def convert(srcfile, targetfile, batchsize, seqlength, outfile, total_num_sents, dropped = 0 sent_id = 0 total_sent_id = 0 - shards = 1 + shards = 1 for _, (src_orig, targ_orig) in \ enumerate(itertools.izip(open(srcfile,'r'), open(targetfile,'r'))): - src_orig = src_indexer.clean(src_orig.decode("utf-8").strip()) - targ_orig = target_indexer.clean(targ_orig.decode("utf-8").strip()) + src_orig, src_orig_features = load_sentence(src_orig, src_feature_indexers) + if chars == 1: + src_orig = src_indexer.clean(src_orig.decode("utf-8").strip()) + targ_orig = target_indexer.clean(targ_orig.decode("utf-8").strip()) + else: + src_orig = src_indexer.clean(src_orig.strip()) + targ_orig = target_indexer.clean(targ_orig.strip()) targ = [target_indexer.BOS] + targ_orig.strip().split() + [target_indexer.EOS] src = [src_indexer.BOS] + src_orig.strip().split() + [src_indexer.EOS] max_sent_l = max(len(targ), len(src), max_sent_l) if len(targ) > newseqlength or len(src) > newseqlength or len(targ) < 3 or len(src) < 3: dropped += 1 continue - targ = pad(targ, newseqlength+1, target_indexer.PAD) targ_char = [] for word in targ: if chars == 1: word = char_indexer.clean(word) + #use UNK for target, but not for source #word = word if word in target_indexer.d else target_indexer.UNK if chars == 1: char = [char_indexer.BOS] + list(word) + [char_indexer.EOS] if len(char) > max_word_l: char = char[:max_word_l] char[-1] = char_indexer.EOS - char_idx = char_indexer.convert_sequence(pad(char, max_word_l, - char_indexer.PAD)) - targ_char.append(char_idx) + char_idx = char_indexer.convert_sequence(pad(char, max_word_l, char_indexer.PAD)) + targ_char.append(char_idx) targ = target_indexer.convert_sequence(targ) + #targ = np.array(targ, dtype=int) targets[sent_id] = np.array(targ[:-1],dtype=int) target_lengths[sent_id] = (targets[sent_id] != 1).sum() if chars == 1: targets_char[sent_id] = np.array(targ_char[:-1], dtype=int) - target_output[sent_id] = np.array(targ[1:],dtype=int) - + target_output[sent_id] = np.array(targ[1:],dtype=int) + src = pad(src, newseqlength, src_indexer.PAD) src_char = [] for word in src: @@ -161,45 +252,52 @@ def convert(srcfile, targetfile, batchsize, seqlength, outfile, total_num_sents, if len(char) > max_word_l: char = char[:max_word_l] char[-1] = char_indexer.EOS - char_idx = char_indexer.convert_sequence(pad(char, max_word_l, - char_indexer.PAD)) + char_idx = char_indexer.convert_sequence(pad(char, max_word_l, char_indexer.PAD)) src_char.append(char_idx) src = src_indexer.convert_sequence(src) sources[sent_id] = np.array(src, dtype=int) - source_lengths[sent_id] = (sources[sent_id] != 1).sum() + source_lengths[sent_id] = (sources[sent_id] != 1).sum() if chars == 1: sources_char[sent_id] = np.array(src_char, dtype=int) + source_features = load_features(src_orig_features, src_feature_indexers, newseqlength) + for i in range(len(src_feature_indexers)): + sources_features[i][sent_id] = np.array(source_features[i], dtype=int) + sent_id += 1 total_sent_id += 1 if sent_id % 100000 == 0: print("{}/{} sentences processed, shard {}".format(total_sent_id, total_num_sents, shards)) + if sent_id % num_sents == 0 or total_sent_id == total_num_sents: if total_sent_id == total_num_sents: source_lengths = source_lengths[:sent_id] print(sent_id, num_sents) print("saving shard {}".format(shards)) sent_id = 0 - #break up batches based on source lengths + source_sort = np.argsort(source_lengths) - sources = sources[source_sort] targets = targets[source_sort] target_output = target_output[source_sort] target_l = target_lengths[source_sort] source_l = source_lengths[source_sort] print(sources.shape) + + for i in range(len(src_feature_indexers)): + sources_features[i] = sources_features[i][source_sort] + curr_l = 0 l_location = [] #idx where sent length changes - + for j,i in enumerate(source_sort): if source_lengths[i] > curr_l: curr_l = source_lengths[i] l_location.append(j+1) l_location.append(len(sources)) - + #get batch sizes curr_idx = 1 batch_idx = [1] @@ -212,35 +310,43 @@ def convert(srcfile, targetfile, batchsize, seqlength, outfile, total_num_sents, curr_idx = min(curr_idx + batchsize, l_location[i+1]) batch_idx.append(curr_idx) for i in range(len(batch_idx)-1): - batch_l.append(batch_idx[i+1] - batch_idx[i]) + batch_l.append(batch_idx[i+1] - batch_idx[i]) batch_w.append(source_l[batch_idx[i]-1]) nonzeros.append((target_output[batch_idx[i]-1:batch_idx[i+1]-1] != 1).sum().sum()) target_l_max.append(max(target_l[batch_idx[i]-1:batch_idx[i+1]-1])) - + # Write output - f = h5py.File(outfile + '.' + str(shards) + '.hdf5', "w") - + f = h5py.File(outfile, "w") + f["source"] = sources f["target"] = targets f["target_output"] = target_output f["target_l"] = np.array(target_l_max, dtype=int) - f["target_l_all"] = target_l + f["target_l_all"] = target_l f["batch_l"] = np.array(batch_l, dtype=int) f["batch_w"] = np.array(batch_w, dtype=int) f["batch_idx"] = np.array(batch_idx[:-1], dtype=int) f["target_nonzeros"] = np.array(nonzeros, dtype=int) f["source_size"] = np.array([len(src_indexer.d)]) f["target_size"] = np.array([len(target_indexer.d)]) + f["num_source_features"] = np.array([len(src_feature_indexers)]) + for i in range(len(src_feature_indexers)): + f["source_feature_" + str(i+1)] = sources_features[i] + f["source_feature_" + str(i+1) + "_size"] = np.array([len(src_feature_indexers[i].d)]) if chars == 1: + #del sources, targets, target_output sources_char = sources_char[source_sort] - targets_char = targets_char[source_sort] f["source_char"] = sources_char + #del sources_char + targets_char = targets_char[source_sort] f["target_char"] = targets_char f["char_size"] = np.array([len(char_indexer.d)]) - print("Saved {} sentences (dropped {} due to length)".format(len(f["source"]), dropped)) + print("Saved {} sentences (dropped {} due to length/unk filter)".format( + len(f["source"]), dropped)) + f.close() - shards += 1 - + + shards +=1 return max_sent_l print("First pass through data to get vocab...") @@ -248,8 +354,8 @@ def convert(srcfile, targetfile, batchsize, seqlength, outfile, total_num_sents, args.seqlength, 0, args.chars) print("Number of sentences in training: {}".format(num_sents_train)) max_word_l, num_sents_valid = make_vocab(args.srcvalfile, args.targetvalfile, - args.seqlength, max_word_l, args.chars) - print("Number of sentences in valid: {}".format(num_sents_valid)) + args.seqlength, max_word_l, args.chars, 0) + print("Number of sentences in valid: {}".format(num_sents_valid)) if args.chars == 1: print("Max word length (before cutting): {}".format(max_word_l)) max_word_l = min(max_word_l, args.maxwordlength) @@ -259,44 +365,72 @@ def convert(srcfile, targetfile, batchsize, seqlength, outfile, total_num_sents, src_indexer.prune_vocab(args.srcvocab) target_indexer.prune_vocab(args.targetvocab) src_indexer.write(args.outputfile + ".src.dict") - target_indexer.write(args.outputfile + ".targ.dict") + target_indexer.write(args.outputfile + ".targ.dict") if args.chars == 1: - char_indexer.prune_vocab(200) + + char_indexer.prune_vocab(500) char_indexer.write(args.outputfile + ".char.dict") print("Character vocab size: {}".format(len(char_indexer.pruned_vocab))) - - print("Source vocab size: Original = {}, Pruned = {}".format(len(src_indexer.vocab), + + if args.reusefeaturefile != '': + load_features('source', src_feature_indexers, args.reusefeaturefile) + + save_features('source', src_feature_indexers, args.outputfile) + + print("Source vocab size: Original = {}, Pruned = {}".format(len(src_indexer.vocab), len(src_indexer.pruned_vocab))) - print("Target vocab size: Original = {}, Pruned = {}".format(len(target_indexer.vocab), + print("Target vocab size: Original = {}, Pruned = {}".format(len(target_indexer.vocab), len(target_indexer.pruned_vocab))) - max_sent_l = 0 + max_sent_l = 0 max_sent_l = convert(args.srcvalfile, args.targetvalfile, args.batchsize, args.seqlength, - args.outputfile + "-val.hdf5", num_sents_valid, args.shardsize, + args.outputfile + "-val.hdf5", num_sents_valid,args.shardsize, max_word_l, max_sent_l, args.chars) max_sent_l = convert(args.srcfile, args.targetfile, args.batchsize, args.seqlength, - args.outputfile + "-train.hdf5", num_sents_train, args.shardsize, - max_word_l, max_sent_l, args.chars) - print("Max sent length (before dropping): {}".format(max_sent_l)) - + args.outputfile + "-train.hdf5", num_sents_train,args.shardsize, max_word_l, + max_sent_l, args.chars) + + print("Max sent length (before dropping): {}".format(max_sent_l)) + def main(arguments): parser = argparse.ArgumentParser( description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument('--srcvocab', help="Source vocab size", type=int) - parser.add_argument('--targetvocab', help="Target vocab size", type=int) - parser.add_argument('--srcfile', help="Source Input file") - parser.add_argument('--targetfile', help="Target Input file") - parser.add_argument('--srcvalfile', help="Source Val file") - parser.add_argument('--targetvalfile', help="Target val file") - parser.add_argument('--batchsize', help="Batchsize", type=int) - parser.add_argument('--shardsize', help="Num sents in each shard", type=int) - parser.add_argument('--seqlength', help="(Max) Sequence length", type=int) - parser.add_argument('--outputfile', help="HDF5 output file", type=str) - parser.add_argument('--maxwordlength', help="Max word length", type=int) - parser.add_argument('--chars', help="Use characters", type=int) + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--srcvocab', help="Size of source vocabulary, constructed " + "by taking the top X most frequent words. " + " Rest are replaced with special UNK tokens.", + type=int, default=50000) + parser.add_argument('--targetvocab', help="Size of target vocabulary, constructed " + "by taking the top X most frequent words. " + "Rest are replaced with special UNK tokens.", + type=int, default=50000) + parser.add_argument('--srcfile', help="Path to source training data, " + "where each line represents a single " + "source/target sequence.", required=True) + parser.add_argument('--targetfile', help="Path to target training data, " + "where each line represents a single " + "source/target sequence.", required=True) + parser.add_argument('--srcvalfile', help="Path to source validation data.", required=True) + parser.add_argument('--targetvalfile', help="Path to target validation data.", required=True) + parser.add_argument('--shardsize', help="Num sents in each shard", type=int) + parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=64) + parser.add_argument('--seqlength', help="Maximum sequence length. Sequences longer " + "than this are dropped.", type=int, default=50) + parser.add_argument('--outputfile', help="Prefix of the output file names. ", type=str, required=True) + parser.add_argument('--maxwordlength', help="For the character models, words are " + "(if longer than maxwordlength) or zero-padded " + "(if shorter) to maxwordlength", type=int, default=35) + parser.add_argument('--chars', help="If 1, construct the character-level dataset as well. " + "This might take up a lot of space depending on your data " + "size, so you may want to break up the training data into " + "different shards.", type=int, default=0) + + parser.add_argument('--reusefeaturefile', help="Use existing feature vocabs", + type = str, default ='') + args = parser.parse_args(arguments) get_data(args) if __name__ == '__main__': sys.exit(main(sys.argv[1:])) + From bcd899ec990da6b2c5c616aab5ac77b5c7760dc6 Mon Sep 17 00:00:00 2001 From: mdasadul Date: Thu, 22 Sep 2016 12:05:29 -0700 Subject: [PATCH 2/2] fixing file writing issues --- preprocess-shards.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/preprocess-shards.py b/preprocess-shards.py index bd90346..96ad3a5 100644 --- a/preprocess-shards.py +++ b/preprocess-shards.py @@ -316,7 +316,9 @@ def load_features(orig_features, indexers, seqlength): target_l_max.append(max(target_l[batch_idx[i]-1:batch_idx[i+1]-1])) # Write output - f = h5py.File(outfile, "w") + #f = h5py.File(outfile, "w") + f = h5py.File(outfile + '.' + str(shards) + '.hdf5', "w") + f["source"] = sources f["target"] = targets