Commits

Miki Tebeka committed 82929f0 Merge

Merge mutli

Comments (0)

Files changed (4)

+2012-09-06 version 0.2.0
+    * Support multiple backends [issue #3]
+
 2012-09-03 version 0.1.2
     * Synchronize backend change [issue #2] (Thanks Jonathan Amsterdam)
 
-`seamless` is a TCP proxy that allow you to deploy new code then switch traffic
-to new backend without downtime.
+`seameless` is a TCP proxy that allow you to deploy new code then switch traffic
+to it without downtime.
 
-Switching backends is done with HTTP interface (*on a different port*) with the
-following API:
+It does "round robin" between the list of current active backends.
 
-    `/switch?backend=address` 
-        switch traffic to new backend
+Switching server is done with HTTP interface with the following API:
 
-    `/current` 
-        return (in plain text) current server
+/set?backends=host:port,host:port
+    set list of backends
+
+/add?backend=host:port
+    add a backend
+
+/remove?backend=host:port
+    remove a backend
+
+/get
+    return host:port,host:port
 
 Process
 =======
-* Start first backend at port 4444
-* Run
-  ::
+* Start `seamleass` with list of active backends::
 
     seamless 8080 localhost:4444
 * Direct all traffic to port 8080 on local machine.
+* When you need to add/remove the backend, use the HTTP API on port 6777
+  different port, say 4445)::
 
-When you need to upgrade the backend, start a new one (with new code on a
-different port, say 4445). Then::
+    curl http://localhost:6777/add?backend=localhost:4445
+    curl http://localhost:6777/remove?backend=localhost:4444
 
-    curl http://localhost:6777/switch?backend=localhost:4445
+  Or::
 
-
-(Note that management port is different from the one we proxy).
+        curl http://localhost:6777/set?backends=localhost:4445
+    
+New traffic will be directed to new backend(s).
 
 Installing
 ==========
 /* A TCP proxy that allow you to deploy new code then switch traffic to it
    without downtime.
 
+   It does "round robin" between the list of current active backends.
+
    Switching server is done with HTTP interface with the following API:
-   /switch?backend=address - will switch traffic to new backend
-   /current - will return (in plain text) current server
+   /set?backends=host:port,host:port - will set list of backends
+   /add?backend=host:port - will add a backend
+   /remove?backend=host:port - will remove a backend
+   /get - will return host:port,host:port
 
    Work flow:
 	   Start first backend at port 4444
 
 	   When you need to upgrade the backend, start a new one (with new code on
 	   a different port, say 4445).
-	   The `curl http://localhost:6777/switch?backend=localhost:4445. 
-	   New traffic will be directed to new server.
+	   Then
+			* `curl http://localhost:6777/add?backend=localhost:4445`
+			* `curl http://localhost:6777/remove?backend=localhost:4444`
+	   Or
+		`curl http://localhost:6777/set?backends=localhost:4445`
+
+	   New traffic will be directed to new server(s).
 
 Original forward code by Roger Peppe (see http://bit.ly/Oc1YtF)
 */
 	"net"
 	"net/http"
 	"os"
+	"regexp"
+	"strings"
 	"sync"
 )
 
 const (
-	Version = "0.1.2"
+	Version = "0.2.0"
 )
 
+// List of backends
+var backends []string
+
 // Current backend
-var backend string
+var currentBackend int
 
 // Sync backend changes
-var backendLock sync.RWMutex
+var backendsLock sync.RWMutex
 
