|
19 | 19 | from encexp.tests.test_utils import samples |
20 | 20 | from encexp.utils import compute_b4msa_vocabulary, compute_seqtm_vocabulary |
21 | 21 | from encexp.build_encexp import build_encexp |
22 | | -from encexp.text_repr import SeqTM, EncExp |
| 22 | +from encexp.text_repr import SeqTM, EncExp, TM, EncExpT |
23 | 23 | from sklearn.base import clone |
24 | 24 |
|
25 | 25 |
|
| 26 | +def test_tm(): |
| 27 | + """Test TM""" |
| 28 | + tm = TM(voc_source='mix') |
| 29 | + _ = tm['buenos dias mxeico'] |
| 30 | + assert len(_) == 13 |
| 31 | + |
26 | 32 | def test_seqtm(): |
27 | 33 | """Test SeqTM""" |
28 | 34 |
|
@@ -377,9 +383,73 @@ def test_EncExp_build_tailored(): |
377 | 383 | enc = EncExp(lang='es', |
378 | 384 | tailored=True) |
379 | 385 | w = enc.weights |
380 | | - enc.build_tailored(mx + ar) |
| 386 | + enc.build_tailored(mx + ar, load=True) |
381 | 387 | assert isfile(enc.tailored) |
| 388 | + assert hasattr(enc, '_tailored_built') |
382 | 389 | enc = EncExp(lang='es', |
383 | 390 | tailored=enc.tailored).fit(mx + ar, y) |
384 | 391 | assert np.fabs(w - enc.weights).sum() != 0 |
| 392 | + enc2 = clone(enc) |
| 393 | + assert hasattr(enc2, '_tailored_built') |
| 394 | + assert hasattr(enc2, '_estimator') |
385 | 395 | # os.unlink(enc.tailored) |
| 396 | + |
| 397 | +def test_pipeline_tm(): |
| 398 | + """Test Pipeline""" |
| 399 | + samples() |
| 400 | + mx = list(tweet_iterator('es-mx-sample.json')) |
| 401 | + samples(filename='es-ar-sample.json.zip') |
| 402 | + ar = list(tweet_iterator('es-ar-sample.json')) |
| 403 | + y = ['mx'] * len(mx) |
| 404 | + y += ['ar'] * len(ar) |
| 405 | + |
| 406 | + from sklearn.pipeline import Pipeline |
| 407 | + from sklearn.svm import LinearSVC |
| 408 | + from sklearn.model_selection import GridSearchCV |
| 409 | + from sklearn.model_selection import StratifiedShuffleSplit |
| 410 | + |
| 411 | + pipe = Pipeline([('bow', 'passthrough'), |
| 412 | + ('cl', LinearSVC(class_weight='balanced'))]) |
| 413 | + params = {'cl__C': [0.01, 0.1, 1, 10], |
| 414 | + 'bow': [SeqTM(lang='es', voc_source='mix'), |
| 415 | + TM(lang='es', voc_source='mix')]} |
| 416 | + sss = StratifiedShuffleSplit(random_state=0, |
| 417 | + n_splits=1, |
| 418 | + test_size=0.3) |
| 419 | + |
| 420 | + grid = GridSearchCV(pipe, |
| 421 | + param_grid=params, |
| 422 | + cv=sss, |
| 423 | + n_jobs=-1, |
| 424 | + scoring='f1_macro').fit(mx + ar, y) |
| 425 | + assert grid.best_score_ > 0.7 |
| 426 | + |
| 427 | + |
| 428 | +def test_pipeline_encexp(): |
| 429 | + """Test Pipeline in EncExpT""" |
| 430 | + from sklearn.pipeline import Pipeline |
| 431 | + from sklearn.svm import LinearSVC |
| 432 | + from sklearn.model_selection import GridSearchCV |
| 433 | + from sklearn.model_selection import StratifiedShuffleSplit |
| 434 | + |
| 435 | + samples() |
| 436 | + mx = list(tweet_iterator('es-mx-sample.json')) |
| 437 | + samples(filename='es-ar-sample.json.zip') |
| 438 | + ar = list(tweet_iterator('es-ar-sample.json')) |
| 439 | + y = ['mx'] * len(mx) |
| 440 | + y += ['ar'] * len(ar) |
| 441 | + |
| 442 | + pipe = Pipeline([('encexp', EncExpT(lang='es')), |
| 443 | + ('cl', LinearSVC(class_weight='balanced'))]) |
| 444 | + params = {'cl__C': [0.01, 0.1, 1, 10], |
| 445 | + 'encexp__voc_source': ['mix', 'noGeo']} |
| 446 | + sss = StratifiedShuffleSplit(random_state=0, |
| 447 | + n_splits=1, |
| 448 | + test_size=0.3) |
| 449 | + |
| 450 | + grid = GridSearchCV(pipe, |
| 451 | + param_grid=params, |
| 452 | + cv=sss, |
| 453 | + n_jobs=1, |
| 454 | + scoring='f1_macro').fit(mx + ar, y) |
| 455 | + assert grid.best_score_ > 0.7 |
0 commit comments