Commits

Anonymous committed 782c2fc

almost rewritten from scratch. Added some tests, should support several clients at a time

  • Participants
  • Parent commits e1dea49

Comments (0)

Files changed (2)

 package main
 
 import (
+	"bytes"
+	"io"
+	"log"
 	"net"
-	"io/ioutil"
 	"os"
-	"log"
 	"strings"
 )
 
 	MODE_OCTET = "octet"
 )
 
-// Default TFTP chunk size
-const CHUNK_SIZE = 512
+// Other TFTP constants
+const (
+	DEFAULT_CHUNK_SIZE int = 512
+	MAX_TFTP_PACKET_SIZE = (1 << 16)
+)
 
+//
 // Single TFTP file data chunk
+//
 type TftpChunk struct {
 	id   int
 	data []byte
 	ack  bool
 }
 
-// Converts data chunk to TFTP data packet
+// Construct new chunk from array of bytes
+func NewTftpChunk(id int, data []byte) (c *TftpChunk) {
+	c = new(TftpChunk)
+	c.id = id
+	c.data = data
+	return
+}
+
+// Get chunk size
+func (c *TftpChunk) Len() int {
+	return len(c.data)
+}
+
+// Convert data chunk into TFTP data packet
 func (c *TftpChunk) Packet() (p []byte) {
 	p = make([]byte, len(c.data)+4)
 	p[0] = 0
 	p[2] = uint8(c.id / 256)
 	p[3] = uint8(c.id % 256)
 	copy(p[4:], c.data[:])
-	log.Println("packet id =", c.id, "len =", (len(p) - 4))
+	//log.Println("packet id =", c.id, "len =", (len(p) - 4))
 	return
 }
 
-// TFTP file represented as array of chunks
+//
+// TFTP file (represented as array of chunks)
+//
 type TftpFile struct {
 	chunks []TftpChunk
 }
 
 // Create new TFTP file object from local file
