Commits

Anonymous committed dedd44e

Complete broadcast support (both raw and via port mapper CALLIT)

Comments (0)

Files changed (1)

-# Implement (a subset of) Sun RPC, version 2 -- RFC1057.
+# Sun RPC version 2 -- RFC1057.
 
 # XXX There should be separate exceptions for the various reasons why
 # XXX an RPC can fail, rather than using RuntimeError for everything
 		self.port = port
 		self.makesocket() # Assigns to self.sock
 		self.bindsocket()
-		self.sock.connect((host, port))
-		self.lastxid = 0
+		self.connsocket()
+		self.lastxid = 0 # XXX should be more random?
 		self.addpackers()
 		self.cred = None
 		self.verf = None
 		# This MUST be overridden
 		raise RuntimeError, 'makesocket not defined'
 
+	def connsocket(self):
+		# Override this if you don't want/need a connection
+		self.sock.connect((self.host, self.port))
+
 	def bindsocket(self):
 		# Override this to bind to a different port (e.g. reserved)
 		self.sock.bind(('', 0))
 		self.packer = Packer().init()
 		self.unpacker = Unpacker().init('')
 
+	def make_call(self, proc, args, pack_func, unpack_func):
+		# Don't normally override this (but see Broadcast)
+		if pack_func is None and args is not None:
+			raise TypeError, 'non-null args with null pack_func'
+		self.start_call(proc)
+		if pack_func:
+			pack_func(args)
+		self.do_call()
+		if unpack_func:
+			result = unpack_func()
+		else:
+			result = None
+		self.unpacker.done()
+		return result
+
 	def start_call(self, proc):
 		# Don't override this
 		self.lastxid = xid = self.lastxid + 1
 		p.reset()
 		p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
 
-	def do_call(self, *rest):
+	def do_call(self):
 		# This MUST be overridden
 		raise RuntimeError, 'do_call not defined'
 
-	def end_call(self):
-		# Don't override this
-		self.unpacker.done()
-
 	def mkcred(self):
 		# Override this to use more powerful credentials
 		if self.cred == None:
 		return self.verf
 
 	def Null(self):			# Procedure 0 is always like this
-		self.start_call(0)
-		self.do_call(0)
-		self.end_call()
+		return self.make_call(0, None, None, None)
 
 
 # Record-Marking standard support
 	raise RuntimeError, 'can\'t assign reserved port'
 
 
-# Raw TCP-based client
+# Client using TCP to a specific port
 
 class RawTCPClient(Client):
 
 	def makesocket(self):
 		self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
-	def start_call(self, proc):
-		self.lastxid = xid = self.lastxid + 1
-		cred = self.mkcred()
-		verf = self.mkverf()
-		p = self.packer
-		p.reset()
-		p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
-
-	def do_call(self, *rest):
-		# rest is used for UDP buffer size; ignored for TCP
+	def do_call(self):
 		call = self.packer.get_buf()
 		sendrecord(self.sock, call)
 		reply = recvrecord(self.sock)
 			raise RuntimeError, 'wrong xid in reply ' + `xid` + \
 				' instead of ' + `self.lastxid`
 
-	def end_call(self):
-		self.unpacker.done()
 
-
-# Raw UDP-based client
+# Client using UDP to a specific port
 
 class RawUDPClient(Client):
 
 	def makesocket(self):
 		self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 
-	def start_call(self, proc):
-		self.lastxid = xid = self.lastxid + 1
-		cred = self.mkcred()
-		verf = self.mkverf()
-		p = self.packer
-		p.reset()
-		p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
-
-	def do_call(self, *rest):
+	def do_call(self):
+		call = self.packer.get_buf()
+		self.sock.send(call)
 		try:
 			from select import select
 		except ImportError:
 			print 'WARNING: select not found, RPC may hang'
 			select = None
-		if len(rest) == 0:
-			bufsize = 8192
-		elif len(rest) > 1:
-			raise TypeError, 'too many args'
-		else:
-			bufsize = rest[0] + 512
-		call = self.packer.get_buf()
+		BUFSIZE = 8192 # Max UDP buffer size
 		timeout = 1
 		count = 5
-		self.sock.send(call)
 		while 1:
 			r, w, x = [self.sock], [], []
 			if select:
 ##				print 'RESEND', timeout, count
 				self.sock.send(call)
 				continue
