Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions smart_importer/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,15 @@ class StringVectorizer(CountVectorizer): # type: ignore[misc]
"""Subclass of CountVectorizer that handles empty data."""

def __init__(
self, tokenizer: Callable[[str], list[str]] | None = None
self,
tokenizer: Callable[[str], list[str]] | None = None,
token_pattern: None | str = r"(?u)\b\w\w+\b",
) -> None:
super().__init__(ngram_range=(1, 3), tokenizer=tokenizer)
super().__init__(
ngram_range=(1, 3),
tokenizer=tokenizer,
token_pattern=token_pattern,
)

def fit_transform(self, raw_documents: list[str], y: None = None) -> Any:
try:
Expand All @@ -86,7 +92,9 @@ def transform(self, raw_documents: list[str], _y: None = None) -> Any:


def get_pipeline(
attribute: str, tokenizer: Callable[[str], list[str]] | None
attribute: str,
tokenizer: Callable[[str], list[str]] | None,
token_pattern: None | str = r"(?u)\b\w\w+\b",
) -> Any:
"""Make a pipeline for a given entry attribute."""

Expand All @@ -95,5 +103,6 @@ def get_pipeline(

# Treat all other attributes as strings.
return make_pipeline(
AttrGetter(attribute, default=""), StringVectorizer(tokenizer)
AttrGetter(attribute, default=""),
StringVectorizer(tokenizer, token_pattern=token_pattern),
)
15 changes: 13 additions & 2 deletions smart_importer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class EntryPredictor:
string_tokenizer: Tokenizer can let smart_importer support more
languages. This parameter should be an callable function with
string parameter and the returning should be a list.
string_token_pattern: Regex for tokenizing text when no custom
tokenizer is provided. Set to None to disable.
denylist_accounts: Transations with any of these accounts will be
removed from the training data.
"""
Expand All @@ -52,11 +54,12 @@ class EntryPredictor:
weights: dict[str, float] = {}
attribute: str | None = None

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments,too-many-arguments
self,
predict: bool = True,
overwrite: bool = False,
string_tokenizer: Callable[[str], list[str]] | None = None,
string_token_pattern: str | None = r"(?u)\b\w\w+\b",
denylist_accounts: list[str] | None = None,
) -> None:
super().__init__()
Expand All @@ -69,6 +72,7 @@ def __init__(
self.predict = predict
self.overwrite = overwrite
self.string_tokenizer = string_tokenizer
self.string_token_pattern = string_token_pattern

def wrap(self, importer: Importer) -> ImporterWrapper:
"""Wrap an existing importer with smart importer logic.
Expand Down Expand Up @@ -191,7 +195,14 @@ def define_pipeline(self) -> None:
"""Defines the machine learning pipeline based on given weights."""

transformers = [
(attribute, get_pipeline(attribute, self.string_tokenizer))
(
attribute,
get_pipeline(
attribute,
tokenizer=self.string_tokenizer,
token_pattern=self.string_token_pattern,
),
)
for attribute in self.weights
]

Expand Down