-// currentBackend returns the current value of the backend in atomic format.
-// (Uses backendLock.RLock)
-func currentBackend() (reply string) {
-	backendLock.RLock()
-	reply = backend
-	backendLock.RUnlock()
+// backend regular expression
+var backendRe *regexp.Regexp = regexp.MustCompile("^[^:]+:[0-9]+$")
 
-	return
+// isValidBackend returns true if backend is in "host:port" format
+func isValidBackend(backend string) bool {
+	return backendRe.MatchString(backend)
 }
 
-func main() {
-	flag.Usage = func() {
-		fmt.Fprintf(os.Stderr, "usage: seamless LISTEN_PORT BACKEND\n")
-		fmt.Fprintf(os.Stderr, "command line switches:\n")
-		flag.PrintDefaults()
-	}
-	port := flag.Int("httpPort", 6777, "http interface port")
-	version := flag.Bool("version", false, "show version and exit")
-	flag.Parse()
+// nextBackend returns the next backend to use (uses backendsLock.RLock)
+func nextBackend() (string, error) {
+	backendsLock.RLock()
+	defer backendsLock.RUnlock()
 
-	if *version {
-		fmt.Printf("seamless %s\n", Version)
-		os.Exit(0)
+	if len(backends) == 0 {
+		return "", fmt.Errorf("No backends")
 	}
 
-	if flag.NArg() != 2 {
-		flag.Usage()
-		os.Exit(1)
-	}
-	localAddr := fmt.Sprintf(":%s", flag.Arg(0))
-	backend = flag.Arg(1)
+	currentBackend = (currentBackend + 1) % len(backends)
+	backend := backends[currentBackend]
+	return backend, nil
+}
 
-	local, err := net.Listen("tcp", localAddr)
-	if local == nil {
-		die("cannot listen: %v", err)
+// parseBackends parses string in format "host:port,host:port" and return list of backends
+func parseBackends(str string) ([]string, error) {
+	backends := strings.Split(str, ",")
+	if len(backends) == 0 {
+		return nil, fmt.Errorf("no backends")
 	}
 
-	go func() {
-		if err := startHttpServer(*port); err != nil {
-			die("cannot listen on %d: %v", *port, err)
+	for i, v := range backends {
+		backends[i] = strings.TrimSpace(v)
+		if !isValidBackend(backends[i]) {
+			return nil, fmt.Errorf("'%s' is not valid network address", backends[i])
 		}
-	}()
+	}
 
-	for {
-		conn, err := local.Accept()
-		if conn == nil {
-			die("accept failed: %v", err)
-		}
-		go forward(conn, currentBackend())
-	}
+	return backends, nil
 }
 
 // forward proxies traffic between local socket and remote backend
 
 // startHttpServer start the HTTP server interface in a given port
 func startHttpServer(port int) error {
-	http.HandleFunc("/switch", switchHandler)
-	http.HandleFunc("/current", currentHandler)
+	http.HandleFunc("/set", setHandler)
+	http.HandleFunc("/get", getHandler)
+	http.HandleFunc("/add", addHandler)
+	http.HandleFunc("/remove", removeHandler)
+
 	return http.ListenAndServe(fmt.Sprintf(":%d", port), nil)
 }
 
-// switchHandler handler /switch and switches backend
-func switchHandler(w http.ResponseWriter, req *http.Request) {
-	newBackend := req.FormValue("backend")
-	if len(newBackend) == 0 {
+// getHandler handles /current and return the current backend
+func getHandler(w http.ResponseWriter, req *http.Request) {
+	w.Header().Set("Content-Type", "text/plain")
+	fmt.Fprintf(w, "%s\n", strings.Join(backends, ","))
+}
+
+// setBackends sets the current list of backends and sets currentBackend to 0
+func setBackends(newBackends []string) {
+	backendsLock.Lock()
+	defer backendsLock.Unlock()
+
+	backends = newBackends
+	currentBackend = 0
+}
+
+// setHandler handler /set and sets backends
+func setHandler(w http.ResponseWriter, req *http.Request) {
+	newBackends, err := parseBackends(req.FormValue("backends"))
+	if err != nil {
+		msg := fmt.Sprintf("error: %s", err)
+		log.Println(msg)
+		http.Error(w, msg, http.StatusBadRequest)
+		return
+	}
+
+	setBackends(newBackends)
+	getHandler(w, req)
+}
+
+// addHandler handles /add to add a new backend
+func addHandler(w http.ResponseWriter, req *http.Request) {
+	backend := req.FormValue("backend")
+	if len(backend) == 0 {
 		msg := "error: missing 'backend' parameter"
 		log.Println(msg)
 		http.Error(w, msg, http.StatusBadRequest)
 		return
 	}
 
-	backendLock.Lock()
-	backend = newBackend
-	backendLock.Unlock()
-	currentHandler(w, req)
+	backendsLock.Lock()
+	defer backendsLock.Unlock()
+	backends = append(backends, backend)
+	getHandler(w, req)
 }
 
-// currentHandler handles /current and return the current backend
-func currentHandler(w http.ResponseWriter, req *http.Request) {
-	w.Header().Set("Content-Type", "text/plain")
-	fmt.Fprintf(w, "%s\n", backend)
+// remove removes all items matching 'item' from items.
+func remove(items []string, item string) []string {
+	i := 0
+	for i < len(items) {
+		if items[i] == item {
+			items = append(items[:i], items[i+1:]...)
+		} else {
+			i++
+		}
+	}
+
+	return items
 }
+
+// removeHandler handles /remove and remove a backend
+func removeHandler(w http.ResponseWriter, req *http.Request) {
+	err := ""
+
+	defer func() {
+		if len(err) != 0 {
+			log.Println(err)
+			http.Error(w, err, http.StatusBadRequest)
+			return
+		} else {
+			getHandler(w, req)
+		}
+	}()
+
+	backend := req.FormValue("backend")
+	if len(backend) == 0 {
+		err = "error: missing 'backend' parameter"
+		return
+	}
+
+	backendsLock.Lock()
+	defer backendsLock.Unlock()
+	newBackends := remove(backends, backend)
+	if len(newBackends) == len(backends) {
+		err = fmt.Sprintf("error: backend '%s' not found", backend)
+		return
+	}
+
+	backends = newBackends
+}
+
+// seamless launches the HTTP API and then start proxying
+func seamless(localAddr string, apiPort int, backends []string, out chan error) {
+	local, err := net.Listen("tcp", localAddr)
+	if local == nil {
+		out <- fmt.Errorf("cannot listen: %v", err)
+		return
+	}
+
+	go func() {
+		if err := startHttpServer(apiPort); err != nil {
+			out <- fmt.Errorf("cannot listen on %d: %v", apiPort, err)
+		}
+	}()
+
+	for {
+		conn, err := local.Accept()
+		if conn == nil {
+			die("accept failed: %v", err)
+		}
+		backend, err := nextBackend()
+		if err != nil {
+			log.Printf("error: can't get next backend %v\n", err)
+			conn.Close()
+		}
+		go forward(conn, backend)
+	}
+}
+
+func main() {
+	flag.Usage = func() {
+		fmt.Fprintf(os.Stderr, "usage: seamless LISTEN_PORT BACKENDS\n")
+		fmt.Fprintf(os.Stderr, "command line switches:\n")
+		flag.PrintDefaults()
+	}
+	port := flag.Int("httpPort", 6777, "http interface port")
+	version := flag.Bool("version", false, "show version and exit")
+	flag.Parse()
+
+	if *version {
+		fmt.Printf("seamless %s\n", Version)
+		os.Exit(0)
+	}
+
+	if flag.NArg() != 2 {
+		flag.Usage()
+		os.Exit(1)
+	}
+	localAddr := fmt.Sprintf(":%s", flag.Arg(0))
+
+	var err error
+	backends, err = parseBackends(flag.Arg(1))
+	if err != nil {
+		die(fmt.Sprintf("%s", err))
+	}
+
+	out := make(chan error)
+	go seamless(localAddr, *port, backends, out)
+
+	err = <-out
+	if err != nil {
+		die("%s", err)
+	}
+}
 	"time"
 )
 
-func TestHttp(t *testing.T) {
-	backend = "hello"
-	port := 6777
-	go startHttpServer(port)
+var apiPort int = 6777
+var numBackends int = 3
+var proxyPort = 6888
+
+type testHandler int
+
+// startBackend spawns an HTTP server the listens on 6700+i port and replies with i to requests.
+func (h testHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	w.Header().Set("Content-Type", "text/plain")
+	fmt.Fprintf(w, "%d", h)
+}
+
+func startBackend(i int) {
+	handler := testHandler(i)
+	port := 6700 + i
+	server := http.Server{Handler: handler, Addr: fmt.Sprintf(":%d", port)}
+	go server.ListenAndServe()
+}
+
+func init() {
+	for i := 0; i < numBackends; i++ {
+		startBackend(i)
+	}
+
+	out := make(chan error)
+	go seamless(fmt.Sprintf(":%d", proxyPort), apiPort, backends, out)
 
 	time.Sleep(1 * time.Second)
+}
 
-	resp, err := http.Get(fmt.Sprintf("http://localhost:%d/current", port))
+func call(url string) (string, error) {
+	// We really don't want keep alive or caching :)
+	client := &http.Client{Transport: &http.Transport{DisableKeepAlives: true}}
+	resp, err := client.Get(url)
 	if err != nil {
-		t.Fatalf("error connecting to /current: %v\n", err)
+		return "", fmt.Errorf("can't GET %s: %v\n", url, err)
 	}
 	defer resp.Body.Close()
 
 	reply, err := ioutil.ReadAll(resp.Body)
 	if err != nil {
-		t.Fatalf("error reading reply: %v\n", err)
+		return "", fmt.Errorf("error reading reply: %v\n", err)
 	}
 
-	if string(reply) != fmt.Sprintf("%s\n", backend) {
+	return string(reply), nil
+}
+
+func callAPI(suffix string) (string, error) {
+	url := fmt.Sprintf("http://localhost:%d/%s", apiPort, suffix)
+	return call(url)
+}
+
+func TestHTTPGet(t *testing.T) {
+	setBackends([]string{"localhost:8080"})
+	reply, err := callAPI("get")
+	if err != nil {
+		t.Fatalf("%s", err)
+	}
+
+	if reply != fmt.Sprintf("%s\n", backends[0]) {
 		t.Fatalf("bad reply: %s\n", string(reply))
 	}
 }
+
+func TestHTTPAdd(t *testing.T) {
+	setBackends([]string{"localhost:8888"})
+	backend := "localhost:8887"
+
+	reply, err := callAPI(fmt.Sprintf("add?backend=%s", backend))
+	if err != nil {
+		t.Fatalf("%s", err)
+	}
+
+	if len(backends) != 2 {
+		t.Fatalf("bad number of backends (%d)\nreply: %s", len(backends), reply)
+	}
+
+	if backends[1] != backend {
+		t.Fatalf("bad backend - %s", backends[0])
+	}
+
+	if reply != fmt.Sprintf("%s,%s\n", backends[0], backends[1]) {
+		t.Fatalf("bad reply - %s\n", reply)
+	}
+}
+
+func TestHTTPRemove(t *testing.T) {
+	backend1, backend2 := "localhost:8888", "localhost:8887"
+	setBackends([]string{backend1, backend2})
+	reply, err := callAPI(fmt.Sprintf("remove?backend=%s", backend1))
+	if err != nil {
+		t.Fatalf("%s", err)
+	}
+
+	if len(backends) != 1 {
+		t.Fatalf("bad number of backends (%d)\nreply: %s", len(backends), reply)
+	}
+
+	if backends[0] != backend2 {
+		t.Fatalf("bad backend left - %s", backends[0])
+	}
+}
+
+func Test_isValidBackend(t *testing.T) {
+	names := map[bool]string{
+		true:  "valid",
+		false: "invalid",
+	}
+
+	cases := []struct {
+		value string
+		valid bool
+	}{
+		{"localhost:7", true},
+		{"foo.com:8080", true},
+		{"", false},
+		{"foo.com", false},
+		{"localhost", false},
+		{"foo.com:", false},
+	}
+
+	for _, c := range cases {
+		if isValidBackend(c.value) != c.valid {
+			t.Fatalf("`%s` should be %s", c.value, names[c.valid])
+		}
+	}
+}
+
+func Test_nextBackend(t *testing.T) {
+	backend1, backend2 := "localhost:8888", "localhost:8887"
+	setBackends([]string{backend1, backend2})
+
+	for i, expected := range []string{backend2, backend1, backend2} {
+		next, _ := nextBackend()
+		if next != expected {
+			t.Fatalf("backend should be %s at %d (was %s)", expected, i, next)
+		}
+	}
+
+	backends = []string{backend1, backend2}
+	_, err := nextBackend()
+	if err != nil {
+		t.Fatalf("managed to get backend from empty list")
+	}
+}
+
+func arreq(a, b []string) bool {
+	if len(a) != len(b) {
+		return false
+	}
+
+	for i, v := range a {
+		if b[i] != v {
+			return false
+		}
+	}
+
+	return true
+}
+
+func Test_parseBackends(t *testing.T) {
+	cases := []struct {
+		backends string
+		expected []string
+		ok       bool
+	}{
+		{"localhost:8080", []string{"localhost:8080"}, true},
+		{"localhost:8080,localhost:8887", []string{"localhost:8080", "localhost:8887"}, true},
+		{"", []string{}, false},
+		{"foo", []string{}, false},
+		{"localhost:8080,localhost", []string{}, false},
+	}
+
+	for _, c := range cases {
+		value, err := parseBackends(c.backends)
+		ok := err == nil
+
+		if ok != c.ok {
+			t.Fatalf("bad error for %v", c.backends)
+		}
+
+		if !arreq(value, c.expected) {
+			t.Fatalf("go %v for %s (expected %v)", value, c.backends, c.expected)
+		}
+
+	}
+}
+
+func backendAddr(i int) string {
+	return fmt.Sprintf("localhost:%d", 6700+i)
+}
+
+func callProxy() (string, error) {
+	url := fmt.Sprintf("http://localhost:%d", proxyPort)
+	return call(url)
+}
+
+func TestProxy(t *testing.T) {
+	setBackends([]string{backendAddr(0), backendAddr(1)})
+
+	for i := 0; i < 7; i++ {
+		reply, err := callProxy()
+		if err != nil {
+			t.Fatalf("can't call proxy - %v", err)
+		}
+		expected := fmt.Sprintf("%d", (i+1)%len(backends))
+		if reply != expected {
+			t.Fatalf("bad backend for i=%d: got %s instead of %s", i, reply, expected)
+		}
+	}
+}
+
+func TestProxyRemove(t *testing.T) {
+	setBackends([]string{backendAddr(0), backendAddr(1)})
+	suffix := fmt.Sprintf("remove?backend=%s", backendAddr(0))
+	if _, err := callAPI(suffix); err != nil {
+		t.Fatalf("can't remove %s - %s", backendAddr(0), err)
+	}
+
+	for i := 0; i < 7; i++ {
+		reply, err := callProxy()
+		if err != nil {
+			t.Fatalf("can't call proxy - %v", err)
+		}
+		if reply != "1" {
+			t.Fatalf("bad reply %s (expected 1)", reply)
+		}
+	}
+}
+
+func TestProxyAdd(t *testing.T) {
+	setBackends([]string{backendAddr(0), backendAddr(1)})
+
+	suffix := fmt.Sprintf("add?backend=%s", backendAddr(2))
+	if _, err := callAPI(suffix); err != nil {
+		t.Fatalf("can't remove %s - %s", backendAddr(0), err)
+	}
+
+	for i := 0; i < 7; i++ {
+		reply, err := callProxy()
+		if err != nil {
+			t.Fatalf("can't call proxy - %v", err)
+		}
+		expected := fmt.Sprintf("%d", (i+1)%len(backends))
+		if reply != expected {
+			t.Fatalf("bad reply %s (expected %s)", reply, expected)
+		}
+	}
+}