Commits

Virgil Dupras committed df78b9c

currency.RatesDB can now have configurable exchange rate providers.

Comments (0)

Files changed (1)

 # which should be included with this package. The terms are also available at 
 # http://www.hardcoded.net/licenses/bsd_license
 
-from datetime import datetime, date
+from datetime import datetime, date, timedelta
 import logging
 import sqlite3 as sqlite
+import threading
+from queue import Queue, Empty
 
 from . import io
 from .path import Path
 XPF = Currency.register('XPF', 'CFP franc',
     exponent=0, start_date=date(1998, 1, 2), start_rate=0.01299, latest_rate=0.01114)
 
+class CurrencyNotSupportedException(Exception):
+    """The current exchange rate provider doesn't support the requested currency."""
+
 class RatesDB:
     """Stores exchange rates for currencies.
     
     The currencies are identified with ISO 4217 code (USD, CAD, EUR, etc.).
     The rates are represented as float and represent the value of the currency in CAD.
     """
-    def __init__(self, db_or_path=':memory:'):
+    def __init__(self, db_or_path=':memory:', async=True):
         self._cache = {} # {(date, currency): CAD value
         self.db_or_path = db_or_path
         if isinstance(db_or_path, (str, Path)):
         else:
             self.con = db_or_path
         self._execute("select * from rates where 1=2")
+        self._rate_providers = []
+        self.async = async
+        self._fetched_values = Queue()
+        self._fetched_ranges = {} # a currency --> (start, end) map
     
     def _execute(self, *args, **kwargs):
         def create_tables():
                 return row[0]
         return seek('<=', 'desc') or seek('>=', '') or Currency(currency_code).latest_rate
     
+    def _save_fetched_rates(self):
+        while True:
+            try:
+                rates, currency = self._fetched_values.get_nowait()
+                for rate_date, rate in rates:
+                    self.set_CAD_value(rate_date, currency, rate)
+            except Empty:
+                break
+    
     def clear_cache(self):
         self._cache = {}
     
         The rate of the nearest date that is smaller than 'date' is returned. If
         there is none, a seek for a rate with a higher date will be made.
         """
+        # We want to check self._fetched_values for rates to add.
+        if not self._fetched_values.empty():
+            self._save_fetched_rates()
         # This method is a bottleneck and has been optimized for speed.
         value1 = None
         value2 = None
         sql = "replace into rates(date, currency, rate) values(?, ?, ?)"
         self._execute(sql, [str_date, currency_code, value])
         self.con.commit()
+    
+    def register_rate_provider(self, rate_provider):
+        """Adds `rate_provider` to the list of providers supported by this DB.
+        
+        A provider if a function(currency, start_date, end_date) that returns a list of
+        (rate_date, float_rate) as a result. This function will be called asyncronously, so it's ok
+        if it takes a long time to return.
+        
+        The rates returned must be the value of 1 `currency` in CAD (Canadian Dollars) at the
+        specified date.
+        
+        The provider can be asked for any currency. If it doesn't support it, it has to raise
+        CurrencyNotSupportedException.
+        
+        If we support the currency but that there is no rate available for the specified range,
+        simply return an empty list or None.
+        """
+        self._rate_providers.append(rate_provider)
+    
+    def ensure_rates(self, start_date, currencies):
+        """Ensures that the DB has all the rates it needs for 'currencies' between 'start_date' and today
+        
+        If there is any rate missing, a request will be made to the currency server. The requests
+        are made asynchronously.
+        """
+        def do():
+            for currency, fetch_start, fetch_end in currencies_and_range:
+                for rate_provider in self._rate_providers:
+                    try:
+                        values = rate_provider(currency, fetch_start, fetch_end)
+                    except CurrencyNotSupportedException:
+                        continue
+                    else:
+                        if values:
+                            self._fetched_values.put((values, currency))
+        
+        currencies_and_range = []
+        for currency in currencies:
+            if currency == 'CAD':
+                continue
+            try:
+                cached_range = self._fetched_ranges[currency]
+            except KeyError:
+                cached_range = self.date_range(currency)
+            range_start = start_date
+            range_end = date.today()
+            if cached_range is not None:
+                cached_start, cached_end = cached_range
+                if range_start >= cached_start:
+                    # Make a forward fetch
+                    range_start = cached_end + timedelta(days=1)
+                else:
+                    # Make a backward fetch
+                    range_end = cached_start - timedelta(days=1)
+            if range_start <= range_end:
+                currencies_and_range.append((currency, range_start, range_end))
+            self._fetched_ranges[currency] = (start_date, date.today())
+        if self.async:
+            threading.Thread(target=do).start()
+        else:
+            do()
+    
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.