Commits

Amaury Forgeot d'Arc committed 22acace

Reject NUL characters in some socket functions (e.g. host names)

  • Participants
  • Parent commits 3dc1f91

Comments (0)

Files changed (5)

File pypy/annotation/unaryop.py

 
     def str(obj):
         getbookkeeper().count('str', obj)
+        if isinstance(obj, SomeString):
+            return obj
+        if isinstance(obj, (SomeInteger, SomeFloat)):
+            return SomeString(no_nul=True)
         return SomeString()
 
     def unicode(obj):

File pypy/module/_socket/interp_func.py

         raise converted_error(space, e)
     return space.wrap(res)
 
-@unwrap_spec(host=str)
+@unwrap_spec(host='str0')
 def gethostbyname(space, host):
     """gethostbyname(host) -> address
 
                            space.newlist(aliases),
                            space.newlist(address_list)])
 
-@unwrap_spec(host=str)
+@unwrap_spec(host='str0')
 def gethostbyname_ex(space, host):
     """gethostbyname_ex(host) -> (name, aliaslist, addresslist)
 
         raise converted_error(space, e)
     return common_wrapgethost(space, res)
 
-@unwrap_spec(host=str)
+@unwrap_spec(host='str0')
 def gethostbyaddr(space, host):
     """gethostbyaddr(host) -> (name, aliaslist, addresslist)
 
         raise converted_error(space, e)
     return common_wrapgethost(space, res)
 
-@unwrap_spec(name=str)
+@unwrap_spec(name='str0')
 def getservbyname(space, name, w_proto=None):
     """getservbyname(servicename[, protocolname]) -> integer
 
     if space.is_w(w_proto, space.w_None):
         proto = None
     else:
-        proto = space.str_w(w_proto)
+        proto = space.str0_w(w_proto)
     try:
         port = rsocket.getservbyname(name, proto)
     except SocketError, e:
     if space.is_w(w_proto, space.w_None):
         proto = None
     else:
