Skip to content

Commit b7449bd

Browse files
authored
Merge pull request #21 from INGEOTEC/develop
EncExp
2 parents 3bf368b + cda9ef9 commit b7449bd

4 files changed

Lines changed: 222 additions & 112 deletions

File tree

encexp/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
import sys
1616

1717
if not '-m' in sys.argv:
18-
from encexp.text_repr import EncExp
18+
from encexp.text_repr import EncExp, EncExpT, SeqTM, TM
1919

20-
__version__ = "0.0.18"
20+
__version__ = "0.0.19"

encexp/download.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
def download_seqtm(lang, voc_size_exponent: int=13,
2424
output=None, voc_source='noGeo',
25+
prefix='seqtm',
2526
prefix_suffix: bool=True):
2627
"""Download SeqTM vocabulary"""
2728
if not isdir(MODELS):
@@ -32,7 +33,7 @@ def download_seqtm(lang, voc_size_exponent: int=13,
3233
for flag in [voc_source]:
3334
if flag is not None:
3435
flags.append(flag)
35-
voc_fname = f'seqtm_{"_".join(flags)}_{lang}_{voc_size_exponent}.json.gz'
36+
voc_fname = f'{prefix}_{"_".join(flags)}_{lang}_{voc_size_exponent}.json.gz'
3637
if output is None:
3738
output = join(MODELS, voc_fname)
3839
if isfile(output):

encexp/tests/test_text_repr.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@
1919
from encexp.tests.test_utils import samples
2020
from encexp.utils import compute_b4msa_vocabulary, compute_seqtm_vocabulary
2121
from encexp.build_encexp import build_encexp
22-
from encexp.text_repr import SeqTM, EncExp
22+
from encexp.text_repr import SeqTM, EncExp, TM, EncExpT
2323
from sklearn.base import clone
2424

2525

26+
def test_tm():
27+
"""Test TM"""
28+
tm = TM(voc_source='mix')
29+
_ = tm['buenos dias mxeico']
30+
assert len(_) == 13
31+
2632
def test_seqtm():
2733
"""Test SeqTM"""
2834

@@ -377,9 +383,73 @@ def test_EncExp_build_tailored():
377383
enc = EncExp(lang='es',
378384
tailored=True)
379385
w = enc.weights
380-
enc.build_tailored(mx + ar)
386+
enc.build_tailored(mx + ar, load=True)
381387
assert isfile(enc.tailored)
388+
assert hasattr(enc, '_tailored_built')
382389
enc = EncExp(lang='es',
383390
tailored=enc.tailored).fit(mx + ar, y)
384391
assert np.fabs(w - enc.weights).sum() != 0
392+
enc2 = clone(enc)
393+
assert hasattr(enc2, '_tailored_built')
394+
assert hasattr(enc2, '_estimator')
385395
# os.unlink(enc.tailored)
396+
397+
def test_pipeline_tm():
398+
"""Test Pipeline"""
399+
samples()
400+
mx = list(tweet_iterator('es-mx-sample.json'))
401+
samples(filename='es-ar-sample.json.zip')
402+
ar = list(tweet_iterator('es-ar-sample.json'))
403+
y = ['mx'] * len(mx)
404+
y += ['ar'] * len(ar)
405+
406+
from sklearn.pipeline import Pipeline
407+
from sklearn.svm import LinearSVC
408+
from sklearn.model_selection import GridSearchCV
409+
from sklearn.model_selection import StratifiedShuffleSplit
410+
411+
pipe = Pipeline([('bow', 'passthrough'),
412+
('cl', LinearSVC(class_weight='balanced'))])
413+
params = {'cl__C': [0.01, 0.1, 1, 10],
414+
'bow': [SeqTM(lang='es', voc_source='mix'),
415+
TM(lang='es', voc_source='mix')]}
416+
sss = StratifiedShuffleSplit(random_state=0,
417+
n_splits=1,
418+
test_size=0.3)
419+
420+
grid = GridSearchCV(pipe,
421+
param_grid=params,
422+
cv=sss,
423+
n_jobs=-1,
424+
scoring='f1_macro').fit(mx + ar, y)
425+
assert grid.best_score_ > 0.7
426+
427+
428+
def test_pipeline_encexp():
429+
"""Test Pipeline in EncExpT"""
430+
from sklearn.pipeline import Pipeline
431+
from sklearn.svm import LinearSVC
432+
from sklearn.model_selection import GridSearchCV
433+
from sklearn.model_selection import StratifiedShuffleSplit
434+
435+
samples()
436+
mx = list(tweet_iterator('es-mx-sample.json'))
437+
samples(filename='es-ar-sample.json.zip')
438+
ar = list(tweet_iterator('es-ar-sample.json'))
439+
y = ['mx'] * len(mx)
440+
y += ['ar'] * len(ar)
441+
442+
pipe = Pipeline([('encexp', EncExpT(lang='es')),
443+
('cl', LinearSVC(class_weight='balanced'))])
444+
params = {'cl__C': [0.01, 0.1, 1, 10],
445+
'encexp__voc_source': ['mix', 'noGeo']}
446+
sss = StratifiedShuffleSplit(random_state=0,
447+
n_splits=1,
448+
test_size=0.3)
449+
450+
grid = GridSearchCV(pipe,
451+
param_grid=params,
452+
cv=sss,
453+
n_jobs=1,
454+
scoring='f1_macro').fit(mx + ar, y)
455+
assert grid.best_score_ > 0.7

0 commit comments

Comments
 (0)