-
Notifications
You must be signed in to change notification settings - Fork 13
Topic modeling #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Topic modeling #37
Changes from 12 commits
b07786a
257131d
265cc83
bb5f036
2a15b32
c3faebf
c6b764d
a17ed34
d5592f0
58513f7
aac7d98
ef9853f
c97c6df
aaffd94
58db43a
30af4c2
31a35ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,34 @@ | |
| """ | ||
| import collections | ||
| import joblib as jl | ||
| from sklearn.base import BaseEstimator, TransformerMixin | ||
| from six import iteritems | ||
| from decorator import decorator | ||
| import re | ||
|
|
||
| try: | ||
| from spacy.lang.en.stop_words import STOP_WORDS | ||
| from gensim.corpora import Dictionary | ||
| from gensim import sklearn_api | ||
| import gensim | ||
| spacy = True | ||
| except ImportError: | ||
| spacy = None | ||
| gensim = None | ||
|
|
||
|
|
||
| @decorator | ||
| def check_spacy(func, *args, **kwargs): | ||
| if spacy is None: | ||
| raise RuntimeError('Must install spacy to use {}'.format(func)) | ||
| return func(*args, **kwargs) | ||
|
|
||
|
|
||
| @decorator | ||
| def check_gensim(func, *args, **kwargs): | ||
| if gensim is None: | ||
| raise RuntimeError('Must install gensim to use {}'.format(func)) | ||
| return func(*args, **kwargs) | ||
|
|
||
|
|
||
| class _PersistanceMixin(object): | ||
|
|
@@ -85,3 +113,41 @@ class CandidateModel( | |
| parameter values to test as values | ||
| """ | ||
| pass | ||
|
|
||
|
|
||
| class QGLdaModel(BaseEstimator, TransformerMixin): | ||
|
||
| @check_gensim | ||
| @check_spacy | ||
| def __init__(self, word_regex=r'\b[A-z]{2,}\b', stop_words=STOP_WORDS): | ||
|
||
| self.stop_words = stop_words | ||
| self.word_regex = re.compile(word_regex) | ||
|
|
||
| def transform(self, driver): | ||
| self.test_corpus = self.create_corpus(driver) | ||
| return self.model.transform(self.test_corpus) | ||
|
|
||
| def create_corpus(self, driver): | ||
| return [self.dictionary.doc2bow([i.group(0).lower() | ||
| for i in self.word_regex.finditer(doc.text)]) | ||
| for doc in driver.stream()] | ||
|
|
||
| def fit(self, driver, alpha=None, eta=None, num_topics=1, passes=1): | ||
| self.dictionary = Dictionary([[i.group(0).lower() | ||
| for i in self.word_regex | ||
| .finditer(doc.text)] | ||
| for doc in driver.stream()]) | ||
| stop_ids = [self.dictionary.token2id[stopword] for stopword | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be better to only pass the dictionary words that aren't in stop_words? |
||
| in self.stop_words if stopword in self.dictionary.token2id] | ||
| once_ids = [tokenid for tokenid, docfreq in | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we doing this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filtering out words that only occur once was recommended in the Gensim documentation - beyond that, I don't know if it actually improves the performance of the model. |
||
| iteritems(self.dictionary.dfs) if docfreq == 1] | ||
| self.dictionary.filter_tokens(stop_ids + once_ids) | ||
| self.corpus = self.create_corpus(driver) | ||
| self.model = sklearn_api.ldamodel.LdaTransformer( | ||
| alpha=alpha, | ||
| eta=eta, | ||
| num_topics=num_topics, | ||
| passes=passes, | ||
| id2word=self.dictionary | ||
| ) | ||
| self.model.fit(self.corpus) | ||
| return self | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,10 @@ def find_version(*file_paths): | |
| 'nlp': [ | ||
| 'textblob', | ||
| 'nltk', | ||
| ], | ||
| 'topic_modeling': [ | ||
| 'gensim', | ||
| 'spacy' | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we not need a spacy corpus as well? |
||
| ] | ||
| }, | ||
| entry_points={ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # import pytest | ||
| import subprocess | ||
| import quantgov.estimator | ||
| import quantgov | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| PSEUDO_CORPUS_PATH = Path(__file__).resolve().parent.joinpath('pseudo_corpus') | ||
| driver = quantgov.load_driver(PSEUDO_CORPUS_PATH) | ||
|
|
||
|
|
||
| def test_topic_model(): | ||
| sample = quantgov.estimator.structures.QGLdaModel() | ||
| sample.fit(driver, num_topics=2) | ||
| sample.transform(driver) | ||
|
|
||
|
|
||
| def check_output(cmd): | ||
| return ( | ||
| subprocess.check_output(cmd, universal_newlines=True) | ||
| .replace('\n\n', '\n') | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're literally only using spacy here for the stopwords, can't we somehow find the sklearn stopwords used in the
CountVectorizer? That's got to be importable from somewhere.