Skip to content

Commit a91b4ad

Browse files
committed
Updated for Python 2/3 compatibility
1 parent 8580ada commit a91b4ad

10 files changed

Lines changed: 1091 additions & 493 deletions

File tree

disjunction/bls41.py

Lines changed: 284 additions & 115 deletions
Large diffs are not rendered by default.

embeddedscalars/analysis.py

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
#!/usr/bin/env python
22

3-
######################################################################
4-
# Functions for analyzing models against experimental data using the
5-
# techniques described in the paper. For examples, see paper.py
6-
######################################################################
3+
"""
4+
Functions for analyzing models against experimental data using the
5+
techniques described in the paper. For examples, see `paper.py`
6+
"""
7+
8+
__author__ = "Christopher Potts"
9+
__version__ = "2.0"
10+
__license__ = "GNU general public license, version 3"
11+
__maintainer__ = "Christopher Potts"
12+
__email__ = "See the author's website"
13+
714

8-
import sys
9-
import csv
1015
from copy import copy
1116
from collections import defaultdict
17+
import csv
1218
from itertools import product
1319
import numpy as np
1420
from scipy.stats import spearmanr, pearsonr
1521
from scipy import stats
16-
from settings import *
17-
sys.path.append('../')
18-
from utils import *
19-
import bootstrap
22+
import sys
23+
24+
from pypragmods.embeddedscalars import bootstrap
25+
from pypragmods.embeddedscalars.settings import *
26+
from pypragmods.utils import *
27+
2028

2129
######################################################################
2230

@@ -70,9 +78,17 @@ def overall_analysis(self, digits=4, nsims=10000):
7078
p_lower, p_upper = self.get_ci(sample_pearsons)
7179
s_lower, s_upper = self.get_ci(sample_spearmans)
7280
e_lower, e_upper = self.get_ci(sample_errs)
73-
rows.append(np.array([pearson, p_lower, p_upper, pearson_p, spearman, s_lower, s_upper, spearman_p, err, e_lower, e_upper]))
74-
labels = ['Pearson', 'PearsonLower', 'PearsonUpper', 'Pearson p', 'Spearman', 'SpearmanLower', 'SpearmanUpper', 'Spearman p', 'MSE', 'MSELower', 'MSEUpper']
75-
display_matrix(np.array(rows), rnames=self.modnames, cnames=labels, digits=digits)
81+
rows.append(np.array([pearson, p_lower, p_upper, pearson_p,
82+
spearman, s_lower, s_upper, spearman_p,
83+
err, e_lower, e_upper]))
84+
labels = ['Pearson', 'PearsonLower', 'PearsonUpper', 'Pearson p',
85+
'Spearman', 'SpearmanLower', 'SpearmanUpper', 'Spearman p',
86+
'MSE', 'MSELower', 'MSEUpper']
87+
display_matrix(
88+
np.array(rows),
89+
rnames=self.modnames,
90+
cnames=labels,
91+
digits=digits)
7692

7793
def get_ci(self, vals, percentiles=[2.5, 97.5]):
7894
return np.percentile(vals, percentiles)
@@ -87,14 +103,15 @@ def numeric_analysis(self):
87103
spearman, spearman_p = spearmanr(expvec, lisvec)
88104
s_lower, s_upper = correlation_coefficient_ci(spearman, n=observation_count)
89105
err = mse(expvec, lisvec)
90-
results[self.modnames[i]] = dict(zip(labels, [pearson, pearson_p, spearman, spearman_p, err]))
106+
results[self.modnames[i]] = dict(list(zip(labels, [pearson, pearson_p, spearman, spearman_p, err])))
91107
return results
92108

93109
def analysis_by_message(self, digits=4):
94110
rows = []
95111
msglen = max([len(x) for x in self.messages])
96112
modlen = max([len(x) for x in self.modnames])
97-
rnames = [msg.rjust(msglen)+" "+ mod.rjust(modlen) for msg, mod in product(self.messages, self.modnames)]
113+
rnames = [msg.rjust(msglen)+" "+ mod.rjust(modlen)
114+
for msg, mod in product(self.messages, self.modnames)]
98115
for i, msg in enumerate(self.messages):
99116
expvec = self.expmat[i]
100117
for j, lis in enumerate(self.listeners):
@@ -104,7 +121,11 @@ def analysis_by_message(self, digits=4):
104121
spearman, spearman_p = spearmanr(expvec, lisvec)
105122
err = mse(expvec, lisvec)
106123
rows.append(np.array([pearson, pearson_p, spearman, spearman_p, err]))
107-
display_matrix(np.array(rows), rnames=rnames, cnames=['Pearson', 'Pearson p', 'Spearman', 'Spearman p', 'MSE'], digits=digits)
124+
display_matrix(
125+
np.array(rows),
126+
rnames=rnames,
127+
cnames=['Pearson', 'Pearson p', 'Spearman', 'Spearman p', 'MSE'],
128+
digits=digits)
108129