-			reply = self.sock.recv(bufsize)
+			reply = self.sock.recv(BUFSIZE)
 			u = self.unpacker
 			u.reset(reply)
 			xid, verf = u.unpack_replyheader()
 				continue
 			break
 
-	def end_call(self):
-		self.unpacker.done()
+
+# Client using UDP broadcast to a specific port
+
+class RawBroadcastUDPClient(RawUDPClient):
+
+	def init(self, bcastaddr, prog, vers, port):
+		self = RawUDPClient.init(self, bcastaddr, prog, vers, port)
+		self.reply_handler = None
+		self.timeout = 30
+		return self
+
+	def connsocket(self):
+		# Don't connect -- use sendto
+		self.sock.allowbroadcast(1)
+
+	def set_reply_handler(self, reply_handler):
+		self.reply_handler = reply_handler
+
+	def set_timeout(self, timeout):
+		self.timeout = timeout # Use None for infinite timeout
+
+	def make_call(self, proc, args, pack_func, unpack_func):
+		if pack_func is None and args is not None:
+			raise TypeError, 'non-null args with null pack_func'
+		self.start_call(proc)
+		if pack_func:
+			pack_func(args)
+		call = self.packer.get_buf()
+		self.sock.sendto(call, (self.host, self.port))
+		try:
+			from select import select
+		except ImportError:
+			print 'WARNING: select not found, broadcast will hang'
+			select = None
+		BUFSIZE = 8192 # Max UDP buffer size (for reply)
+		replies = []
+		if unpack_func is None:
+			def dummy(): pass
+			unpack_func = dummy
+		while 1:
+			r, w, x = [self.sock], [], []
+			if select:
+				if self.timeout is None:
+					r, w, x = select(r, w, x)
+				else:
+					r, w, x = select(r, w, x, self.timeout)
+			if self.sock not in r:
+				break
+			reply, fromaddr = self.sock.recvfrom(BUFSIZE)
+			u = self.unpacker
+			u.reset(reply)
+			xid, verf = u.unpack_replyheader()
+			if xid <> self.lastxid:
+##				print 'BAD xid'
+				continue
+			reply = unpack_func()
+			self.unpacker.done()
+			replies.append((reply, fromaddr))
+			if self.reply_handler:
+				self.reply_handler(reply, fromaddr)
+		return replies
 
 
 # Port mapper interface
 
-# XXX CALLIT is not implemented
-
 # Program number, version and (fixed!) port number
 PMAP_PROG = 100000
 PMAP_VERS = 2
 	def pack_pmaplist(self, list):
 		self.pack_list(list, self.pack_mapping)
 
+	def pack_call_args(self, ca):
+		prog, vers, proc, args = ca
+		self.pack_uint(prog)
+		self.pack_uint(vers)
+		self.pack_uint(proc)
+		self.pack_opaque(args)
+
 
 class PortMapperUnpacker(Unpacker):
 
 	def unpack_pmaplist(self):
 		return self.unpack_list(self.unpack_mapping)
 
+	def unpack_call_result(self):
+		port = self.unpack_uint()
+		res = self.unpack_opaque()
+		return port, res
+
 
 class PartialPortMapperClient:
 
 		self.unpacker = PortMapperUnpacker().init('')
 
 	def Set(self, mapping):
-		self.start_call(PMAPPROC_SET)
-		self.packer.pack_mapping(mapping)
-		self.do_call()
-		res = self.unpacker.unpack_uint()
-		self.end_call()
-		return res
+		return self.make_call(PMAPPROC_SET, mapping, \
+			self.packer.pack_mapping, \
+			self.unpacker.unpack_uint)
 
 	def Unset(self, mapping):
-		self.start_call(PMAPPROC_UNSET)
-		self.packer.pack_mapping(mapping)
-		self.do_call()
-		res = self.unpacker.unpack_uint()
-		self.end_call()
-		return res
+		return self.make_call(PMAPPROC_UNSET, mapping, \
+			self.packer.pack_mapping, \
+			self.unpacker.unpack_uint)
 
 	def Getport(self, mapping):
-		self.start_call(PMAPPROC_GETPORT)
-		self.packer.pack_mapping(mapping)
-		self.do_call(4)
-		port = self.unpacker.unpack_uint()
-		self.end_call()
-		return port
+		return self.make_call(PMAPPROC_GETPORT, mapping, \
+			self.packer.pack_mapping, \
+			self.unpacker.unpack_uint)
 
 	def Dump(self):
