Commits

Maciej Fijalkowski committed 0f26de8

progress towards addition

Comments (0)

Files changed (4)

pypy/rpython/lltypesystem/rbytearray.py

 
 from pypy.rpython.rbytearray import AbstractByteArrayRepr
-from pypy.rpython.lltypesystem import lltype
+from pypy.rpython.lltypesystem import lltype, rstr
 
-BYTEARRAY = lltype.GcArray(lltype.Char)
+BYTEARRAY = lltype.GcForwardReference()
+
+def mallocbytearray(size):
+    return lltype.malloc(BYTEARRAY, size)
+
+copy_bytearray_contents = rstr._new_copy_contents_fun(BYTEARRAY, BYTEARRAY,
+                                                      lltype.Char,
+                                                      'bytearray')
+copy_bytearray_contents_from_str = rstr._new_copy_contents_fun(rstr.STR,
+                                                               BYTEARRAY,
+                                                               lltype.Char,
+                                                               'bytearray')
+
+BYTEARRAY.become(lltype.GcStruct('rpy_bytearray',
+                 ('chars', lltype.Array(lltype.Char)), adtmeths={
+    'malloc' : lltype.staticAdtMethod(mallocbytearray),
+    'copy_contents' : lltype.staticAdtMethod(copy_bytearray_contents),
+    'copy_contents_from_str': lltype.staticAdtMethod(
+                                         copy_bytearray_contents_from_str),
+    'length': rstr.LLHelpers.ll_length,
+}))
 
 class ByteArrayRepr(AbstractByteArrayRepr):
     lowleveltype = lltype.Ptr(BYTEARRAY)
 
+    def __init__(self, *args):
+        AbstractByteArrayRepr.__init__(self, *args)
+        self.ll = rstr.LLHelpers
+
     def convert_const(self, value):
         if value is None:
             return lltype.nullptr(BYTEARRAY)
         p = lltype.malloc(BYTEARRAY, len(value))
         for i, c in enumerate(value):
-            p[i] = chr(c)
+            p.chars[i] = chr(c)
         return p
 
 bytearray_repr = ByteArrayRepr()
 
 def hlbytearray(ll_b):
     b = bytearray()
-    for i in range(len(ll_b)):
-        b.append(ll_b[i])
+    for i in range(ll_b.length()):
+        b.append(ll_b.chars[i])
     return b

pypy/rpython/lltypesystem/rstr.py

 def emptyunicodefun():
     return emptyunicode
 
