Commits

Matt Chaput committed fd5348f

Added Decimal support to NUMERIC field type via the decimal_places keyword arg.

Comments (0)

Files changed (2)

src/whoosh/fields.py

 """
 
 import datetime, re, struct
+from decimal import Decimal
 
 from whoosh.analysis import (IDAnalyzer, RegexAnalyzer, KeywordAnalyzer,
                              StandardAnalyzer, NgramAnalyzer, Tokenizer,
 
 class NUMERIC(FieldType):
     """Special field type that lets you index int, long, or floating point
-    numbers. The field converts the number to sortable text for you before
-    indexing.
+    numbers in relatively short fixed-width terms. The field converts numbers
+    to sortable text for you before indexing.
     
-    You can specify the type of the field when you create the NUMERIC object.
-    The default is int.
+    You specify the numeric type of the field when you create the NUMERIC
+    object. The default is ``int``.
     
     >>> schema = Schema(path=STORED, position=NUMERIC(long))
     >>> ix = storage.create_index(schema)
     >>> w = ix.writer()
     >>> w.add_document(path="/a", position=5820402204)
     >>> w.commit()
+    
+    You can also use the NUMERIC field to store Decimal instances by specifying
+    a type of ``int`` or ``long`` and the ``decimal_places`` keyword argument.
+    This simply multiplies each number by ``(10 ** decimal_places)`` before
+    storing it as an integer. Of course this may throw away decimal prcesision
+    (by truncating, not rounding) and imposes the same maximum value limits as
+    ``int``/``long``, but these may be acceptable for certain applications.
+    
+    >>> position = NUMERIC(int, decimal_places=4)
     """
     
     def __init__(self, type=int, stored=False, unique=False, field_boost=1.0,
-                 small=True):
+                 decimal_places=0):
         """
         :param type: the type of numbers that can be stored in this field: one
-            of ``int``, ``long``, or ``float``.
+            of ``int``, ``long``, ``float``, or ``Decimal``.
         :param stored: Whether the value of this field is stored with the
             document.
         :param unique: Whether the value of this field is unique per-document.
+        :param decimal_places: if ``type`` is ``Decimal``, this specifies the
+            number of decimal places to save.
         """
         
         self.type = type
         
         self.stored = stored
         self.unique = unique
-        self.small = small
+        self.decimal_places = decimal_places
         self.format = Existence(analyzer=IDAnalyzer(), field_boost=field_boost)
     
     def index(self, num):
-        _to_text = self._to_text
+        to_text = self.to_text
         # word, freq, weight, valuestring
-        return [(_to_text(num), 1, 1.0, '')]
+        return [(to_text(num), 1, 1.0, '')]
     
     def to_text(self, x):
+        if self.decimal_places:
+            x = Decimal(x)
+            x *= 10 ** self.decimal_places
         return self._to_text(self.type(x))
     
+    def from_text(self, t):
+        n = self._from_text(t)
+        if self.decimal_places:
+            s = str(n)
+            n = Decimal(s[:-4] + "." + s[-4:])
+        return n
+    
     def process_text(self, text, **kwargs):
         return (self.to_text(text),)
     
     
     def parse_query(self, fieldname, qstring, boost=1.0):
         from whoosh import query
+        
         return query.Term(fieldname, self.to_text(qstring), boost=boost)
     
 

tests/test_fields.py

     def test_numeric(self):
         schema = fields.Schema(id=fields.ID(stored=True),
                                integer=fields.NUMERIC(int),
-                               decimal=fields.NUMERIC(float))
-        st = RamStorage()
-        ix = st.create_index(schema)
+                               floating=fields.NUMERIC(float))
+        ix = RamStorage().create_index(schema)
         
         w = ix.writer()
-        w.add_document(id=u"a", integer=5820, decimal=1.2)
-        w.add_document(id=u"b", integer=22, decimal=2.3)
-        w.add_document(id=u"c", integer=78, decimal=3.4)
-        w.add_document(id=u"d", integer=13, decimal=4.5)
-        w.add_document(id=u"e", integer=9, decimal=5.6)
+        w.add_document(id=u"a", integer=5820, floating=1.2)
+        w.add_document(id=u"b", integer=22, floating=2.3)
+        w.add_document(id=u"c", integer=78, floating=3.4)
+        w.add_document(id=u"d", integer=13, floating=4.5)
+        w.add_document(id=u"e", integer=9, floating=5.6)
         w.commit()
         
         s = ix.searcher()
         self.assertEqual(sorted(d["id"] for d in r), ["b", "d"])
         
         s = ix.searcher()
-        r = s.search(qp.parse("decimal:4.5"))
+        r = s.search(qp.parse("floating:4.5"))
         self.assertEqual(len(r), 1)
         self.assertEqual(r[0]["id"], "d")
         
-        r = s.search(qp.parse("decimal:[1.4 TO 4]"))
+        r = s.search(qp.parse("floating:[1.4 TO 4]"))
         self.assertEqual(len(r), 2)
         self.assertEqual(sorted(d["id"] for d in r), ["b", "c"])
     
+    def test_decimal_numeric(self):
+        from decimal import Decimal
+        f = fields.NUMERIC(int, decimal_places=4)
+        schema = fields.Schema(id=fields.ID(stored=True), deci=f)
+        ix = RamStorage().create_index(schema)
+        
+        self.assertEqual(f.from_text(f.to_text(Decimal("123.56"))),
+                         Decimal("123.56"))
+        
+        w = ix.writer()
+        w.add_document(id=u"a", deci=Decimal("123.56"))
+        w.add_document(id=u"b", deci=Decimal("0.536255"))
+        w.add_document(id=u"c", deci=Decimal("2.5255"))
+        w.add_document(id=u"d", deci=Decimal("58"))
+        w.commit()
+        
+        s = ix.searcher()
+        qp = qparser.QueryParser("deci", schema=schema)
+        
+        self.assertEqual([f.from_text(t) for t in s.lexicon("deci")],
+                         [Decimal("0.5362"), Decimal("2.5255"),
+                          Decimal("58.0000"), Decimal("123.5600")])
+        
+        r = s.search(qp.parse("123.56"))
+        self.assertEqual(r[0]["id"], "a")
+        
+        r = s.search(qp.parse("0.536255"))
+        self.assertEqual(r[0]["id"], "b")
+    
     def test_datetime(self):
         schema = fields.Schema(id=fields.ID(stored=True),
                                date=fields.DATETIME(stored=True))
 
 
 
-
 if __name__ == '__main__':
     unittest.main()