Ross Light avatar Ross Light committed 03e206a

Add SQL database functions

Comments (0)

Files changed (2)

+package webapp
+
+import (
+	"database/sql"
+	"errors"
+	"reflect"
+	"strings"
+)
+
+// ScanOneStruct is equivalent to calling rows.Next, ScanStruct, then rows.Close.
+// If there is no next row, sql.ErrNoRows is returned.
+func ScanOneStruct(rows *sql.Rows, val interface{}) error {
+	defer rows.Close()
+	if !rows.Next() {
+		return sql.ErrNoRows
+	}
+	return ScanStruct(rows, val)
+}
+
+// ScanStruct extracts a single row into the struct pointed to by val.
+//
+// Each exported struct field is converted to a column name by using the "sql"
+// tag of that field, or by transforming the field name from upper camel case to
+// lowercase underscore-separated words if there is no tag.  If the field tag is
+// "-", then the field will be ignored.
+//
+// It is not an error to have extra columns or fields.
+func ScanStruct(rows *sql.Rows, val interface{}) error {
+	v := reflect.ValueOf(val)
+	if v.Kind() != reflect.Ptr {
+		return errors.New("db: val is not a pointer")
+	}
+	v = v.Elem()
+	if v.Kind() != reflect.Struct {
+		return errors.New("db: val is not a pointer to a struct")
+	}
+	cols, err := rows.Columns()
+	if err != nil {
+		return err
+	}
+
+	dest := make([]interface{}, len(cols))
+	var placeholder interface{}
+	for i := range dest {
+		dest[i] = &placeholder
+	}
+
+	for fi := 0; fi < v.NumField(); fi++ {
+		name := fieldColumn(v.Type().Field(fi))
+		if name == "" {
+			continue
+		}
+		for i := range cols {
+			if cols[i] == name {
+				dest[i] = v.Field(fi).Addr().Interface()
+				break
+			}
+		}
+	}
+	return rows.Scan(dest...)
+}
+
+// TransactionError is returned by RunInTransaction.
+type TransactionError struct {
+	Err   error // Error surrounding transaction
+	TxErr error // Error during transaction
+}
+
+func (e *TransactionError) Error() string {
+	if e.Err != nil {
+		return "during transaction: " + e.Err.Error()
+	}
+	return "transaction: " + e.TxErr.Error()
+}
+
+// RunInTransaction executes a SQL operation in a transaction.  Any non-nil
+// error will be wrapped in a TransactionError.
+func RunInTransaction(db *sql.DB, f func(*sql.Tx) error) error {
+	tx, err := db.Begin()
+	if err != nil {
+		return &TransactionError{nil, err}
+	}
+	err = f(tx)
+	var txErr error
+	if err == nil {
+		txErr = tx.Commit()
+	} else {
+		txErr = tx.Rollback()
+	}
+	if err != nil || txErr != nil {
+		return &TransactionError{err, txErr}
+	}
+	return nil
+}
+
+// fieldColumn returns the SQL column name for a struct field.  An empty string
+// means this field should not be persisted.
+func fieldColumn(f reflect.StructField) string {
+	if tag := f.Tag.Get("sql"); tag == "-" {
+		return ""
+	} else if tag != "" {
+		return tag
+	}
+	return colname(f.Name)
+}
+
+// colname converts from upper camel case to underscore-separated words.
+func colname(s string) string {
+	words := make([]string, 0)
+	var start int
+	lastIdx, lastRune := -1, rune(0)
+	for i, r := range s {
+		if i == 0 {
+			// don't create a new word, just grab the rune
+		} else if isUpper(lastRune) && isLower(r) && start != lastIdx {
+			words = append(words, strings.ToLower(s[start:lastIdx]))
+			start = lastIdx
+		} else if isLower(lastRune) && isUpper(r) && start != i {
+			words = append(words, strings.ToLower(s[start:i]))
+			start = i
+		}
+		lastIdx, lastRune = i, r
+	}
+	words = append(words, strings.ToLower(s[start:]))
+	return strings.Join(words, "_")
+}
+
+func lower(r rune) rune {
+	return (r - 'A') + 'a'
+}
+
+func isUpper(r rune) bool {
+	return r >= 'A' && r <= 'Z'
+}
+
+func isLower(r rune) bool {
+	return r >= 'a' && r <= 'z'
+}
+package webapp
+
+import (
+	"database/sql"
+	"log"
+	"testing"
+)
+
+func TestColname(t *testing.T) {
+	tests := []struct {
+		s   string
+		out string
+	}{
+		{"", ""},
+		{"a", "a"},
+		{"A", "a"},
+		{"ID", "id"},
+		{"Foo", "foo"},
+		{"FooB", "foo_b"},
+		{"FirstName", "first_name"},
+		{"StudentID", "student_id"},
+		{"studentID", "student_id"},
+		{"FooIDBar", "foo_id_bar"},
+		{"FooID4Bar", "foo_id4_bar"},
+	}
+	for _, test := range tests {
+		result := colname(test.s)
+		if result != test.out {
+			t.Errorf("colname(%q) = %v; want %v", test.s, result, test.out)
+		}
+	}
+}
+
+func ExampleScanStruct() {
+	type Person struct {
+		ID        int    `sql:"ID"`
+		FirstName string // implicitly "first_name"
+		LastName  string // implicitly "last_name"
+		Ignored   bool   `sql:"-"`
+	}
+
+	var person Person
+	db, err := sql.Open("sqlite3", ":memory:")
+	if err != nil {
+		log.Fatal(err)
+	}
+	_, err = db.Exec("CREATE TABLE person ( ID integer, first_name text, last_name text );")
+	if err != nil {
+		log.Fatal(err)
+	}
+	_, err = db.Exec("INSERT INTO person ( ID, first_name, last_name ) VALUES (1, 'John', 'Doe');")
+	if err != nil {
+		log.Fatal(err)
+	}
+	rows, err := db.Query("SELECT * FROM person;")
+	defer rows.Close()
+	if err != nil {
+		log.Fatal(err)
+	}
+	if err := ScanStruct(rows, &person); err != nil {
+		log.Fatal(err)
+	}
+}
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.