11#!/usr/bin/env python
22
3- ######################################################################
4- # Functions for analyzing models against experimental data using the
5- # techniques described in the paper. For examples, see paper.py
6- ######################################################################
3+ """
4+ Functions for analyzing models against experimental data using the
5+ techniques described in the paper. For examples, see `paper.py`
6+ """
7+
8+ __author__ = "Christopher Potts"
9+ __version__ = "2.0"
10+ __license__ = "GNU general public license, version 3"
11+ __maintainer__ = "Christopher Potts"
12+ __email__ = "See the author's website"
13+
714
8- import sys
9- import csv
1015from copy import copy
1116from collections import defaultdict
17+ import csv
1218from itertools import product
1319import numpy as np
1420from scipy .stats import spearmanr , pearsonr
1521from scipy import stats
16- from settings import *
17- sys .path .append ('../' )
18- from utils import *
19- import bootstrap
22+ import sys
23+
24+ from pypragmods .embeddedscalars import bootstrap
25+ from pypragmods .embeddedscalars .settings import *
26+ from pypragmods .utils import *
27+
2028
2129######################################################################
2230
@@ -70,9 +78,17 @@ def overall_analysis(self, digits=4, nsims=10000):
7078 p_lower , p_upper = self .get_ci (sample_pearsons )
7179 s_lower , s_upper = self .get_ci (sample_spearmans )
7280 e_lower , e_upper = self .get_ci (sample_errs )
73- rows .append (np .array ([pearson , p_lower , p_upper , pearson_p , spearman , s_lower , s_upper , spearman_p , err , e_lower , e_upper ]))
74- labels = ['Pearson' , 'PearsonLower' , 'PearsonUpper' , 'Pearson p' , 'Spearman' , 'SpearmanLower' , 'SpearmanUpper' , 'Spearman p' , 'MSE' , 'MSELower' , 'MSEUpper' ]
75- display_matrix (np .array (rows ), rnames = self .modnames , cnames = labels , digits = digits )
81+ rows .append (np .array ([pearson , p_lower , p_upper , pearson_p ,
82+ spearman , s_lower , s_upper , spearman_p ,
83+ err , e_lower , e_upper ]))
84+ labels = ['Pearson' , 'PearsonLower' , 'PearsonUpper' , 'Pearson p' ,
85+ 'Spearman' , 'SpearmanLower' , 'SpearmanUpper' , 'Spearman p' ,
86+ 'MSE' , 'MSELower' , 'MSEUpper' ]
87+ display_matrix (
88+ np .array (rows ),
89+ rnames = self .modnames ,
90+ cnames = labels ,
91+ digits = digits )
7692
7793 def get_ci (self , vals , percentiles = [2.5 , 97.5 ]):
7894 return np .percentile (vals , percentiles )
@@ -87,14 +103,15 @@ def numeric_analysis(self):
87103 spearman , spearman_p = spearmanr (expvec , lisvec )
88104 s_lower , s_upper = correlation_coefficient_ci (spearman , n = observation_count )
89105 err = mse (expvec , lisvec )
90- results [self .modnames [i ]] = dict (zip (labels , [pearson , pearson_p , spearman , spearman_p , err ]))
106+ results [self .modnames [i ]] = dict (list ( zip (labels , [pearson , pearson_p , spearman , spearman_p , err ]) ))
91107 return results
92108
93109 def analysis_by_message (self , digits = 4 ):
94110 rows = []
95111 msglen = max ([len (x ) for x in self .messages ])
96112 modlen = max ([len (x ) for x in self .modnames ])
97- rnames = [msg .rjust (msglen )+ " " + mod .rjust (modlen ) for msg , mod in product (self .messages , self .modnames )]
113+ rnames = [msg .rjust (msglen )+ " " + mod .rjust (modlen )
114+ for msg , mod in product (self .messages , self .modnames )]
98115 for i , msg in enumerate (self .messages ):
99116 expvec = self .expmat [i ]
100117 for j , lis in enumerate (self .listeners ):
@@ -104,7 +121,11 @@ def analysis_by_message(self, digits=4):
104121 spearman , spearman_p = spearmanr (expvec , lisvec )
105122 err = mse (expvec , lisvec )
106123 rows .append (np .array ([pearson , pearson_p , spearman , spearman_p , err ]))
107- display_matrix (np .array (rows ), rnames = rnames , cnames = ['Pearson' , 'Pearson p' , 'Spearman' , 'Spearman p' , 'MSE' ], digits = digits )
124+ display_matrix (
125+ np .array (rows ),
126+ rnames = rnames ,
127+ cnames = ['Pearson' , 'Pearson p' , 'Spearman' , 'Spearman p' , 'MSE' ],
128+ digits = digits )
108129
109130 def comparison_plot (self , width = 0.2 , output_filename = None , nrows = None ):
110131 # Preferred: human left, then models from best to worse, informally:
@@ -122,16 +143,40 @@ def comparison_plot(self, width=0.2, output_filename=None, nrows=None):
122143 fig .text (0.5 , 0.05 , 'Probability' , ha = 'center' , va = 'center' , fontsize = 30 )
123144 fig .text (0.08 , 0.5 , 'World' , ha = 'center' , va = 'center' , rotation = 'vertical' , fontsize = 30 )
124145 # Human column, then model columns:
125- self .model_comparison_plot (axarray [:,0 ], self .expmat , width = width , color = colors [0 ], modname = 'Human' , left = True , right = False , nrows = nrows )
146+ self .model_comparison_plot (
147+ axarray [:,0 ],
148+ self .expmat ,
149+ width = width ,
150+ color = colors [0 ],
151+ modname = 'Human' ,
152+ left = True ,
153+ right = False ,
154+ nrows = nrows )
126155 for i , lis in enumerate (listeners ):
127- self .model_comparison_plot (axarray [: , i + 1 ], lis , width = width , color = colors [- (i + 1 )], modname = modnames [i ], left = False , right = i == ncols - 2 , nrows = nrows )
156+ self .model_comparison_plot (
157+ axarray [: , i + 1 ],
158+ lis ,
159+ width = width ,
160+ color = colors [- (i + 1 )],
161+ modname = modnames [i ],
162+ left = False ,
163+ right = i == ncols - 2 ,
164+ nrows = nrows )
128165 # Output:
129166 if output_filename :
130167 plt .savefig (output_filename , bbox_inches = 'tight' )
131168 else :
132169 plt .show ()
133170
134- def model_comparison_plot (self , axarray , modmat , width = 1.0 , color = 'black' , modname = None , left = False , right = False , nrows = None ):
171+ def model_comparison_plot (self ,
172+ axarray ,
173+ modmat ,
174+ width = 1.0 ,
175+ color = 'black' ,
176+ modname = None ,
177+ left = False ,
178+ right = False ,
179+ nrows = None ):
135180 # Preferred ordering puts the embedded 'some' sentences last:
136181 message_ordering_indices = [0 ,3 ,6 ,1 ,4 ,7 ,2 ,5 ,8 ]
137182 if nrows :
@@ -158,16 +203,27 @@ def model_comparison_plot(self, axarray, modmat, width=1.0, color='black', modna
158203 msg = msgs [j ]
159204 row = modmat [j ]
160205 row = row [::- 1 ] # Reversal for preferred ordering.
161- ax .tick_params (axis = 'both' , which = 'both' , bottom = 'off' , left = 'off' , top = 'off' , right = 'off' )
206+ ax .tick_params (
207+ axis = 'both' ,
208+ which = 'both' ,
209+ bottom = 'off' ,
210+ left = 'off' ,
211+ top = 'off' ,
212+ right = 'off' )
162213 ax .barh (pos , row , width , color = color )
163214 # title as model name:
164215 if j == 0 :
165- ax .set_title (r"\textbf{%s}" % modname , fontsize = title_size , color = color , fontweight = 'bold' )
216+ ax .set_title (
217+ r"\textbf{%s}" % modname ,
218+ fontsize = title_size ,
219+ color = color ,
220+ fontweight = 'bold' )
166221 # x-axis
167222 ax .set_xlim (xlim )
168223 ax .set_xticks (xticks )
169224 if j == len (axarray )- 1 :
170- ax .set_xticklabels (xtick_labels , fontsize = xtick_labelsize , color = 'black' )
225+ ax .set_xticklabels (
226+ xtick_labels , fontsize = xtick_labelsize , color = 'black' )
171227 else :
172228 ax .set_xticklabels ([])
173229 # y-axis:
@@ -177,7 +233,8 @@ def model_comparison_plot(self, axarray, modmat, width=1.0, color='black', modna
177233 ax .set_ylim (ylim )
178234 ax .set_yticks (yticks )
179235 if left :
180- ax .set_yticklabels (ytick_labels , fontsize = ytick_labelsize , color = 'black' )
236+ ax .set_yticklabels (
237+ ytick_labels , fontsize = ytick_labelsize , color = 'black' )
181238 else :
182239 ax .set_yticklabels ([])
183240
0 commit comments