Commits

Miki Tebeka committed 6ea8fa0

Move code to backends.go

Comments (0)

Files changed (4)

+package main
+
+import (
+	"fmt"
+	"strings"
+	"sync"
+)
+
+// List of backends, support (thread safe) adding, removing and getting next in list (circular)
+type Backends struct {
+	backends []string
+	current  int
+	lock     sync.Mutex
+}
+
+// Set sets the current list of backends
+func (bs *Backends) Set(backends []string) {
+	bs.lock.Lock()
+	defer bs.lock.Unlock()
+
+	bs.backends = backends
+	bs.current = 0
+}
+
+// Next returns the next back in circular fashion
+func (bs *Backends) Next() (string, error) {
+	bs.lock.Lock()
+	defer bs.lock.Unlock()
+
+	if len(bs.backends) == 0 {
+		return "", fmt.Errorf("empty backends")
+	}
+
+	// We advance first to make sure we're in bounds
+	bs.current = (bs.current + 1) % len(bs.backends)
+	backend := bs.backends[bs.current]
+	return backend, nil
+}
+
+// Add adds a new backend
+func (bs *Backends) Add(backend string) {
+	bs.lock.Lock()
+	defer bs.lock.Unlock()
+
+	bs.backends = append(bs.backends, backend)
+}
+
+// Remove removes all occurrences of backend from list of backends, returns the number of items removed
+func (bs *Backends) Remove(backend string) int {
+	bs.lock.Lock()
+	defer bs.lock.Unlock()
+
+	i, count := 0, 0
+	for i < len(bs.backends) {
+		if bs.backends[i] == backend {
+			count++
+			bs.backends = append(bs.backends[:i], bs.backends[i+1:]...)
+		} else {
+			i++
+		}
+	}
+
+	return count
+}
+
+// String is string representation of backends
+func (bs *Backends) String() string {
+	return strings.Join(bs.backends, ",")
+}
+package main
+
+import (
+	"testing"
+)
+
+// FIXME: Sync tests
+
+func TestBackendsSet(t *testing.T) {
+	bs := &Backends{}
+	if len(bs.backends) != 0 {
+		t.Fatalf("should be empty")
+	}
+
+	bs.current = 107
+	bs.Set([]string{"a", "b"})
+	if len(bs.backends) != 2 {
+		t.Fatalf("len should 2 (was %d)", len(bs.backends))
+	}
+
+	if bs.current != 0 {
+		t.Fatalf("current should be 0 (is %d)", bs.current)
+	}
+}
+
+func TestBackendsNext(t *testing.T) {
+	backend1, backend2 := "localhost:8888", "localhost:8887"
+	bs := &Backends{backends: []string{backend1, backend2}}
+
+	for i, expected := range []string{backend2, backend1, backend2} {
+		next, _ := bs.Next()
+		if next != expected {
+			t.Fatalf("backend should be %s at %d (was %s)", expected, i, next)
+		}
+	}
+
+	bs.Set([]string{})
+	_, err := bs.Next()
+	if err == nil {
+		t.Fatalf("managed to get backend from empty list")
+	}
+}
+
+func TestBackendsAdd(t *testing.T) {
+	bs := &Backends{}
+	bs.Add("a")
+	if len(bs.backends) != 1 {
+		t.Fatalf("not added")
+	}
+}
+
+func TestBackendsRemove(t *testing.T) {
+	count := -1
+
+	bs := &Backends{}
+	count = bs.Remove("a")
+	if count != 0 {
+		t.Fatalf("removed from empty list")
+	}
+
+	bs.Set([]string{"a", "b", "a"})
+	count = bs.Remove("a")
+	if count != 2 {
+		t.Fatalf("should remove 2 (was %d)", count)
+	}
+	if len(bs.backends) != 1 {
+		t.Fatalf("should have left one (have %d)", len(bs.backends))
+	}
+
+	if bs.backends[0] != "b" {
+		t.Fatalf("wrong one left (%s)", bs.backends[0])
+	}
+}
+
+func TestBackendsString(t *testing.T) {
+	bs := &Backends{}
+	if bs.String() != "" {
+		t.Fatalf("bad empty str: %s", bs.String())
+	}
+
+	bs.Set([]string{"a", "b"})
+	if bs.String() != "a,b" {
+		t.Fatalf("bad str %s (expected a,b)", bs.String())
+	}
+}
 	"os"
 	"regexp"
 	"strings"
-	"sync"
 )
 
 const (
 )
 
 // List of backends
-var backends []string
+var backends *Backends = &Backends{}
 
-// Current backend
-var currentBackend int
-
-// Sync backend changes
-var backendsLock sync.RWMutex
-
-// backend regular expression
+// backend regular expression (<host>:<port>)
 var backendRe *regexp.Regexp = regexp.MustCompile("^[^:]+:[0-9]+$")
 
 // isValidBackend returns true if backend is in "host:port" format
 	return backendRe.MatchString(backend)
 }
 
