Skip to content

Commit 09bc558

Browse files
committedMay 12, 2020
adding scripts for WMT'20.
1 parent 098fe47 commit 09bc558

8 files changed

+782
-2
lines changed
 

‎README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
# qe-eval-scripts
2-
Scripts to process & score QE predictions into WMT format.
1+
# Scripts to process & score QE predictions
2+
3+
This repository contains scripts to score QE outputs in preparation of a submission at WMT.

‎wmt20/README.md

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# WMT'20 QE shared task
2+
3+
All information about the shared task is [here](http://www.statmt.org/wmt20/quality-estimation-task.html).
4+
5+
### Scoring programs
6+
All written in Python, use `requirements.txt` to install the required modules.
7+
Then,
8+
9+
for **Task 1**, use the following scripts:
10+
* sentence-level DA: `python sent_evaluate.py -h`
11+
* sentence-level DA **multilingual**: `python sent-multi_evaluate.py -h`
12+
13+
for **Task 2**, use the following scripts:
14+
* sentence-level HTER: `python sent_evaluate.py -h`
15+
* word-level HTER: `python word_evaluate.py -h`
16+
17+
for **Task 3**, use the following scripts[^1]:
18+
* MQM **score**: `python eval_document_mqm.py -h`
19+
* MQM **annotations**: `python eval_document_annotations.py -h`
20+
21+
Once you have checked that your system output on the dev data is correctly read by the right script, you can submit it using the CODALAB page corresponding to your subtask.
22+
23+
### Submission platforms
24+
25+
Predicitons should be submitted to a CODALAB page for each subtask:
26+
27+
Task 1, [sentence-level DA](https://competitions.codalab.org/competitions/24447)
28+
Task 1, [sentence-level DA **multilingual**](https://competitions.codalab.org/competitions/24447)
29+
30+
Task 2, [sentence-level HTER](https://competitions.codalab.org/competitions/24515)
31+
Task 2, [word-level HTER](https://competitions.codalab.org/competitions/24728)
32+
33+
Task 3, [doc-level MQM **score**](https://competitions.codalab.org/competitions/24762)
34+
Task 3, [doc-level **annotations**](https://competitions.codalab.org/competitions/24763)
35+
36+
[^1]: under MT License (source: [Deep-Spin](https://github.com/deep-spin/qe-evaluation))
37+

‎wmt20/eval_document_annotations.py

+331
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
"""Script to evaluate document-level QE as in the WMT19 shared task."""
2+
3+
import argparse
4+
from collections import defaultdict
5+
import numpy as np
6+
7+
8+
class OverlappingSpans(ValueError):
9+
pass
10+
11+
12+
class Span(object):
13+
def __init__(self, segment, start, end):
14+
"""A contiguous span of text in a particular segment.
15+
Args:
16+
segment: ID of the segment (0-based).
17+
start: Character-based start position.
18+
end: The position right after the last character.
19+
"""
20+
self.segment = segment
21+
self.start = start
22+
self.end = end
23+
assert self.end >= self.start
24+
25+
def __len__(self):
26+
"""Returns the length of the span in characters.
27+
Note: by convention, an empty span has length 1."""
28+
return max(1, self.end - self.start)
29+
30+
def count_overlap(self, span):
31+
"""Given another span, returns the number of matched characters.
32+
Zero if the two spans are in different segments.
33+
Args:
34+
span: another Span object.
35+
Returns:
36+
The number of matched characters.
37+
"""
38+
if self.segment != span.segment:
39+
return 0
40+
start = max(self.start, span.start)
41+
end = min(self.end, span.end)
42+
if end >= start:
43+
if span.start == span.end or self.start == self.end:
44+
assert start == end
45+
return 1 # By convention, the overlap with empty spans is 1.
46+
else:
47+
return end - start
48+
else:
49+
return 0
50+
51+
52+
class Annotation(object):
53+
def __init__(self, severity=None, spans=None):
54+
"""An annotation, which has a severity level (minor, major, or critical)
55+
and consists of one or more non-overlapping spans.
56+
57+
Args:
58+
severity: 'minor', 'major', or 'critical'.
59+
spans: A list of Span objects.
60+
"""
61+
# make sure that there is no overlap
62+
spans = sorted(spans, key=lambda span: (span.segment, span.start))
63+
segment = -1
64+
for span in spans:
65+
if span.segment != segment:
66+
# first span in this segment
67+
segment = span.segment
68+
last_end = span.end
69+
else:
70+
# second or later span
71+
if span.start < last_end:
72+
raise OverlappingSpans()
73+
last_end = span.end
74+
75+
self.severity = severity
76+
self.spans = spans
77+
78+
def __len__(self):
79+
"""Returns the sum of the span lengths (in characters)."""
80+
return sum([len(span) for span in self.spans])
81+
82+
def count_overlap(self, annotation, severity_match=None):
83+
"""Given another annotation with the same severity, returns the number
84+
of matched characters. If the severities are different, the result is
85+
penalized according to a severity match matrix.
86+
87+
Args:
88+
annotation: another Annotation object.
89+
severity_match: a dictionary of dictionaries containing match
90+
penalties for severity pairs.
91+
Returns:
92+
The number of matched characters, possibly penalized by a severity
93+
mismatch.
94+
"""
95+
# TODO: Maybe normalize by annotation length (e.g. intersection over
96+
# union)?
97+
# Note: since we're summing the matches, this won't work as expected
98+
# if there are overlapping spans (which we assume there aren't).
99+
matched = 0
100+
for span in self.spans:
101+
for annotation_span in annotation.spans:
102+
matched += span.count_overlap(annotation_span)
103+
# Scale overlap by a coefficient that takes into account mispredictions
104+
# of the severity. For example, predicting "major" when the error is
105+
# "critical" gives some partial credit. If None, give zero credit unless
106+
# the severity is correct.
107+
if severity_match:
108+
matched *= severity_match[self.severity][annotation.severity]
109+
else:
110+
matched *= (self.severity == annotation.severity)
111+
return matched
112+
113+
@classmethod
114+
def from_fields(cls, fields):
115+
"""Creates an Annotation object by loading from a list of string fields.
116+
117+
Args:
118+
fields: a list of strings containing annotations information. They
119+
are:
120+
- segment_id
121+
- annotation_start
122+
- annotation_length
123+
- severity
124+
125+
The first three fields may contain several integers separated by
126+
whitespaces, in case there are multiple spans.
127+
The two last fields are ignored.
128+
Example: "13 13 229 214 7 4 minor"
129+
"""
130+
segments = list(map(int, fields[0].split(' ')))
131+
starts = list(map(int, fields[1].split(' ')))
132+
lengths = list(map(int, fields[2].split(' ')))
133+
assert len(segments) == len(starts) == len(lengths)
134+
severity = fields[3]
135+
spans = [Span(segment, start, start + length)
136+
for segment, start, length in zip(segments, starts, lengths)]
137+
return cls(severity, spans)
138+
139+
@classmethod
140+
def from_string(cls, line):
141+
"""Creates an Annotation object by loading from a string.
142+
Args:
143+
line: tab-separated line containing the annotation information. The
144+
fields are:
145+
- document_id
146+
- segment_id
147+
- annotation_start
148+
- annotation_length
149+
- severity
150+
151+
Segment id, annotation start and length may contain several
152+
integers separated by whitespaces, in case there are multiple
153+
spans.
154+
Example: "A0034 13 13 229 214 7 4 minor"
155+
"""
156+
# Ignore the last two fields.
157+
fields = line.split('\t')
158+
assert len(fields) == 5
159+
return cls.from_fields(fields[1:])
160+
161+
def to_string(self):
162+
"""Return a string representation of this annotation.
163+
164+
This is the representation expected in the output file, without notes"""
165+
segments = []
166+
starts = []
167+
lengths = []
168+
for span in self.spans:
169+
segments.append(str(span.segment))
170+
starts.append(str(span.start))
171+
lengths.append(str(span.end - span.start))
172+
173+
segment_string = ' '.join(segments)
174+
start_string = ' '.join(starts)
175+
length_string = ' '.join(lengths)
176+
return '\t'.join([segment_string, start_string, length_string,
177+
self.severity])
178+
179+
180+
class Evaluator(object):
181+
def __init__(self):
182+
"""A document-level QE evaluator."""
183+
# The severity match matrix will give some credit when the
184+
# severity is slighted mispredicted ("minor" <> "major" and
185+
# "major" <> "critical"), but not for extreme mispredictions
186+
# ("minor" <> "critical").
187+
self.severity_match = {'minor': {'minor': 1.0,
188+
'major': 0.5,
189+
'critical': 0.0},
190+
'major': {'minor': 0.5,
191+
'major': 1.0,
192+
'critical': 0.5},
193+
'critical': {'minor': 0.0,
194+
'major': 0.5,
195+
'critical': 1.0}}
196+
197+
def run(self, system, reference, verbose=False):
198+
"""Given system and reference documents, computes the macro-averaged F1
199+
across all documents.
200+
201+
Args:
202+
system: a dictionary mapping names (doc id's) to lists of
203+
Annotations produced by a QE system.
204+
reference: a dictionary mapping names (doc id's) to lists of
205+
reference Annotations.
206+
Returns:
207+
The macro-averaged F1 score.
208+
"""
209+
total_f1 = 0.
210+
for doc_id in system:
211+
# both dicts are defaultdics, returning a empty list if there are no
212+
# annotations for that doc_id
213+
reference_annotations = reference[doc_id]
214+
system_annotations = system[doc_id]
215+
f1 = self._compare_document(system_annotations,
216+
reference_annotations)
217+
if verbose:
218+
print(doc_id)
219+
print(f1)
220+
total_f1 += f1
221+
total_f1 /= len(system)
222+
return total_f1
223+
224+
def _compare_document(self, system, reference):
225+
"""Compute the F1 score for a single document, given a system output
226+
and a reference. This is done by computing a precision according to the
227+
best possible matching of annotations from the system's perspective,
228+
and a recall according to the best possible matching of annotations
229+
from the reference perspective. Gives some partial credit to
230+
annotations that match with the wrong severity.
231+
Args:
232+
system: dictionary mapping doc id's to lists of annotations
233+
reference: dictionary mapping doc id's to lists of annotations
234+
Returns:
235+
The F1 score of a single document.
236+
"""
237+
all_matched = np.zeros((len(system), len(reference)))
238+
for i, system_annotation in enumerate(system):
239+
for j, reference_annotation in enumerate(reference):
240+
matched = reference_annotation.count_overlap(
241+
system_annotation,
242+
severity_match=self.severity_match)
243+
all_matched[i, j] = matched
244+
245+
lengths_sys = np.array([len(annotation) for annotation in system])
246+
lengths_ref = np.array([len(annotation) for annotation in reference])
247+
248+
if lengths_sys.sum() == 0:
249+
# no system annotations
250+
precision = 1.
251+
elif lengths_ref.sum() == 0:
252+
# there were no references
253+
precision = 0.
254+
else:
255+
# normalize by annotation length
256+
precision_by_annotation = all_matched.max(1) / lengths_sys
257+
precision = precision_by_annotation.mean()
258+
259+
# same as above, for recall now
260+
if lengths_ref.sum() == 0:
261+
recall = 1.
262+
elif lengths_sys.sum() == 0:
263+
recall = 0.
264+
else:
265+
recall_by_annotation = all_matched.max(0) / lengths_ref
266+
recall = recall_by_annotation.mean()
267+
268+
if not precision + recall:
269+
f1 = 0.
270+
else:
271+
f1 = 2*precision*recall / (precision + recall)
272+
assert(0. <= f1 <= 1.)
273+
274+
return f1
275+
276+
277+
def load_annotations(file_path):
278+
"""Loads a file containing annotations for multiple documents.
279+
280+
The file should contain lines with the following format:
281+
<DOCUMENT ID> <LINES> <SPAN START POSITIONS> <SPAN LENGTHS> <SEVERITY>
282+
283+
Fields are separated by tabs; LINE, SPAN START POSITIONS and SPAN LENGTHS
284+
can have a list of values separated by white space.
285+
286+
Args:
287+
file_path: path to the file.
288+
Returns:
289+
a dictionary mapping document id's to a list of annotations.
290+
"""
291+
annotations = defaultdict(list)
292+
293+
with open(file_path, 'r', encoding='utf8') as f:
294+
for i, line in enumerate(f):
295+
line = line.strip()
296+
if not line:
297+
continue
298+
299+
fields = line.split('\t')
300+
doc_id = fields[0]
301+
302+
try:
303+
annotation = Annotation.from_fields(fields[1:])
304+
except OverlappingSpans:
305+
msg = 'Overlapping spans when reading line %d of file %s '
306+
msg %= (i, file_path)
307+
print(msg)
308+
continue
309+
310+
annotations[doc_id].append(annotation)
311+
312+
return annotations
313+
314+
315+
def main():
316+
parser = argparse.ArgumentParser(description=__doc__)
317+
parser.add_argument('system', help='System annotations')
318+
parser.add_argument('ref', help='Reference annotations')
319+
parser.add_argument('-v', help='Show score by document',
320+
action='store_true', dest='verbose')
321+
args = parser.parse_args()
322+
323+
system = load_annotations(args.system)
324+
reference = load_annotations(args.ref)
325+
evaluator = Evaluator()
326+
f1 = evaluator.run(system, reference, args.verbose)
327+
print('Final F1:', f1)
328+
329+
330+
if __name__ == '__main__':
331+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.