diff --git a/.travis.yml b/.travis.yml index 75cc0ab..1833fc7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ python: install: - pip install ".[testing]" - pip install ".[nlp]" +- pip install ".[topic_modeling]" - python -m nltk.downloader punkt stopwords wordnet script: pytest deploy: diff --git a/quantgov/estimator/candidate_sets.py b/quantgov/estimator/candidate_sets.py index 97978fa..38e6813 100644 --- a/quantgov/estimator/candidate_sets.py +++ b/quantgov/estimator/candidate_sets.py @@ -17,9 +17,22 @@ import sklearn.multioutput import sklearn.pipeline import sklearn.feature_extraction +from . import structures + +try: + import gensim +except ImportError: + gensim = None +try: + import gensim + import spacy +except ImportError: + spacy = None + gensim = None import quantgov.estimator + classification = [ quantgov.estimator.CandidateModel( name="Random Forests", @@ -69,3 +82,16 @@ } ), ] + +if gensim and spacy: + topic_modeling = [ + quantgov.estimator.CandidateModel( + name="LDA", + model=structures.GensimLda(), + parameters={ + 'eta': [0.1, 0.05, 0.01], + 'passes': [1, 2, 3], + 'num_topics': [10, 50, 100] + } + ), + ] diff --git a/quantgov/estimator/structures.py b/quantgov/estimator/structures.py index 8ef59ea..d88d570 100644 --- a/quantgov/estimator/structures.py +++ b/quantgov/estimator/structures.py @@ -5,6 +5,28 @@ """ import collections import joblib as jl +from sklearn.base import BaseEstimator, TransformerMixin +from six import iteritems +from decorator import decorator +import re + +try: + from gensim.corpora import Dictionary + from gensim import sklearn_api + import gensim +except ImportError: + gensim = None + + +from sklearn.feature_extraction import stop_words +STOP_WORDS = stop_words.ENGLISH_STOP_WORDS + + +@decorator +def check_gensim(func, *args, **kwargs): + if gensim is None: + raise RuntimeError('Must install gensim to use {}'.format(func)) + return func(*args, **kwargs) class _PersistanceMixin(object): @@ -85,3 +107,50 @@ class CandidateModel( parameter values to test as values """ pass + + +class GensimLda(BaseEstimator, TransformerMixin): + @check_gensim + def __init__(self, word_pattern=r'\b[A-z]{2,}\b', stop_words='en'): + if stop_words == 'en': + self.stop_words = STOP_WORDS + elif not stop_words: + self.stop_words = None + else: + self.stop_words = stop_words + + self.word_pattern = re.compile(word_pattern) + + def transform(self, driver): + self.test_corpus = self.create_corpus(driver) + return self.model.transform(self.test_corpus) + + def create_corpus(self, driver): + return [self.dictionary.doc2bow([i.group(0).lower() + for i in self.word_pattern.finditer(doc.text)]) + for doc in driver.stream()] + + def show_topics(self): + return self.model.gensim_model.show_topics() + + def fit(self, driver, alpha=None, eta=None, num_topics=1, + passes=1, min_wf=1): + self.dictionary = Dictionary([[i.group(0).lower() + for i in self.word_pattern + .finditer(doc.text)] + for doc in driver.stream()]) + stop_ids = [self.dictionary.token2id[stopword] for stopword + in self.stop_words if stopword in self.dictionary.token2id] + once_ids = [tokenid for tokenid, docfreq in + iteritems(self.dictionary.dfs) if docfreq <= min_wf] + self.dictionary.filter_tokens(stop_ids + once_ids) + self.corpus = self.create_corpus(driver) + self.model = sklearn_api.ldamodel.LdaTransformer( + alpha=alpha, + eta=eta, + num_topics=num_topics, + passes=passes, + id2word=self.dictionary + ) + self.model.fit(self.corpus) + return self diff --git a/setup.py b/setup.py index 3d424b1..910fde8 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,10 @@ def find_version(*file_paths): 'nlp': [ 'textblob', 'nltk', + ], + 'topic_modeling': [ + 'gensim', + 'spacy' ] }, entry_points={ diff --git a/tests/test_estimators.py b/tests/test_estimators.py new file mode 100644 index 0000000..2b52c5d --- /dev/null +++ b/tests/test_estimators.py @@ -0,0 +1,22 @@ +# import pytest +import subprocess +import quantgov.estimator +import quantgov + +from pathlib import Path + +PSEUDO_CORPUS_PATH = Path(__file__).resolve().parent.joinpath('pseudo_corpus') +driver = quantgov.load_driver(PSEUDO_CORPUS_PATH) + + +def test_topic_model(): + sample = quantgov.estimator.structures.GensimLda() + sample.fit(driver, num_topics=2) + sample.transform(driver) + + +def check_output(cmd): + return ( + subprocess.check_output(cmd, universal_newlines=True) + .replace('\n\n', '\n') + )