-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathembedding.py
62 lines (53 loc) · 1.98 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def create_embedding():
header = ['id', 'category']
labels = ['politics', 'economy', 'sport']
points = np.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]
])
embedding = []
metadata = []
id = 0
for i in range(len(points)):
for j in range(20):
embedding.append(points[i] + np.random.normal(loc=0.0, scale=0.2, size=3))
label = '%s_%d' % (labels[i], j+1)
metadata.append([id, label])
id += 1
embedding.append(np.array([0.5, 0. , 0.5]))
metadata.append([1000, 'profile_1'])
embedding = np.array(embedding)
return header, metadata, embedding
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--output-dir', help='output path to store Tensorboard artifacts', default='/tmp/embedding')
args = parser.parse_args()
outpath = args.output_dir
ensure_dir(outpath)
header, metadata, embedding = create_embedding()
embedding_v = tf.Variable(embedding, trainable=True, name='embedding')
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver()
writer = tf.summary.FileWriter(outpath, sess.graph)
config = projector.ProjectorConfig()
embed = config.embeddings.add()
embed.tensor_name = 'embedding:0'
embed.metadata_path = os.path.join(outpath, 'metadata.tsv')
projector.visualize_embeddings(writer, config)
saver.save(sess, os.path.join(outpath, 'a_model.ckpt'))
with open(os.path.join(outpath, 'metadata.tsv'), 'w') as f:
f.write('{}\t{}\n'.format(*header))
for i in range(len(metadata)):
f.write('{}\t{}\n'.format(*metadata[i]))
if __name__ == "__main__":
main()