-// nextBackend returns the next backend to use (uses backendsLock.RLock)
-func nextBackend() (string, error) {
-	backendsLock.RLock()
-	defer backendsLock.RUnlock()
-
-	if len(backends) == 0 {
-		return "", fmt.Errorf("No backends")
-	}
-
-	currentBackend = (currentBackend + 1) % len(backends)
-	backend := backends[currentBackend]
-	return backend, nil
-}
-
 // parseBackends parses string in format "host:port,host:port" and return list of backends
 func parseBackends(str string) ([]string, error) {
 	backends := strings.Split(str, ",")
 // getHandler handles /current and return the current backend
 func getHandler(w http.ResponseWriter, req *http.Request) {
 	w.Header().Set("Content-Type", "text/plain")
-	fmt.Fprintf(w, "%s\n", strings.Join(backends, ","))
-}
-
-// setBackends sets the current list of backends and sets currentBackend to 0
-func setBackends(newBackends []string) {
-	backendsLock.Lock()
-	defer backendsLock.Unlock()
-
-	backends = newBackends
-	currentBackend = 0
+	fmt.Fprintf(w, "%s\n", backends)
 }
 
 // setHandler handler /set and sets backends
 		return
 	}
 
-	setBackends(newBackends)
+	backends.Set(newBackends)
 	getHandler(w, req)
 }
 
 		return
 	}
 
-	backendsLock.Lock()
-	defer backendsLock.Unlock()
-	backends = append(backends, backend)
+	backends.Add(backend)
 	getHandler(w, req)
 }
 