-		self.start_call(PMAPPROC_DUMP)
-		self.do_call(8192-512)
-		list = self.unpacker.unpack_pmaplist()
-		self.end_call()
-		return list
+		return self.make_call(PMAPPROC_DUMP, None, \
+			None, \
+			self.unpacker.unpack_pmaplist)
+
+	def Callit(self, ca):
+		return self.make_call(PMAPPROC_CALLIT, ca, \
+			self.packer.pack_call_args, \
+			self.unpacker.unpack_call_result)
 
 
 class TCPPortMapperClient(PartialPortMapperClient, RawTCPClient):
 			host, PMAP_PROG, PMAP_VERS, PMAP_PORT)
 
 
+class BroadcastUDPPortMapperClient(PartialPortMapperClient, \
+				   RawBroadcastUDPClient):
+
+	def init(self, bcastaddr):
+		return RawBroadcastUDPClient.init(self, \
+			bcastaddr, PMAP_PROG, PMAP_VERS, PMAP_PORT)
+
+
+# Generic clients that find their server through the Port mapper
+
 class TCPClient(RawTCPClient):
 
 	def init(self, host, prog, vers):
 		return RawUDPClient.init(self, host, prog, vers, port)
 
 
+class BroadcastUDPClient(Client):
+
+	def init(self, bcastaddr, prog, vers):
+		self.pmap = BroadcastUDPPortMapperClient().init(bcastaddr)
+		self.pmap.set_reply_handler(self.my_reply_handler)
+		self.prog = prog
+		self.vers = vers
+		self.user_reply_handler = None
+		self.addpackers()
+		return self
+
+	def close(self):
+		self.pmap.close()
+
+	def set_reply_handler(self, reply_handler):
+		self.user_reply_handler = reply_handler
+
+	def set_timeout(self, timeout):
+		self.pmap.set_timeout(timeout)
+
+	def my_reply_handler(self, reply, fromaddr):
+		port, res = reply
+		self.unpacker.reset(res)
+		result = self.unpack_func()
+		self.unpacker.done()
+		self.replies.append((result, fromaddr))
+		if self.user_reply_handler is not None:
+			self.user_reply_handler(result, fromaddr)
+
+	def make_call(self, proc, args, pack_func, unpack_func):
+		self.packer.reset()
+		if pack_func:
+			pack_func(args)
+		if unpack_func is None:
+			def dummy(): pass
+			self.unpack_func = dummy
+		else:
+			self.unpack_func = unpack_func
+		self.replies = []
+		packed_args = self.packer.get_buf()
+		dummy_replies = self.pmap.Callit( \
+			(self.prog, self.vers, proc, packed_args))
+		return self.replies
+
+
 # Server classes
 
 # These are not symmetric to the Client classes
 # Simple test program -- dump local portmapper status
 
 def test():
-	import T
-	T.TSTART()
 	pmap = UDPPortMapperClient().init('')
-	T.TSTOP()
 	pmap.Null()
-	T.TSTOP()
 	list = pmap.Dump()
-	T.TSTOP()
 	list.sort()
 	for prog, vers, prot, port in list:
 		print prog, vers,
 		print port
 
 
-# Server and client test program.
+# Test program for broadcast operation -- dump everybody's portmapper status
+
+def testbcast():
+	import sys
+	if sys.argv[1:]:
+		bcastaddr = sys.argv[1]
+	else:
+		bcastaddr = '<broadcast>'
+	def rh(reply, fromaddr):
+		host, port = fromaddr
+		print host + '\t' + `reply`
+	pmap = BroadcastUDPPortMapperClient().init(bcastaddr)
+	pmap.set_reply_handler(rh)
+	pmap.set_timeout(5)
+	replies = pmap.Getport((100002, 1, IPPROTO_UDP, 0))
+
+
+# Test program for server, with corresponding client
 # On machine A: python -c 'import rpc; rpc.testsvr()'
 # On machine B: python -c 'import rpc; rpc.testclt()' A
 # (A may be == B)
 	# Client for above server
 	class C(UDPClient):
 		def call_1(self, arg):
-			self.start_call(1)
-			self.packer.pack_string(arg)
-			self.do_call()
-			reply = self.unpacker.unpack_string()
-			self.end_call()
-			return reply
+			return self.make_call(1, arg, \
+				self.packer.pack_string, \
+				self.unpacker.unpack_string)
 	c = C().init(host, 0x20000000, 1)
 	print 'making call...'
 	reply = c.call_1('hello, world, ')
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.