109130
def comparison_plot(self, width=0.2, output_filename=None, nrows=None):
110131
# Preferred: human left, then models from best to worse, informally:
@@ -122,16 +143,40 @@ def comparison_plot(self, width=0.2, output_filename=None, nrows=None):
122143
fig.text(0.5, 0.05, 'Probability', ha='center', va='center', fontsize=30)
123144
fig.text(0.08, 0.5, 'World', ha='center', va='center', rotation='vertical', fontsize=30)
124145
# Human column, then model columns:
125-
self.model_comparison_plot(axarray[:,0], self.expmat, width=width, color=colors[0], modname='Human', left=True, right=False, nrows=nrows)
146+
self.model_comparison_plot(
147+
axarray[:,0],
148+
self.expmat,
149+
width=width,
150+
color=colors[0],
151+
modname='Human',
152+
left=True,
153+
right=False,
154+
nrows=nrows)
126155
for i, lis in enumerate(listeners):
127-
self.model_comparison_plot(axarray[: , i+1], lis, width=width, color=colors[-(i+1)], modname=modnames[i], left=False, right=i==ncols-2, nrows=nrows)
156+
self.model_comparison_plot(
157+
axarray[: , i+1],
158+
lis,
159+
width=width,
160+
color=colors[-(i+1)],
161+
modname=modnames[i],
162+
left=False,
163+
right=i==ncols-2,
164+
nrows=nrows)
128165
# Output:
129166
if output_filename:
130167
plt.savefig(output_filename, bbox_inches='tight')
131168
else:
132169
plt.show()
133170

134-
def model_comparison_plot(self, axarray, modmat, width=1.0, color='black', modname=None, left=False, right=False, nrows=None):
171+
def model_comparison_plot(self,
172+
axarray,
173+
modmat,
174+
width=1.0,
175+
color='black',
176+
modname=None,
177+
left=False,
178+
right=False,
179+
nrows=None):
135180
# Preferred ordering puts the embedded 'some' sentences last:
136181
message_ordering_indices = [0,3,6,1,4,7,2,5,8]
137182
if nrows:
@@ -158,16 +203,27 @@ def model_comparison_plot(self, axarray, modmat, width=1.0, color='black', modna
158203
msg = msgs[j]
159204
row = modmat[j]
160205
row = row[::-1] # Reversal for preferred ordering.
161-
ax.tick_params(axis='both', which='both', bottom='off', left='off', top='off', right='off')
206+
ax.tick_params(
207+
axis='both',
208+
which='both',
209+
bottom='off',
210+
left='off',
211+
top='off',
212+
right='off')
162213
ax.barh(pos, row, width, color=color)
163214
# title as model name:
164215
if j == 0:
165-
ax.set_title(r"\textbf{%s}" % modname, fontsize=title_size, color=color, fontweight='bold')
216+
ax.set_title(
217+
r"\textbf{%s}" % modname,
218+
fontsize=title_size,
219+
color=color,
220+
fontweight='bold')
166221
# x-axis
167222
ax.set_xlim(xlim)
168223
ax.set_xticks(xticks)
169224
if j == len(axarray)-1:
170-
ax.set_xticklabels(xtick_labels, fontsize=xtick_labelsize, color='black')
225+
ax.set_xticklabels(
226+
xtick_labels, fontsize=xtick_labelsize, color='black')
171227
else:
172228
ax.set_xticklabels([])
173229
# y-axis:
@@ -177,7 +233,8 @@ def model_comparison_plot(self, axarray, modmat, width=1.0, color='black', modna
177233
ax.set_ylim(ylim)
178234
ax.set_yticks(yticks)
179235
if left:
180-
ax.set_yticklabels(ytick_labels, fontsize=ytick_labelsize, color='black')
236+
ax.set_yticklabels(
237+
ytick_labels, fontsize=ytick_labelsize, color='black')
181238
else:
182239
ax.set_yticklabels([])
183240

0 commit comments

Comments
 (0)