-// remove removes all items matching 'item' from items.
-func remove(items []string, item string) []string {
-	i := 0
-	for i < len(items) {
-		if items[i] == item {
-			items = append(items[:i], items[i+1:]...)
-		} else {
-			i++
-		}
-	}
-
-	return items
-}
-
 // removeHandler handles /remove and remove a backend
 func removeHandler(w http.ResponseWriter, req *http.Request) {
 	err := ""
 
 	defer func() {
 		if len(err) != 0 {
-			log.Println(err)
+			log.Printf("error: %s\n", err)
 			http.Error(w, err, http.StatusBadRequest)
 			return
 		} else {
 
 	backend := req.FormValue("backend")
 	if len(backend) == 0 {
-		err = "error: missing 'backend' parameter"
+		err = "missing 'backend' parameter"
 		return
 	}
 
-	backendsLock.Lock()
-	defer backendsLock.Unlock()
-	newBackends := remove(backends, backend)
-	if len(newBackends) == len(backends) {
-		err = fmt.Sprintf("error: backend '%s' not found", backend)
+	count := backends.Remove(backend)
+	if count == 0 {
+		err = fmt.Sprintf("backend '%s' not found", backend)
 		return
 	}
-
-	backends = newBackends
 }
 
 // seamless launches the HTTP API and then start proxying
-func seamless(localAddr string, apiPort int, backends []string, out chan error) {
+func seamless(localAddr string, apiPort int, backendList []string, out chan error) {
 	local, err := net.Listen("tcp", localAddr)
 	if local == nil {
 		out <- fmt.Errorf("cannot listen: %v", err)
 		return
 	}
 
+	backends.Set(backendList)
+
 	go func() {
 		if err := startHttpServer(apiPort); err != nil {
 			out <- fmt.Errorf("cannot listen on %d: %v", apiPort, err)
 		if conn == nil {
 			die("accept failed: %v", err)
 		}
-		backend, err := nextBackend()
+		backend, err := backends.Next()
 		if err != nil {
 			log.Printf("error: can't get next backend %v\n", err)
 			conn.Close()
 	localAddr := fmt.Sprintf(":%s", flag.Arg(0))
 
 	var err error
-	backends, err = parseBackends(flag.Arg(1))
+	backendList, err := parseBackends(flag.Arg(1))
 	if err != nil {
 		die(fmt.Sprintf("%s", err))
 	}
 
 	out := make(chan error)
-	go seamless(localAddr, *port, backends, out)
+	go seamless(localAddr, *port, backendList, out)
 
 	err = <-out
 	if err != nil {
 	}
 
 	out := make(chan error)
-	go seamless(fmt.Sprintf(":%d", proxyPort), apiPort, backends, out)
+	go seamless(fmt.Sprintf(":%d", proxyPort), apiPort, []string{}, out)
 
 	time.Sleep(1 * time.Second)
 }
 }
 
 func TestHTTPGet(t *testing.T) {
-	setBackends([]string{"localhost:8080"})
+	backends.Set([]string{"localhost:8080"})
 	reply, err := callAPI("get")
 	if err != nil {
 		t.Fatalf("%s", err)
 	}
 
-	if reply != fmt.Sprintf("%s\n", backends[0]) {
+	if reply != fmt.Sprintf("%s\n", backends.backends[0]) {
 		t.Fatalf("bad reply: %s\n", string(reply))
 	}
 }
 
 func TestHTTPAdd(t *testing.T) {
-	setBackends([]string{"localhost:8888"})
+	backends.Set([]string{"localhost:8888"})
 	backend := "localhost:8887"
 
 	reply, err := callAPI(fmt.Sprintf("add?backend=%s", backend))
 		t.Fatalf("%s", err)
 	}
 
-	if len(backends) != 2 {
-		t.Fatalf("bad number of backends (%d)\nreply: %s", len(backends), reply)
+	if len(backends.backends) != 2 {
+		t.Fatalf("bad number of backends (%d)\nreply: %s", len(backends.backends), reply)
 	}
 
-	if backends[1] != backend {
-		t.Fatalf("bad backend - %s", backends[0])
+	if backends.backends[1] != backend {
+		t.Fatalf("bad backend - %s", backends.backends[0])
 	}
 
-	if reply != fmt.Sprintf("%s,%s\n", backends[0], backends[1]) {
+	if reply != fmt.Sprintf("%s,%s\n", backends.backends[0], backends.backends[1]) {
 		t.Fatalf("bad reply - %s\n", reply)
 	}
 }
 
 func TestHTTPRemove(t *testing.T) {
 	backend1, backend2 := "localhost:8888", "localhost:8887"
-	setBackends([]string{backend1, backend2})
+	backends.Set([]string{backend1, backend2})
 	reply, err := callAPI(fmt.Sprintf("remove?backend=%s", backend1))
 	if err != nil {
 		t.Fatalf("%s", err)
 	}
 
-	if len(backends) != 1 {
-		t.Fatalf("bad number of backends (%d)\nreply: %s", len(backends), reply)
+	if len(backends.backends) != 1 {
+		t.Fatalf("bad number of backends (%d)\nreply: %s", len(backends.backends), reply)
 	}
 
-	if backends[0] != backend2 {
-		t.Fatalf("bad backend left - %s", backends[0])
+	if backends.backends[0] != backend2 {
+		t.Fatalf("bad backend left - %s", backends.backends[0])
 	}
 }
 
 	}
 }
 
-func Test_nextBackend(t *testing.T) {
-	backend1, backend2 := "localhost:8888", "localhost:8887"
-	setBackends([]string{backend1, backend2})
-
-	for i, expected := range []string{backend2, backend1, backend2} {
-		next, _ := nextBackend()
-		if next != expected {
-			t.Fatalf("backend should be %s at %d (was %s)", expected, i, next)
-		}
-	}
-
-	backends = []string{backend1, backend2}
-	_, err := nextBackend()
-	if err != nil {
-		t.Fatalf("managed to get backend from empty list")
-	}
-}
-
 func arreq(a, b []string) bool {
 	if len(a) != len(b) {
 		return false
 }
 
 func TestProxy(t *testing.T) {
-	setBackends([]string{backendAddr(0), backendAddr(1)})
+	backends.Set([]string{backendAddr(0), backendAddr(1)})
 
 	for i := 0; i < 7; i++ {
 		reply, err := callProxy()
 		if err != nil {
 			t.Fatalf("can't call proxy - %v", err)
 		}
-		expected := fmt.Sprintf("%d", (i+1)%len(backends))
+		expected := fmt.Sprintf("%d", (i+1)%len(backends.backends))
 		if reply != expected {
 			t.Fatalf("bad backend for i=%d: got %s instead of %s", i, reply, expected)
 		}
 }
 
 func TestProxyRemove(t *testing.T) {
-	setBackends([]string{backendAddr(0), backendAddr(1)})
+	backends.Set([]string{backendAddr(0), backendAddr(1)})
 	suffix := fmt.Sprintf("remove?backend=%s", backendAddr(0))
 	if _, err := callAPI(suffix); err != nil {
 		t.Fatalf("can't remove %s - %s", backendAddr(0), err)
 }
 
 func TestProxyAdd(t *testing.T) {
-	setBackends([]string{backendAddr(0), backendAddr(1)})
+	backends.Set([]string{backendAddr(0), backendAddr(1)})
 
 	suffix := fmt.Sprintf("add?backend=%s", backendAddr(2))
 	if _, err := callAPI(suffix); err != nil {
 		if err != nil {
 			t.Fatalf("can't call proxy - %v", err)
 		}
-		expected := fmt.Sprintf("%d", (i+1)%len(backends))
+		expected := fmt.Sprintf("%d", (i+1)%len(backends.backends))
 		if reply != expected {
 			t.Fatalf("bad reply %s (expected %s)", reply, expected)
 		}
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.