diff --git a/beangulp/extract.py b/beangulp/extract.py index 58224bc..1d5c211 100644 --- a/beangulp/extract.py +++ b/beangulp/extract.py @@ -1,6 +1,7 @@ __copyright__ = "Copyright (C) 2016-2017 Martin Blais" __license__ = "GNU GPLv2" +from collections import defaultdict import io import bisect import datetime @@ -8,7 +9,7 @@ import textwrap import warnings -from typing import Callable +from typing import Callable, Dict, Iterator from typing import TYPE_CHECKING from typing import Tuple from typing import List @@ -199,18 +200,21 @@ def mark_duplicate_entries( # of each newly extracted entry requires the existing entries # to be sorted by date. existing.sort(key=operator.attrgetter("date")) - dates = [entry.date for entry in existing] + dates: Dict[type, List[data.Directive]] = defaultdict(list) + for entry in existing: + dates[type(entry)].append(entry) - def entries_date_window_iterator(date): - lo = bisect.bisect_left(dates, date - window) - hi = bisect.bisect_right(dates, date + window) + def entries_date_window_iterator(entry_type: type, date: datetime.date) -> Iterator[data.Directive]: + lo = bisect.bisect_left(dates[entry_type], date - window, key=operator.attrgetter('date')) + hi = bisect.bisect_right(dates[entry_type], date + window, key=operator.attrgetter('date')) for i in range(lo, hi): - yield existing[i] + yield dates[entry_type][i] for entry in entries: - for target in entries_date_window_iterator(entry.date): - if compare(entry, target): - entry.meta[DUPLICATE] = target + for target in entries_date_window_iterator(type(entry), entry.date): + if type(entry) == type(target): + if compare(entry, target): + entry.meta[DUPLICATE] = target def print_extracted_entries(extracted: List[ExtractedEntry], output: io.TextIOBase) -> None: diff --git a/beangulp/extract_test.py b/beangulp/extract_test.py index ceaa443..39f5056 100644 --- a/beangulp/extract_test.py +++ b/beangulp/extract_test.py @@ -63,6 +63,7 @@ def test_mark_duplicate_entries(self): 1970-01-02 * "Test" Assets:Tests 20.00 USD + 1970-01-03 balance Assets:Tests 20.00 USD """) ) compare = similar.heuristic_comparator() @@ -70,6 +71,8 @@ def test_mark_duplicate_entries(self): self.assertTrue(entries[0].meta[extract.DUPLICATE]) self.assertNotIn(extract.DUPLICATE, entries[1].meta) + self.assertFalse(extract.DUPLICATE in entries[2].meta) + class TestPrint(unittest.TestCase): def test_print_extracted_entries(self):