-        proto = space.str_w(w_proto)
+        proto = space.str0_w(w_proto)
 
     if port < 0 or port > 0xffff:
         raise OperationError(space.w_ValueError, space.wrap(
         raise converted_error(space, e)
     return space.wrap(service)
 
-@unwrap_spec(name=str)
+@unwrap_spec(name='str0')
 def getprotobyname(space, name):
     """getprotobyname(name) -> integer
 
     """
     return space.wrap(rsocket.htonl(x))
 
-@unwrap_spec(ip=str)
+@unwrap_spec(ip='str0')
 def inet_aton(space, ip):
     """inet_aton(string) -> packed 32-bit IP representation
 
         raise converted_error(space, e)
     return space.wrap(ip)
 
-@unwrap_spec(family=int, ip=str)
+@unwrap_spec(family=int, ip='str0')
 def inet_pton(space, family, ip):
     """inet_pton(family, ip) -> packed IP address string
 
     if space.is_w(w_host, space.w_None):
         host = None
     elif space.is_true(space.isinstance(w_host, space.w_str)):
-        host = space.str_w(w_host)
+        host = space.str0_w(w_host)
     elif space.is_true(space.isinstance(w_host, space.w_unicode)):
         w_shost = space.call_method(w_host, "encode", space.wrap("idna"))
-        host = space.str_w(w_shost)
+        host = space.str0_w(w_shost)
     else:
         raise OperationError(space.w_TypeError,
                              space.wrap(
     elif space.is_true(space.isinstance(w_port, space.w_int)):
         port = str(space.int_w(w_port))
     elif space.is_true(space.isinstance(w_port, space.w_str)):
-        port = space.str_w(w_port)
+        port = space.str0_w(w_port)
     else:
         raise OperationError(space.w_TypeError,
                              space.wrap("Int or String expected"))

File pypy/rlib/_rsocket_rffi.py

 from pypy.rpython.lltypesystem import rffi
 from pypy.rpython.lltypesystem import lltype
 from pypy.rpython.tool import rffi_platform as platform
-from pypy.rpython.lltypesystem.rffi import CCHARP
+from pypy.rpython.lltypesystem.rffi import CCHARP, CCHARP0
 from pypy.rlib.rposix import get_errno as geterrno
 from pypy.translator.tool.cbuild import ExternalCompilationInfo
 from pypy.translator.platform import platform as target_platform
 
 socketconnect = external('connect', [socketfd_type, sockaddr_ptr, socklen_t], rffi.INT)
 
-getaddrinfo = external('getaddrinfo', [CCHARP, CCHARP,
+getaddrinfo = external('getaddrinfo', [CCHARP0, CCHARP0,
                         addrinfo_ptr,
                         lltype.Ptr(rffi.CArray(addrinfo_ptr))], rffi.INT)
 freeaddrinfo = external('freeaddrinfo', [addrinfo_ptr], lltype.Void)
 ntohs = external('ntohs', [rffi.USHORT], rffi.USHORT, threadsafe=False)
 
 if _POSIX:
-    inet_aton = external('inet_aton', [CCHARP, lltype.Ptr(in_addr)],
-                                rffi.INT)
+    inet_aton = external('inet_aton', [CCHARP0, lltype.Ptr(in_addr)],
+                         rffi.INT)
 
 inet_ntoa = external('inet_ntoa', [in_addr], rffi.CCHARP)
 
 if _POSIX:
-    inet_pton = external('inet_pton', [rffi.INT, rffi.CCHARP,
-                                              rffi.VOIDP], rffi.INT)
+    inet_pton = external('inet_pton', [rffi.INT, CCHARP0, rffi.VOIDP],
+                         rffi.INT)
 
     inet_ntop = external('inet_ntop', [rffi.INT, rffi.VOIDP, CCHARP,
                                               socklen_t], CCHARP)
 
-inet_addr = external('inet_addr', [rffi.CCHARP], rffi.UINT)
+inet_addr = external('inet_addr', [rffi.CCHARP0], rffi.UINT)
 socklen_t_ptr = lltype.Ptr(rffi.CFixedArray(socklen_t, 1))
 socketaccept = external('accept', [socketfd_type, sockaddr_ptr,
                               socklen_t_ptr], socketfd_type)
                                     sockaddr_ptr, socklen_t], ssize_t)
 socketshutdown = external('shutdown', [socketfd_type, rffi.INT], rffi.INT)
 gethostname = external('gethostname', [rffi.CCHARP, rffi.INT], rffi.INT)
-gethostbyname = external('gethostbyname', [rffi.CCHARP],
+gethostbyname = external('gethostbyname', [rffi.CCHARP0],
                                 lltype.Ptr(cConfig.hostent))
 gethostbyaddr = external('gethostbyaddr', [rffi.VOIDP, rffi.INT, rffi.INT], lltype.Ptr(cConfig.hostent))
-getservbyname = external('getservbyname', [rffi.CCHARP, rffi.CCHARP], lltype.Ptr(cConfig.servent))
-getservbyport = external('getservbyport', [rffi.INT, rffi.CCHARP], lltype.Ptr(cConfig.servent))
-getprotobyname = external('getprotobyname', [rffi.CCHARP], lltype.Ptr(cConfig.protoent))
+getservbyname = external('getservbyname', [rffi.CCHARP0, rffi.CCHARP0], lltype.Ptr(cConfig.servent))
+getservbyport = external('getservbyport', [rffi.INT, rffi.CCHARP0], lltype.Ptr(cConfig.servent))
+getprotobyname = external('getprotobyname', [rffi.CCHARP0], lltype.Ptr(cConfig.protoent))
 
 if _POSIX:
     fcntl = external('fcntl', [socketfd_type, rffi.INT, rffi.INT], rffi.INT)

File pypy/rlib/rsocket.py

     def from_object(space, w_address):
         # Parse an app-level object representing an AF_INET address
         w_host, w_port = space.unpackiterable(w_address, 2)
-        host = space.str_w(w_host)
+        host = space.str0_w(w_host)
         port = space.int_w(w_port)
         port = Address.make_ushort_port(space, port)
         return INETAddress(host, port)
         if not (2 <= len(pieces_w) <= 4):
             raise TypeError("AF_INET6 address must be a tuple of length 2 "
                                "to 4, not %d" % len(pieces_w))
-        host = space.str_w(pieces_w[0])
+        host = space.str0_w(pieces_w[0])
         port = space.int_w(pieces_w[1])
         port = Address.make_ushort_port(space, port)
         if len(pieces_w) > 2: flowinfo = space.uint_w(pieces_w[2])
     return result, klass.maxlen
 
 def ipaddr_from_object(space, w_sockaddr):
-    host = space.str_w(space.getitem(w_sockaddr, space.wrap(0)))
+    host = space.str0_w(space.getitem(w_sockaddr, space.wrap(0)))
     addr = makeipaddr(host)
     addr.fill_from_object(space, w_sockaddr)
     return addr

File pypy/rpython/lltypesystem/rffi.py

 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.objectmodel import Symbolic, CDefinedIntSymbolic
 from pypy.rlib.objectmodel import keepalive_until_here
-from pypy.rlib import rarithmetic, rgc
+from pypy.rlib import rarithmetic, rgc, rstring
 from pypy.rpython.extregistry import ExtRegistryEntry
 from pypy.rlib.unroll import unrolling_iterable
 from pypy.rpython.tool.rfficache import platform
                     arg = lltype.nullptr(CCHARP.TO)   # None => (char*)NULL
                     freeme = arg
                 elif isinstance(arg, str):
+                    if TARGET is CCHARP0:
+                        rstring.check_str0(arg)
                     arg = str2charp(arg)
                     # XXX leaks if a str2charp() fails with MemoryError
                     # and was not the first in this function
 
 # char *
 CCHARP = lltype.Ptr(lltype.Array(lltype.Char, hints={'nolength': True}))
+CCHARP0 = lltype.Ptr(lltype.Array(lltype.Char, hints={'nolength': True}),
+                     use_cache=False)
+assert CCHARP0 is not CCHARP
+assert CCHARP0 == CCHARP
 
 # wchar_t *
 CWCHARP = lltype.Ptr(lltype.Array(lltype.UniChar, hints={'nolength': True}))