Skip to content

Commit 4b3f34f

Browse files
authored
Create course_7_shakespeare_gen.py
1 parent 487c291 commit 4b3f34f

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

course_7_shakespeare_gen.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import absolute_import, division, print_function
2+
3+
import os
4+
import pickle
5+
from six.moves import urllib
6+
7+
import tflearn
8+
from tflearn.data_utils import *
9+
10+
path = "shakespeare_input.txt"
11+
char_idx_file = 'char_idx.pickle'
12+
13+
if not os.path.isfile(path):
14+
urllib.request.urlretrieve("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/shakespeare_input.txt", path)
15+
16+
maxlen = 25
17+
18+
char_idx = None
19+
if os.path.isfile(char_idx_file):
20+
print('Loading previous char_idx')
21+
char_idx = pickle.load(open(char_idx_file, 'rb'))
22+
23+
X, Y, char_idx = \
24+
textfile_to_semi_redundant_sequences(path, seq_maxlen=maxlen, redun_step=3,
25+
pre_defined_char_idx=char_idx)
26+
27+
pickle.dump(char_idx, open(char_idx_file,'wb'))
28+
29+
g = tflearn.input_data([None, maxlen, len(char_idx)])
30+
g = tflearn.lstm(g, 512, return_seq=True)
31+
g = tflearn.dropout(g, 0.5)
32+
g = tflearn.lstm(g, 512, return_seq=True)
33+
g = tflearn.dropout(g, 0.5)
34+
g = tflearn.lstm(g, 512)
35+
g = tflearn.dropout(g, 0.5)
36+
g = tflearn.fully_connected(g, len(char_idx), activation='softmax')
37+
g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy',
38+
learning_rate=0.001)
39+
40+
m = tflearn.SequenceGenerator(g, dictionary=char_idx,
41+
seq_maxlen=maxlen,
42+
clip_gradients=5.0,
43+
checkpoint_path='model_shakespeare')
44+
45+
for i in range(50):
46+
seed = random_sequence_from_textfile(path, maxlen)
47+
m.fit(X, Y, validation_set=0.1, batch_size=128,
48+
n_epoch=1, run_id='shakespeare')
49+
print("-- TESTING...")
50+
print("-- Test with temperature of 1.0 --")
51+
print(m.generate(600, temperature=1.0, seq_seed=seed))
52+
print("-- Test with temperature of 0.5 --")
53+
print(m.generate(600, temperature=0.5, seq_seed=seed))

0 commit comments

Comments
 (0)