diff --git a/beanprice/price.py b/beanprice/price.py index 38fb04a..898555a 100644 --- a/beanprice/price.py +++ b/beanprice/price.py @@ -6,6 +6,7 @@ import argparse import collections import datetime +from decimal import Decimal import functools from os import path import shelve @@ -38,10 +39,13 @@ # module: A Python module, the module to be called to create a price source. # symbol: A ticker symbol in the universe of the source. # invert: A boolean, true if we need to invert the currency. +# multiplier: A Decimal instance to be multiplied on prices from the source. +# This is useful with sources returning 1.23 USD as 123. class PriceSource(NamedTuple): module: Any symbol: str invert: bool + multiplier: Decimal # A dated price source description. @@ -151,7 +155,7 @@ def parse_single_source(source: str) -> PriceSource: Source specifications follow the syntax: - /[^] + [*]/[^] The is resolved against the Python path, but first looked up under the package where the default price extractors lie. @@ -163,12 +167,15 @@ def parse_single_source(source: str) -> PriceSource: Raises: ValueError: If invalid. """ - match = re.match(r'([a-zA-Z]+[a-zA-Z0-9\._]+)/(\^?)([a-zA-Z0-9:=_\-\.\(\)]+)$', source) + match = re.match(r'(?:([0-9]+(?:\.[0-9]+)?)\*)?' + r'([a-zA-Z]+[a-zA-Z0-9\._]+)/(\^?)([a-zA-Z0-9:=_\-\.\(\)]+)$', source) if not match: raise ValueError('Invalid source name: "{}"'.format(source)) - short_module_name, invert, symbol = match.groups() + multiplier_str, short_module_name, invert, symbol = match.groups() module = import_source(short_module_name) - return PriceSource(module, symbol, bool(invert)) + multiplier = Decimal(multiplier_str) if multiplier_str else Decimal(1) + print(f'{multiplier=} {multiplier_str=}') + return PriceSource(module, symbol, bool(invert), multiplier) def import_source(module_name: str): @@ -323,7 +330,7 @@ def get_price_jobs_at_date(entries: data.Entries, # If there are no sources, create a default one. if not psources: - psources = [PriceSource(default_source, base, False)] + psources = [PriceSource(default_source, base, False, Decimal(1))] jobs.append(DatedPrice(base, quote, date, psources)) return sorted(jobs) @@ -599,7 +606,7 @@ def fetch_price(dprice: DatedPrice, swap_inverted: bool = False) -> Optional[dat base = dprice.base quote = dprice.quote or srcprice.quote_currency - price = srcprice.price + price = srcprice.price * psource.multiplier # Invert the rate if requested. if psource.invert: diff --git a/beanprice/price_test.py b/beanprice/price_test.py index 944d3cf..63ac6ce 100644 --- a/beanprice/price_test.py +++ b/beanprice/price_test.py @@ -26,6 +26,7 @@ PS = price.PriceSource +ONE = Decimal(1) def run_with_args(function, args, runner_file=None): @@ -225,7 +226,7 @@ def test_expressions(self): self.assertEqual( [price.DatedPrice( 'AAPL', 'USD', None, - [price.PriceSource(yahoo, 'AAPL', False)])], jobs) + [PS(yahoo, 'AAPL', False, ONE)])], jobs) class TestClobber(cmptest.TestCase): @@ -292,7 +293,7 @@ def test_fetch_price__naive_time_no_timeozne(self, fetch_cached): dprice = price.DatedPrice('JPY', 'USD', datetime.date(2015, 11, 22), None) with self.assertRaises(ValueError): price.fetch_price(dprice._replace(sources=[ - price.PriceSource(yahoo, 'USDJPY', False)]), False) + PS(yahoo, 'USDJPY', False, ONE)]), False) class TestInverted(unittest.TestCase): @@ -309,23 +310,41 @@ def setUp(self): def test_fetch_price__normal(self): entry = price.fetch_price(self.dprice._replace(sources=[ - price.PriceSource(yahoo, 'USDJPY', False)]), False) + PS(yahoo, 'USDJPY', False, ONE)]), False) self.assertEqual(('JPY', 'USD'), (entry.currency, entry.amount.currency)) self.assertEqual(Decimal('125.00'), entry.amount.number) def test_fetch_price__inverted(self): entry = price.fetch_price(self.dprice._replace(sources=[ - price.PriceSource(yahoo, 'USDJPY', True)]), False) + PS(yahoo, 'USDJPY', True, ONE)]), False) self.assertEqual(('JPY', 'USD'), (entry.currency, entry.amount.currency)) self.assertEqual(Decimal('0.008'), entry.amount.number) def test_fetch_price__swapped(self): entry = price.fetch_price(self.dprice._replace(sources=[ - price.PriceSource(yahoo, 'USDJPY', True)]), True) + PS(yahoo, 'USDJPY', True, ONE)]), True) self.assertEqual(('USD', 'JPY'), (entry.currency, entry.amount.currency)) self.assertEqual(Decimal('125.00'), entry.amount.number) +class TestMultiplier(unittest.TestCase): + + def test_multiplier(self): + fetch_cached = mock.patch('beanprice.price.fetch_cached_price').start() + self.addCleanup(mock.patch.stopall) + fetch_cached.return_value = SourcePrice( + Decimal('16824.00'), datetime.datetime(2023, 1, 1, 16, 0, 0, + tzinfo=tz.tzlocal()), + None) + dprice = price.DatedPrice( + 'GBP', 'XSDR', datetime.date(2023, 1, 1), [ + PS(yahoo, 'XSDR.L', False, Decimal('0.01')), + ]) + entry = price.fetch_price(dprice) + self.assertEqual(('GBP', 'XSDR'), (entry.currency, entry.amount.currency)) + self.assertEqual('168.2400', str(entry.amount.number)) + + class TestImportSource(unittest.TestCase): def test_import_source_valid(self): @@ -352,22 +371,25 @@ def test_source_invalid(self): with self.assertRaises(ImportError): price.parse_single_source('invalid.module.name/NASDAQ:AAPL') - def test_source_valid(self): - psource = price.parse_single_source('yahoo/CNYUSD=X') - self.assertEqual(PS(yahoo, 'CNYUSD=X', False), psource) - # Make sure that an invalid name at the tail doesn't succeed. with self.assertRaises(ValueError): psource = price.parse_single_source('yahoo/CNYUSD&X') + def test_source_valid(self): + psource = price.parse_single_source('yahoo/CNYUSD=X') + self.assertEqual(PS(yahoo, 'CNYUSD=X', False, ONE), psource) + psource = price.parse_single_source('beanprice.sources.yahoo/AAPL') - self.assertEqual(PS(yahoo, 'AAPL', False), psource) + self.assertEqual(PS(yahoo, 'AAPL', False, ONE), psource) + + psource = price.parse_single_source('0.01*yahoo/XSDR.L') + self.assertEqual(PS(yahoo, 'XSDR.L', False, Decimal('0.01')), psource) class TestParseSourceMap(unittest.TestCase): def _clean_source_map(self, smap): - return {currency: [PS(s[0].__name__, s[1], s[2]) for s in sources] + return {currency: [PS(s[0].__name__, s[1], s[2], s[3]) for s in sources] for currency, sources in smap.items()} def test_source_map_invalid(self): @@ -378,39 +400,50 @@ def test_source_map_invalid(self): def test_source_map_onecur_single(self): smap = price.parse_source_map('USD:yahoo/AAPL') self.assertEqual( - {'USD': [PS('beanprice.sources.yahoo', 'AAPL', False)]}, + {'USD': [PS('beanprice.sources.yahoo', 'AAPL', False, ONE)]}, self._clean_source_map(smap)) def test_source_map_onecur_multiple(self): smap = price.parse_source_map('USD:oanda/USDCAD,yahoo/CAD=X') self.assertEqual( - {'USD': [PS('beanprice.sources.oanda', 'USDCAD', False), - PS('beanprice.sources.yahoo', 'CAD=X', False)]}, + {'USD': [PS('beanprice.sources.oanda', 'USDCAD', False, ONE), + PS('beanprice.sources.yahoo', 'CAD=X', False, ONE)]}, self._clean_source_map(smap)) def test_source_map_manycur_single(self): smap = price.parse_source_map('USD:yahoo/USDCAD ' 'CAD:yahoo/CAD=X') self.assertEqual( - {'USD': [PS('beanprice.sources.yahoo', 'USDCAD', False)], - 'CAD': [PS('beanprice.sources.yahoo', 'CAD=X', False)]}, + {'USD': [PS('beanprice.sources.yahoo', 'USDCAD', False, ONE)], + 'CAD': [PS('beanprice.sources.yahoo', 'CAD=X', False, ONE)]}, self._clean_source_map(smap)) def test_source_map_manycur_multiple(self): smap = price.parse_source_map('USD:yahoo/GBPUSD,oanda/GBPUSD ' 'CAD:yahoo/GBPCAD') self.assertEqual( - {'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', False), - PS('beanprice.sources.oanda', 'GBPUSD', False)], - 'CAD': [PS('beanprice.sources.yahoo', 'GBPCAD', False)]}, + {'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', False, ONE), + PS('beanprice.sources.oanda', 'GBPUSD', False, ONE)], + 'CAD': [PS('beanprice.sources.yahoo', 'GBPCAD', False, ONE)]}, self._clean_source_map(smap)) def test_source_map_inverse(self): smap = price.parse_source_map('USD:yahoo/^GBPUSD') self.assertEqual( - {'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', True)]}, + {'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', True, ONE)]}, self._clean_source_map(smap)) + def test_source_map_multiplier(self): + smap = price.parse_source_map( + 'GBP:0.01*yahoo/XSDR.L;GBP:yahoo/XSDR;USD:1000*yahoo/mXSDRUSD') + print(smap) + self.assertEqual({ + 'GBP': [PS('beanprice.sources.yahoo', 'XSDR.L', False, Decimal('0.01')), + PS('beanprice.sources.yahoo', 'XSDR', False, ONE)], + 'USD': [PS('beanprice.sources.yahoo', 'mXSDRUSD', False, Decimal(1000))], + + }, self._clean_source_map(smap)) + class TestFilters(unittest.TestCase):