Commits

Maciej Fijalkowski committed f87a57e

getitem/setitem on bytearray

Comments (0)

Files changed (5)

pypy/annotation/binaryop.py

             result.const = b1.const + b2.const
         return result
 
+class __extend__(pairtype(SomeByteArray, SomeInteger)):
+    def getitem((s_b, s_i)):
+        return SomeInteger()
+
+    def setitem((s_b, s_i), s_i2):
+        assert isinstance(s_i2, SomeInteger)
+
 class __extend__(pairtype(SomeString, SomeByteArray),
                  pairtype(SomeByteArray, SomeString),
                  pairtype(SomeChar, SomeByteArray),

pypy/annotation/test/test_annrpython.py

         a = self.RPythonAnnotator()
         assert isinstance(a.build_types(f, [annmodel.SomeChar()]),
                           annmodel.SomeByteArray)
-        
+
+    def test_bytearray_setitem_getitem(self):
+        def f(b, i, c):
+            b[i] = c
+            return b[i + 1]
+
+        a = self.RPythonAnnotator()
+        assert isinstance(a.build_types(f, [annmodel.SomeByteArray(),
+                                            int, int]),
+                          annmodel.SomeInteger)
 
 def g(n):
     return [0,1,2,n]

pypy/rpython/lltypesystem/rbytearray.py

 
 from pypy.rpython.rbytearray import AbstractByteArrayRepr
 from pypy.rpython.lltypesystem import lltype, rstr
+from pypy.rlib.debug import ll_assert
 
 BYTEARRAY = lltype.GcForwardReference()
 
     'length': rstr.LLHelpers.ll_length,
 }))
 
+class LLHelpers(rstr.LLHelpers):
+    @classmethod
+    def ll_strsetitem(cls, s, i, item):
+        if i < 0:
+            i += s.length()
+        cls.ll_strsetitem_nonneg(s, i, item)
+
+    def ll_strsetitem_nonneg(s, i, item):
+        chars = s.chars
+        ll_assert(i >= 0, "negative str getitem index")
+        ll_assert(i < len(chars), "str getitem index out of bound")
+        chars[i] = chr(item)
+
+    def ll_stritem_nonneg(s, i):
+        return ord(rstr.LLHelpers.ll_stritem_nonneg(s, i))
+
 class ByteArrayRepr(AbstractByteArrayRepr):
     lowleveltype = lltype.Ptr(BYTEARRAY)
 
     def __init__(self, *args):
         AbstractByteArrayRepr.__init__(self, *args)
-        self.ll = rstr.LLHelpers
+        self.ll = LLHelpers
+        self.repr = self
 
     def convert_const(self, value):
         if value is None:

pypy/rpython/rbytearray.py

 
-from pypy.rpython.rmodel import Repr
 from pypy.annotation import model as annmodel
 from pypy.tool.pairtype import pairtype
 from pypy.rpython.rstr import AbstractStringRepr
+from pypy.rpython.rmodel import IntegerRepr
+from pypy.rpython.lltypesystem import lltype
 
-class AbstractByteArrayRepr(Repr):
+class AbstractByteArrayRepr(AbstractStringRepr):
     pass
 
 class __extend__(pairtype(AbstractByteArrayRepr, AbstractByteArrayRepr)):
         v_str1, v_b2 = hop.inputargs(str_repr, r_b2)
         return hop.gendirectcall(r_b2.ll.ll_strconcat, v_str1, v_b2)
 
+class __extend__(pairtype(AbstractByteArrayRepr, IntegerRepr)):
+    def rtype_setitem((r_b, r_int), hop, checkidx=False):
+        bytearray_repr = r_b.repr
+        v_str, v_index, v_item = hop.inputargs(bytearray_repr, lltype.Signed,
+                                               lltype.Signed)
+        if checkidx:
+            if hop.args_s[1].nonneg:
+                llfn = r_b.ll.ll_strsetitem_nonneg_checked
+            else:
+                llfn = r_b.ll.ll_strsetitem_checked
+        else:
+            if hop.args_s[1].nonneg:
+                llfn = r_b.ll.ll_strsetitem_nonneg
+            else:
+                llfn = r_b.ll.ll_strsetitem
+        if checkidx:
+            hop.exception_is_here()
+        else:
+            hop.exception_cannot_occur()
+        return hop.gendirectcall(llfn, v_str, v_index, v_item)
+
 class __extend__(annmodel.SomeByteArray):
     def rtyper_makekey(self):
         return self.__class__,

pypy/rpython/test/test_rbytearray.py

 
         ll_res = self.interpret(f3, [llstr("def")])
         assert hlbytearray(ll_res) == "defa"
+
+    def test_getitem_setitem(self):
+        def f(s, i, c):
+            b = bytearray(hlstr(s))
+            b[i] = c
+            return b[i] + b[i + 1] * 255
+
+        ll_res = self.interpret(f, [llstr("abc"), 1, ord('d')])
+        assert ll_res == ord('d') + ord('c') * 255