-func NewTftpFile(name string) (f *TftpFile, err os.Error) {
-	var data []byte
+func NewTftpFile(r io.Reader, chunkSize int) (f *TftpFile, err os.Error) {
+	f = new(TftpFile)
+	f.chunks = make([]TftpChunk, 0)
+	for {
+		var n int
+		buf := make([]byte, chunkSize)
+		if n, err = r.Read(buf); err != nil && err != os.EOF {
+			return
+		}
+		err = nil
 
-	if data, err = ioutil.ReadFile(name); err != nil {
-		return
+		id := len(f.chunks) + 1
+		c := NewTftpChunk(id, buf[:n])
+		f.chunks = append(f.chunks, *c)
+
+		if n != len(buf) {
+			break
+		}
 	}
-
-	emptyChunk := (len(data)%CHUNK_SIZE == 0)
-	numChunks := len(data)/CHUNK_SIZE + 1
-
-	f = new(TftpFile)
-	if emptyChunk {
-		f.chunks = make([]TftpChunk, numChunks+1)
-		f.chunks[numChunks].ack = false
-		f.chunks[numChunks].data = nil
-		f.chunks[numChunks].id = numChunks + 1
-	} else {
-		f.chunks = make([]TftpChunk, numChunks)
-	}
-
-	for i := 0; i < numChunks; i++ {
-		//f.chunks[i] = new(TftpChunk)
-		f.chunks[i].ack = false
-		if i == numChunks-1 {
-			f.chunks[i].data = data[i*CHUNK_SIZE : len(data)]
-		} else {
-			f.chunks[i].data = data[i*CHUNK_SIZE : (i+1)*CHUNK_SIZE]
-		}
-		f.chunks[i].id = i + 1
-	}
-
 	return
 }
 
 	return nil, false
 }
 
-// Mark chink as sent
-func (f *TftpFile) AckChunk(i int) {
-	f.chunks[i-1].ack = true
+// Mark chunk as sent
+func (f *TftpFile) AckChunk(i uint16) {
+	f.chunks[i].ack = true
 }
 
-func getString(buf []byte) (s string, length int) {
-	for i := 0; i < len(buf); i++ {
-		if buf[i] == 0 {
-			return string(buf[0:i]), i
+//
+// Helper routines
+//
+
+func makeUint16(hi, lo byte) uint16 {
+	return uint16(lo) + (uint16(hi) << 8)
+}
+
+func readUint16(r io.Reader) (uint16, os.Error) {
+	var b []byte = make([]byte, 2)
+	if n, err := r.Read(b); err != nil {
+		return 0, err
+	} else {
+		if n != len(b) {
+			return 0, os.NewError("Unexpected end of data")
 		}
 	}
-	return
+	return makeUint16(b[0], b[1]), nil
+}
+
+func readString(r io.Reader) (s string, err os.Error) {
+	var b []byte = make([]byte, 1)
+	var str []byte = make([]byte, 0)
+
+	for {
+		var n int
+		if n, err = r.Read(b); err != nil {
+			return
+		} else {
+			if n != 1 {
+				err = os.NewError("string is not null-terminated")
+				return
+			}
+			if b[0] == 0 {
+				break
+			}
+			str = append(str, b[0])
+		}
+	}
+	return string(str), nil
 }
 
 func main() {
 	// listen on that port
 	conn, err := net.ListenUDP("udp", laddr)
 	if err != nil {
-		log.Fatal("failed to create listening UDP connection: ", err)
+		log.Fatal("failed to create listening UDP connection:", err)
 	}
 
-	buf := make([]byte, 512)
-
-	var f *TftpFile
+	clients := make(map[string]*TftpFile)
+	udpbuf := make([]byte, MAX_TFTP_PACKET_SIZE)
 
 	for {
-		// read single UDP packet
-		n, raddr, err := conn.ReadFromUDP(buf)
+		n, raddr, err := conn.ReadFromUDP(udpbuf)
 		if err != nil {
-			log.Println("failed to receive UDP packet: ", err)
+			log.Println("failed to receive UDP tftpPacket:", err)
 			continue
 		}
-		packet := buf[0:n]
+		if _, ok := clients[raddr.String()]; ok == false {
+			log.Println("new connection from", raddr)
+			clients[raddr.String()] = nil
+		}
 
-		rq := (int(packet[0]) << 8) | (int(packet[1]))
+		tftpPacket := bytes.NewBuffer(udpbuf[0:n])
 
-		var size int
-		var filename, mode string
-		if rq == RRQ {
-			filename, size = getString(packet[2:])
-			if size == 0 {
-				log.Println("failed to retrieve file name from the packet")
+		req, err := readUint16(tftpPacket)
+		if err != nil {
+			log.Println("malformed packet")
+			continue
+		}
+
+		switch req {
+
+		case RRQ:
+			fileName, err := readString(tftpPacket)
+			if err != nil {
+				log.Println(err)
 				continue
 			}
-
-			mode, size = getString(packet[size+3:])
-			if size == 0 {
-				log.Println("failed to retrieve file mode from the packet")
-				continue
-			}
-
-			mode = strings.ToLower(mode)
-
-			if mode != MODE_OCTET && mode != MODE_ASCII {
-				log.Println("mode is not supported:", mode)
-				//continue
-			}
-		}
-
-		switch rq {
-		case RRQ:
-			log.Println("RRQ", mode, filename, "from", raddr.String())
-			f, err = NewTftpFile(filename)
+			fileMode, err := readString(tftpPacket)
 			if err != nil {
 				log.Println(err)
 				continue
 			}
 
+			fileMode = strings.ToLower(fileMode)
+
+			log.Println("RRQ", fileName, "from", raddr, "as", fileMode)
+			if fileMode != MODE_OCTET {
+				log.Println("mode is not supported:", fileMode)
+				continue
+			}
+
+			osFile, err := os.Open(fileName)
+			f, err := NewTftpFile(osFile, DEFAULT_CHUNK_SIZE)
+
+			clients[raddr.String()] = f
+
 			if c, ok := f.GetNextChunk(); ok {
 				conn.WriteTo(c.Packet(), raddr)
 			}
+
 		case ACK:
-			id := ((int(packet[2]) << 8) + (int(packet[3])))
-			log.Println("ACK", id, "from", raddr.String())
-			f.AckChunk(id)
+			tftpChunkId, _ := readUint16(tftpPacket)
+			log.Println("ACK", tftpChunkId, "from", raddr)
+
+			f, ok := clients[raddr.String()]
+			if !ok {
+				log.Println("Unexpected ACK from", raddr)
+				continue
+			}
+
+			f.AckChunk(tftpChunkId - 1)
 			if c, ok := f.GetNextChunk(); ok {
 				conn.WriteTo(c.Packet(), raddr)
 			}
+
 		default:
-			log.Println("Unsupported request ", rq, "from", raddr.String())
+			log.Println("Unsupported request", req, "from", raddr)
 		}
 	}
 }
+package main
+
+import (
+	"bytes"
+	"testing"
+)
+
+func TestChunk(t *testing.T) {
+	c := new(TftpChunk)
+	c.id = 31
+	c.data = []byte{1, 2, 3, 4, 5}
+
+	// test packet length
+	if c.Len() != 5 {
+		t.Fail()
+	}
+
+	p := c.Packet()
+	// test opcode
+	if makeUint16(p[0], p[1]) != DATA {
+		t.Fail()
+	}
+	// test id
+	if makeUint16(p[2], p[3]) != 31 {
+		t.Fail()
+	}
+	// test data
+	if bytes.Compare(p[4:], c.data) != 0 {
+		t.Fail()
+	}
+}
+
+func TestFile(t *testing.T) {
+	data := bytes.NewBuffer([]byte{1, 2, 3})
+	f, err := NewTftpFile(data, DEFAULT_CHUNK_SIZE)
+	if err != nil {
+		t.Fail()
+	}
+	if len(f.chunks) != 1 {
+		t.Fail()
+	}
+	if f.chunks[0].Len() != 3 {
+		t.Fail()
+	}
+}
+
+func Test4KFile(t *testing.T) {
+	block := make([]byte, 4096)
+	for i := 0; i < 4096; i++ {
+		block[i] = byte(i & 0xff)
+	}
+	data := bytes.NewBuffer(block)
+	f, err := NewTftpFile(data, DEFAULT_CHUNK_SIZE)
+	if err != nil {
+		t.Fail()
+	}
+	if len(f.chunks) != 9 {
+		t.Fail()
+	}
+	if f.chunks[8].Len() != 0 {
+		t.Fail()
+	}
+
+	f.AckChunk(0)
+	f.AckChunk(2)
+	f.AckChunk(4)
+	f.AckChunk(6)
+
+	if _, ok := f.GetNextChunk(); !ok {
+		t.Fail()
+	}
+
+	f.AckChunk(1)
+	f.AckChunk(3)
+	f.AckChunk(5)
+	f.AckChunk(7)
+
+	if _, ok := f.GetNextChunk(); !ok {
+		t.Fail()
+	}
+
+	f.AckChunk(8)
+	if _, ok := f.GetNextChunk(); ok {
+		t.Fail()
+	}
+}