|
| 1 | +from __future__ import absolute_import, division, print_function |
| 2 | + |
| 3 | +import os |
| 4 | +import pickle |
| 5 | +from six.moves import urllib |
| 6 | + |
| 7 | +import tflearn |
| 8 | +from tflearn.data_utils import * |
| 9 | + |
| 10 | +path = "shakespeare_input.txt" |
| 11 | +char_idx_file = 'char_idx.pickle' |
| 12 | + |
| 13 | +if not os.path.isfile(path): |
| 14 | + urllib.request.urlretrieve("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/shakespeare_input.txt", path) |
| 15 | + |
| 16 | +maxlen = 25 |
| 17 | + |
| 18 | +char_idx = None |
| 19 | +if os.path.isfile(char_idx_file): |
| 20 | + print('Loading previous char_idx') |
| 21 | + char_idx = pickle.load(open(char_idx_file, 'rb')) |
| 22 | + |
| 23 | +X, Y, char_idx = \ |
| 24 | + textfile_to_semi_redundant_sequences(path, seq_maxlen=maxlen, redun_step=3, |
| 25 | + pre_defined_char_idx=char_idx) |
| 26 | + |
| 27 | +pickle.dump(char_idx, open(char_idx_file,'wb')) |
| 28 | + |
| 29 | +g = tflearn.input_data([None, maxlen, len(char_idx)]) |
| 30 | +g = tflearn.lstm(g, 512, return_seq=True) |
| 31 | +g = tflearn.dropout(g, 0.5) |
| 32 | +g = tflearn.lstm(g, 512, return_seq=True) |
| 33 | +g = tflearn.dropout(g, 0.5) |
| 34 | +g = tflearn.lstm(g, 512) |
| 35 | +g = tflearn.dropout(g, 0.5) |
| 36 | +g = tflearn.fully_connected(g, len(char_idx), activation='softmax') |
| 37 | +g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy', |
| 38 | + learning_rate=0.001) |
| 39 | + |
| 40 | +m = tflearn.SequenceGenerator(g, dictionary=char_idx, |
| 41 | + seq_maxlen=maxlen, |
| 42 | + clip_gradients=5.0, |
| 43 | + checkpoint_path='model_shakespeare') |
| 44 | + |
| 45 | +for i in range(50): |
| 46 | + seed = random_sequence_from_textfile(path, maxlen) |
| 47 | + m.fit(X, Y, validation_set=0.1, batch_size=128, |
| 48 | + n_epoch=1, run_id='shakespeare') |
| 49 | + print("-- TESTING...") |
| 50 | + print("-- Test with temperature of 1.0 --") |
| 51 | + print(m.generate(600, temperature=1.0, seq_seed=seed)) |
| 52 | + print("-- Test with temperature of 0.5 --") |
| 53 | + print(m.generate(600, temperature=0.5, seq_seed=seed)) |
0 commit comments