-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathrun_mvsc.py
139 lines (120 loc) · 5.67 KB
/
run_mvsc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Given a pre-trained model, run mvsc on it, and print scores vs gold standard
we'll use view1 encoder to encode each of view1 and view2, and then pass that through mvsc algo
this should probably be folded innto run_clustering.py (originally kind of forked from
run_clustering.py, and combined with some things from train_pca.py and train.py)
"""
import time
import random
import datetime
import argparse
import sklearn.cluster
import numpy as np
import torch
from metrics import cluster_metrics
from model import multiview_encoders
from proc_data import Dataset
try:
import multiview
except Exception as e:
print('please install https://github.com/mariceli3/multiview')
print('eg pip install git+https://github.com/mariceli3/multiview')
raise e
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 32
def transform(dataset, perm_idx, model, view):
"""
for view1 utterance, simply encode using view1 encoder
for view 2 utterances:
- encode each utterance, using view 1 encoder, to get utterance embeddings
- take average of utterance embeddings to form view 2 embedding
"""
model.eval()
latent_zs, golds = [], []
n_batch = (len(perm_idx) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(n_batch):
indices = perm_idx[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
v1_batch, v2_batch = list(zip(*[dataset[idx][0] for idx in indices]))
golds += [dataset[idx][1] for idx in indices]
if view == 'v1':
latent_z = model(v1_batch, encoder='v1')
elif view == 'v2':
latent_z_l = [model(conv, encoder='v1').mean(dim=0) for conv in v2_batch]
latent_z = torch.stack(latent_z_l)
latent_zs.append(latent_z.cpu().data.numpy())
latent_zs = np.concatenate(latent_zs)
return latent_zs, golds
def run(
ref, model_path, num_clusters, num_cluster_samples, seed,
out_cluster_samples_file_hier,
max_examples, out_cluster_samples_file,
data_path, view1_col, view2_col, label_col,
sampling_strategy, mvsc_no_unk):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
id_to_token, token_to_id, vocab_size, word_emb_size, mvc_encoder = \
multiview_encoders.load_model(model_path)
print('loaded model')
print('loading dataset')
dataset = Dataset(data_path, view1_col=view1_col, view2_col=view2_col, label_col=label_col)
n_cluster = len(dataset.id_to_label) - 1
print("loaded dataset, num of class = %d" % n_cluster)
idxes = dataset.trn_idx_no_unk if mvsc_no_unk else dataset.trn_idx
trn_idx = [x.item() for x in np.random.permutation(idxes)]
if max_examples is not None:
trn_idx = trn_idx[:max_examples]
num_clusters = n_cluster if num_clusters is None else num_clusters
print('clustering over num clusters', num_clusters)
mvsc = multiview.mvsc.MVSC(
k=n_cluster
)
latent_z1s, golds = transform(dataset, trn_idx, mvc_encoder, view='v1')
latent_z2s, _ = transform(dataset, trn_idx, mvc_encoder, view='v2')
print('running mvsc', end='', flush=True)
start = time.time()
preds, eivalues, eivectors, sigmas = mvsc.fit_transform(
[latent_z1s, latent_z2s], [False] * 2
)
print('...done')
mvsc_time = time.time() - start
print('time taken %.3f' % mvsc_time)
lgolds, lpreds = [], []
for g, p in zip(golds, list(preds)):
if g > 0:
lgolds.append(g)
lpreds.append(p)
prec, rec, f1 = cluster_metrics.calc_prec_rec_f1(
gnd_assignments=torch.LongTensor(lgolds).to(device),
pred_assignments=torch.LongTensor(lpreds).to(device))
acc = cluster_metrics.calc_ACC(
torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device))
silhouette = sklearn.metrics.silhouette_score(latent_z1s, preds, metric='euclidean')
davies_bouldin = sklearn.metrics.davies_bouldin_score(latent_z1s, preds)
print(f'{datetime.datetime.now()} pretrain: eval prec={prec:.4f} rec={rec:.4f} f1={f1:.4f} '
f'acc={acc:.4f} sil={silhouette:.4f}, db={davies_bouldin:.4f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--max-examples', type=int,
help='since we might not want to cluster entire dataset?')
parser.add_argument('--mvsc-no-unk', action='store_true',
help='only feed non-unk data to MVSC (to avoid oom)')
parser.add_argument('--ref', type=str, required=True)
parser.add_argument('--model-path', type=str, required=True)
parser.add_argument('--data-path', type=str, default='./data/airlines_500_merged.csv')
parser.add_argument('--view1-col', type=str, default='view1_col')
parser.add_argument('--view2-col', type=str, default='view2_col')
parser.add_argument('--label-col', type=str, default='cluster_id')
parser.add_argument('--num-clusters', type=int, help='defaults to number of supervised classes')
parser.add_argument('--num-cluster-samples', type=int, default=10)
parser.add_argument('--sampling-strategy', type=str,
choices=['uniform', 'nearest'], default='nearest')
parser.add_argument('--out-cluster-samples-file-hier', type=str,
default='tmp/cluster_samples_hier_{ref}.txt')
parser.add_argument('--out-cluster-samples-file', type=str,
default='tmp/cluster_samples_{ref}.txt')
args = parser.parse_args()
args.out_cluster_samples_file = args.out_cluster_samples_file.format(**args.__dict__)
args.out_cluster_samples_file_hier = args.out_cluster_samples_file_hier.format(**args.__dict__)
run(**args.__dict__)