-def _new_copy_contents_fun(TP, CHAR_TP, name):
-    def _str_ofs(item):
-        return (llmemory.offsetof(TP, 'chars') +
-                llmemory.itemoffsetof(TP.chars, 0) +
+def _new_copy_contents_fun(SRC_TP, DST_TP, CHAR_TP, name):
+    def _str_ofs_src(item):
+        return (llmemory.offsetof(SRC_TP, 'chars') +
+                llmemory.itemoffsetof(SRC_TP.chars, 0) +
+                llmemory.sizeof(CHAR_TP) * item)
+
+    def _str_ofs_dst(item):
+        return (llmemory.offsetof(DST_TP, 'chars') +
+                llmemory.itemoffsetof(DST_TP.chars, 0) +
                 llmemory.sizeof(CHAR_TP) * item)
 
     @jit.oopspec('stroruni.copy_contents(src, dst, srcstart, dststart, length)')
         assert srcstart >= 0
         assert dststart >= 0
         assert length >= 0
-        src = llmemory.cast_ptr_to_adr(src) + _str_ofs(srcstart)
-        dst = llmemory.cast_ptr_to_adr(dst) + _str_ofs(dststart)
+        src = llmemory.cast_ptr_to_adr(src) + _str_ofs_src(srcstart)
+        dst = llmemory.cast_ptr_to_adr(dst) + _str_ofs_dst(dststart)
         llmemory.raw_memcopy(src, dst, llmemory.sizeof(CHAR_TP) * length)
         keepalive_until_here(src)
         keepalive_until_here(dst)
     copy_string_contents._always_inline_ = True
     return func_with_new_name(copy_string_contents, 'copy_%s_contents' % name)
 
-copy_string_contents = _new_copy_contents_fun(STR, Char, 'string')
-copy_unicode_contents = _new_copy_contents_fun(UNICODE, UniChar, 'unicode')
+copy_string_contents = _new_copy_contents_fun(STR, STR, Char, 'string')
+copy_unicode_contents = _new_copy_contents_fun(UNICODE, UNICODE, UniChar,
+                                               'unicode')
 
 CONST_STR_CACHE = WeakValueDictionary()
 CONST_UNICODE_CACHE = WeakValueDictionary()
         lgt = len(str.chars)
         b = malloc(BYTEARRAY, lgt)
         for i in range(lgt):
-            b[i] = str.chars[i]
+            b.chars[i] = str.chars[i]
         return b
 
     @jit.elidable
             s.hash = x
         return x
 
+    def ll_length(s):
+        return len(s.chars)
+
     def ll_strfasthash(s):
         return s.hash     # assumes that the hash is already computed
 
     @jit.elidable
     def ll_strconcat(s1, s2):
-        len1 = len(s1.chars)
-        len2 = len(s2.chars)
+        len1 = s1.length()
+        len2 = s2.length()
         # a single '+' like this is allowed to overflow: it gets
         # a negative result, and the gc will complain
-        newstr = s1.malloc(len1 + len2)
-        s1.copy_contents(s1, newstr, 0, 0, len1)
-        s1.copy_contents(s2, newstr, 0, len1, len2)
+        # the typechecks below are if TP == BYTEARRAY
+        if typeOf(s1) == STR:
+            newstr = s2.malloc(len1 + len2)
+            newstr.copy_contents_from_str(s1, newstr, 0, 0, len1)
+        else:
+            newstr = s1.malloc(len1 + len2)            
+            newstr.copy_contents(s1, newstr, 0, 0, len1)
+        if typeOf(s2) == STR:
+            newstr.copy_contents_from_str(s2, newstr, 0, len1, len2)
+        else:
+            newstr.copy_contents(s2, newstr, 0, len1, len2)
         return newstr
     ll_strconcat.oopspec = 'stroruni.concat(s1, s2)'
 
                     adtmeths={'malloc' : staticAdtMethod(mallocstr),
                               'empty'  : staticAdtMethod(emptystrfun),
                               'copy_contents' : staticAdtMethod(copy_string_contents),
-                              'gethash': LLHelpers.ll_strhash}))
+                              'copy_contents_from_str' : staticAdtMethod(copy_string_contents),
+                              'gethash': LLHelpers.ll_strhash,
+                              'length': LLHelpers.ll_length}))
 UNICODE.become(GcStruct('rpy_unicode', ('hash', Signed),
                         ('chars', Array(UniChar, hints={'immutable': True})),
                         adtmeths={'malloc' : staticAdtMethod(mallocunicode),
                                   'empty'  : staticAdtMethod(emptyunicodefun),
                                   'copy_contents' : staticAdtMethod(copy_unicode_contents),
-                                  'gethash': LLHelpers.ll_strhash}
+                                  'copy_contents_from_str' : staticAdtMethod(copy_unicode_contents),
+                                  'gethash': LLHelpers.ll_strhash,
+                                  'length': LLHelpers.ll_length}
                         ))
 
 

pypy/rpython/rbytearray.py

 
 from pypy.rpython.rmodel import Repr
 from pypy.annotation import model as annmodel
+from pypy.tool.pairtype import pairtype
+from pypy.rpython.rstr import AbstractStringRepr
 
 class AbstractByteArrayRepr(Repr):
     pass
 
+class __extend__(pairtype(AbstractByteArrayRepr, AbstractByteArrayRepr)):
+    def rtype_add((r_b1, r_b2), hop):
+        xxx
+
+class __extend__(pairtype(AbstractByteArrayRepr, AbstractStringRepr)):
+    def rtype_add((r_b1, r_s2), hop):
+        str_repr = r_s2.repr
+        if hop.s_result.is_constant():
+            return hop.inputconst(r_b1, hop.s_result.const)
+        v_b1, v_str2 = hop.inputargs(r_b1, str_repr)
+        return hop.gendirectcall(r_b1.ll.ll_strconcat, v_b1, v_str2)
+
+class __extend__(pairtype(AbstractStringRepr, AbstractByteArrayRepr)):
+    def rtype_add((r_s2, r_b1), hop):
+        xxx
+
 class __extend__(annmodel.SomeByteArray):
     def rtyper_makekey(self):
         return self.__class__,

pypy/rpython/test/test_rbytearray.py

 
 from pypy.rpython.test.tool import BaseRtypingTest, LLRtypeMixin
 from pypy.rpython.lltypesystem.rbytearray import hlbytearray
+from pypy.rpython.annlowlevel import llstr, hlstr
 
 class TestByteArray(BaseRtypingTest, LLRtypeMixin):
     def test_bytearray_creation(self):
         assert hlbytearray(ll_res) == "def"
         ll_res = self.interpret(f, [1])
         assert hlbytearray(ll_res) == "1"
+
+    def test_addition(self):
+        def f(x):
+            return bytearray("a") + hlstr(x)
+
+        ll_res = self.interpret(f, [llstr("def")])
+        assert hlbytearray(ll_res) == "adef"
+
+        def f2(x):
+            return hlstr(x) + bytearray("a")
+
+        ll_res = self.interpret(f2, [llstr("def")])
+        assert hlbytearray(ll_res) == "defa"