From 5161280cf3ccf895865a44382c2d55a4565b4e25 Mon Sep 17 00:00:00 2001 From: tavo Date: Mon, 30 Jun 2025 01:47:53 -0600 Subject: [PATCH] forms parser, in-memory sessions, storage abstraction --- forms/formatters.go | 15 +++ forms/forms.go | 197 ++++++++++++++++++++++++++++++++++++++ forms/forms_test.go | 59 ++++++++++++ forms/validators.go | 41 ++++++++ go.mod | 3 + sessions/sessions.go | 165 +++++++++++++++++++++++++++++++ sessions/sessions_test.go | 45 +++++++++ storage/storage.go | 126 ++++++++++++++++++++++++ 8 files changed, 651 insertions(+) create mode 100644 forms/formatters.go create mode 100644 forms/forms.go create mode 100644 forms/forms_test.go create mode 100644 forms/validators.go create mode 100644 go.mod create mode 100644 sessions/sessions.go create mode 100644 sessions/sessions_test.go create mode 100644 storage/storage.go diff --git a/forms/formatters.go b/forms/formatters.go new file mode 100644 index 0000000..7d31664 --- /dev/null +++ b/forms/formatters.go @@ -0,0 +1,15 @@ +package forms + +import ( + "strings" +) + +func capitalize(s string) string { + words := strings.Fields(s) + for i, word := range words { + if len(word) > 0 { + words[i] = strings.ToUpper(string(word[0])) + strings.ToLower(word[1:]) + } + } + return strings.Join(words, " ") +} diff --git a/forms/forms.go b/forms/forms.go new file mode 100644 index 0000000..09e2fa5 --- /dev/null +++ b/forms/forms.go @@ -0,0 +1,197 @@ +package forms + +import ( + "fmt" + "net/http" + "net/url" + "reflect" + "strconv" + "strings" +) + +type Formatter func(string) string + +var Formatters = map[string]Formatter{ + "trim": strings.TrimSpace, + "lower": strings.ToLower, + "upper": strings.ToUpper, + "capitalize": capitalize, +} + +type Validator func(fieldName string, value any, param string) error + +var Validators = map[string]Validator{ + "nonzero": nonzero, + "minlen": minlen, + "email": email, +} + +func FormToStruct[T any](r *http.Request) (T, error) { + var target T + if err := r.ParseForm(); err != nil { + return target, fmt.Errorf("error parsing form: %v", err) + } + err := UrlValuesToStruct(r.Form, &target) + return target, err +} + +func UrlValuesToStruct(form url.Values, dst any) error { + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return fmt.Errorf("dst must be a pointer to a struct") + } + + v = v.Elem() + t := v.Type() + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldValue := v.Field(i) + + if !fieldValue.CanSet() { + continue + } + + key := field.Tag.Get("form") + required := field.Tag.Get("req") == "1" + formatters := parseFormatters(field.Tag.Get("fmt")) + validateTags := parseValidators(field.Tag.Get("validate")) + + values, ok := form[key] + if !ok || len(values) == 0 { + if required { + return fmt.Errorf("missing required form field: %s", key) + } + continue + } + + for i := range values { + for _, fmtFunc := range formatters { + values[i] = fmtFunc(values[i]) + } + } + + fieldKind := fieldValue.Kind() + + if fieldKind == reflect.Slice { + elemKind := field.Type.Elem().Kind() + + castedSlice, err := castStringSliceToType(values, elemKind) + if err != nil { + return fmt.Errorf("field '%s': %v", field.Name, err) + } + + sliceValue := reflect.MakeSlice(field.Type, len(castedSlice), len(castedSlice)) + for i, val := range castedSlice { + sliceValue.Index(i).Set(reflect.ValueOf(val).Convert(field.Type.Elem())) + } + fieldValue.Set(sliceValue) + } else { + if len(values) != 1 { + return fmt.Errorf("field '%s' expects a single value", field.Name) + } + + castedVal, err := castStringToType(values[0], fieldKind) + if err != nil { + return fmt.Errorf("field '%s': %v", field.Name, err) + } + fieldValue.Set(reflect.ValueOf(castedVal).Convert(field.Type)) + } + + var finalValue any + if fieldKind == reflect.Slice { + finalValue = fieldValue.Interface() + } else { + finalValue = fieldValue.Interface() + } + + for _, validator := range validateTags { + if fn, ok := Validators[validator.Name]; ok { + if err := fn(field.Name, finalValue, validator.Param); err != nil { + return err + } + } + } + } + + return nil +} + +func castStringSliceToType(input []string, kind reflect.Kind) ([]any, error) { + var output []any + + for _, value := range input { + cast, err := castStringToType(value, kind) + if err != nil { + return nil, err + } + + output = append(output, cast) + } + + return output, nil +} + +func castStringToType(value string, kind reflect.Kind) (any, error) { + switch kind { + case reflect.String: + return value, nil + case reflect.Int, reflect.Int64: + new, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("failed to cast string to integer: %v", err) + } + return new, nil + case reflect.Float32, reflect.Float64: + new, err := strconv.ParseFloat(value, 64) + if err != nil { + return nil, fmt.Errorf("failed to cast string to float64: %v", err) + } + return new, nil + case reflect.Bool: + new, err := strconv.ParseBool(value) + if err != nil { + return nil, fmt.Errorf("failed to cast string to boolean: %v", err) + } + return new, nil + default: + return nil, fmt.Errorf("unsupported kind: %s", kind) + } +} + +func parseFormatters(tag string) []func(string) string { + if tag == "" { + return nil + } + parts := strings.Split(tag, ",") + var fns []func(string) string + for _, p := range parts { + if fn, ok := Formatters[strings.TrimSpace(p)]; ok { + fns = append(fns, fn) + } + } + return fns +} + +func parseValidators(tag string) []struct { + Name string + Param string +} { + if tag == "" { + return nil + } + var result []struct { + Name string + Param string + } + parts := strings.SplitSeq(tag, ",") + for part := range parts { + pair := strings.SplitN(part, ":", 2) + if len(pair) == 2 { + result = append(result, struct{ Name, Param string }{pair[0], pair[1]}) + } else { + result = append(result, struct{ Name, Param string }{pair[0], ""}) + } + } + return result +} diff --git a/forms/forms_test.go b/forms/forms_test.go new file mode 100644 index 0000000..fecd0d9 --- /dev/null +++ b/forms/forms_test.go @@ -0,0 +1,59 @@ +package forms_test + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "git.tavo.one/tavo/axiom/forms" +) + +func TestParseFormToStruct(t *testing.T) { + type MyForm struct { + Name string `form:"name" fmt:"trim" validate:"nonzero,minlen:5"` + Email string `form:"email" fmt:"trim,lower"` + Tags []string `form:"tags" fmt:"trim"` + Enabled bool `form:"enabled"` + Age int `form:"age"` + Scores []float64 `form:"scores"` + } + + data := url.Values{} + data.Set("name", " Alice ") + data.Set("email", "ALICE@EXAMPLE.COM") + data.Add("tags", " go ") + data.Add("tags", "web") + data.Set("enabled", "true") + data.Set("age", "28") + data.Add("scores", "98.6") + data.Add("scores", "80.6") + + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(data.Encode())) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + form, err := forms.FormToStruct[MyForm](r) + if err != nil { + t.Fatalf("FormToStruct returned error: %v", err) + } + + if form.Name != "Alice" { + t.Errorf("expected Name to be 'Alice', got '%s'", form.Name) + } + if form.Email != "alice@example.com" { + t.Errorf("expected Email to be 'alice@example.com', got '%s'", form.Email) + } + if len(form.Tags) != 2 || form.Tags[0] != "go" || form.Tags[1] != "web" { + t.Errorf("expected Tags to be ['go', 'web'], got %v", form.Tags) + } + if !form.Enabled { + t.Errorf("expected Enabled to be true") + } + if form.Age != 28 { + t.Errorf("expected Age to be 28, got %d", form.Age) + } + if len(form.Scores) != 2 || form.Scores[0] != 98.6 || form.Scores[1] != 80.6 { + t.Errorf("expected Scores to be [98.6, 80.6], got %v", form.Scores) + } +} diff --git a/forms/validators.go b/forms/validators.go new file mode 100644 index 0000000..ed9364d --- /dev/null +++ b/forms/validators.go @@ -0,0 +1,41 @@ +package forms + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +func nonzero(field string, value any, _ string) error { + v := reflect.ValueOf(value) + if v.Kind() == reflect.String && v.Len() == 0 { + return fmt.Errorf("field '%s' must not be empty", field) + } + if v.Kind() == reflect.Slice && v.Len() == 0 { + return fmt.Errorf("field '%s' must not be empty", field) + } + if v.Kind() >= reflect.Int && v.Kind() <= reflect.Float64 && v.IsZero() { + return fmt.Errorf("field '%s' must be non-zero", field) + } + return nil +} + +func minlen(field string, value any, param string) error { + min, err := strconv.Atoi(param) + if err != nil { + return fmt.Errorf("invalid minlen param for field '%s'", field) + } + if str, ok := value.(string); ok && len(str) < min { + return fmt.Errorf("field '%s' must be at least %d characters", field, min) + } + return nil +} + +func email(field string, value any, _ string) error { + str, ok := value.(string) + if !ok || !strings.Contains(str, "@") { + return fmt.Errorf("field '%s' must be a valid email address", field) + } + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..fe93397 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.tavo.one/tavo/axiom + +go 1.24.4 diff --git a/sessions/sessions.go b/sessions/sessions.go new file mode 100644 index 0000000..fc276a6 --- /dev/null +++ b/sessions/sessions.go @@ -0,0 +1,165 @@ +package sessions + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "fmt" + "sync" + "time" +) + +type SessionStore[T any] struct { + mu sync.RWMutex + sessions map[string]session[T] + maxSessions int + tokenLength int +} + +type session[T any] struct { + csrfTokenHash string + expires time.Time + data T +} + +func NewStore[T any](tokenLength, maxSessions int) *SessionStore[T] { + return &SessionStore[T]{ + sessions: make(map[string]session[T]), + tokenLength: tokenLength, + maxSessions: maxSessions, + } +} + +// Starts a background goroutine to clean up expired sessions +func (s *SessionStore[T]) StartCleanup(interval time.Duration) { + go func() { + for { + time.Sleep(interval) + s.cleanupExpiredSessions() + } + }() +} + +func (s *SessionStore[T]) cleanupExpiredSessions() { + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + for k, sess := range s.sessions { + if now.After(sess.expires) { + delete(s.sessions, k) + } + } +} + +// Creates a new session and returns the session token and CSRF token. +func (s *SessionStore[T]) New(sessionMaxAge time.Duration, data T) (string, string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if len(s.sessions) >= s.maxSessions { + return "", "", fmt.Errorf("maximum number of sessions reached") + } + + st, err := generateToken(s.tokenLength) + if err != nil { + return "", "", fmt.Errorf("error generating session token: %w", err) + } + + ct, err := generateToken(s.tokenLength) + if err != nil { + return "", "", fmt.Errorf("error generating CSRF token: %w", err) + } + + hst := hash(st) + hct := hash(ct) + + s.sessions[hst] = session[T]{ + csrfTokenHash: hct, + expires: time.Now().Add(sessionMaxAge), + data: data, + } + + return st, ct, nil +} + +// Validates the session and CSRF tokens, and returns new rotated tokens and session data. +func (s *SessionStore[T]) Validate(st, ct string) (string, string, T, error) { + var zero T + + hst := hash(st) + + s.mu.RLock() + sess, ok := s.sessions[hst] + s.mu.RUnlock() + + if !ok { + return "", "", zero, fmt.Errorf("invalid session token") + } + + if time.Now().After(sess.expires) { + s.mu.Lock() + delete(s.sessions, hst) + s.mu.Unlock() + return "", "", zero, fmt.Errorf("session expired") + } + + if subtle.ConstantTimeCompare([]byte(sess.csrfTokenHash), []byte(hash(ct))) != 1 { + return "", "", zero, fmt.Errorf("invalid CSRF token") + } + + // Rotate session tokens + newSt, err := generateToken(s.tokenLength) + if err != nil { + return "", "", zero, fmt.Errorf("error generating new session token: %w", err) + } + newCt, err := generateToken(s.tokenLength) + if err != nil { + return "", "", zero, fmt.Errorf("error generating new CSRF token: %w", err) + } + + newHst := hash(newSt) + newHct := hash(newCt) + + s.mu.Lock() + delete(s.sessions, hst) + s.sessions[newHst] = session[T]{ + csrfTokenHash: newHct, + expires: sess.expires, + data: sess.data, + } + s.mu.Unlock() + + return newSt, newCt, sess.data, nil +} + +// Deletes a session using the session token +func (s *SessionStore[T]) Delete(st string) error { + hst := hash(st) + + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.sessions[hst]; !ok { + return fmt.Errorf("invalid session token") + } + + delete(s.sessions, hst) + return nil +} + +func generateToken(length int) (string, error) { + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(bytes), nil +} + +func hash(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:]) +} diff --git a/sessions/sessions_test.go b/sessions/sessions_test.go new file mode 100644 index 0000000..a2e2ab5 --- /dev/null +++ b/sessions/sessions_test.go @@ -0,0 +1,45 @@ +package sessions_test + +import ( + "testing" + "time" + + "git.tavo.one/tavo/axiom/sessions" +) + +type DummyData struct { + UserID string +} + +func TestMaxSessionLimit(t *testing.T) { + const max = 10 + + store := sessions.NewStore[DummyData](24, max) + + // Try creating maxSessions + for i := 0; i < max; i++ { + _, _, err := store.New(30*time.Minute, DummyData{UserID: "user"}) + if err != nil { + t.Fatalf("unexpected error on session %d: %v", i, err) + } + } + + // Now create one more, which should fail + _, _, err := store.New(30*time.Minute, DummyData{UserID: "extra"}) + if err == nil { + t.Fatal("expected error when exceeding max sessions, but got nil") + } +} + +func BenchmarkSessionCreation(b *testing.B) { + const maxSessions = 1_000_000 + store := sessions.NewStore[DummyData](24, maxSessions) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := store.New(10*time.Minute, DummyData{UserID: "u"}) + if err != nil { + b.Fatalf("failed to create session at iteration %d: %v", i, err) + } + } +} diff --git a/storage/storage.go b/storage/storage.go new file mode 100644 index 0000000..5dbb5af --- /dev/null +++ b/storage/storage.go @@ -0,0 +1,126 @@ +package storage + +import ( + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strings" +) + +type Client struct { + RootDir string + MaxObjectSize int64 +} + +func New(root string, maxObjectSize int64) (*Client, error) { + if maxObjectSize <= 0 { + return nil, fmt.Errorf("maxObjectSize must be greater than 0") + } + + info, err := os.Stat(root) + if os.IsNotExist(err) { + return nil, fmt.Errorf("root directory does not exist: %s", root) + } + if err != nil { + return nil, fmt.Errorf("failed to stat root directory: %v", err) + } + if !info.IsDir() { + return nil, fmt.Errorf("root path is not a directory: %s", root) + } + + testFile := filepath.Join(root, ".axiom-storage-check") + f, err := os.Create(testFile) + if err != nil { + return nil, fmt.Errorf("root directory is not writable: %v", err) + } + f.Close() + os.Remove(testFile) + + return &Client{ + RootDir: root, + MaxObjectSize: maxObjectSize, + }, nil +} + +func (s *Client) getPath(bucket, key string) string { + return filepath.Join(s.RootDir, bucket, key) +} + +func (s *Client) CreateBucket(bucket string) error { + path := filepath.Join(s.RootDir, bucket) + return os.MkdirAll(path, 0755) +} + +func (s *Client) PutObject(bucket, key string, reader io.Reader) error { + fullPath := s.getPath(bucket, key) + + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + return err + } + + f, err := os.Create(fullPath) + if err != nil { + return err + } + defer f.Close() + + limitedReader := io.LimitReader(reader, s.MaxObjectSize+1) + written, err := io.Copy(f, limitedReader) + if err != nil { + return err + } + + if written > s.MaxObjectSize { + os.Remove(fullPath) + return fmt.Errorf("object size exceeds maximum allowed: %d bytes", s.MaxObjectSize) + } + + return nil +} + +func (s *Client) GetObject(bucket, key string) (*os.File, error) { + fullPath := s.getPath(bucket, key) + return os.Open(fullPath) +} + +func (s *Client) DeleteObject(bucket, key string) error { + fullPath := s.getPath(bucket, key) + return os.Remove(fullPath) +} + +func DetectFileType(file multipart.File) (string, []byte, error) { + const sniffLen = 512 + buffer := make([]byte, sniffLen) + + n, err := file.Read(buffer) + if err != nil && err != io.EOF { + return "", nil, err + } + + if seeker, ok := file.(io.Seeker); ok { + _, _ = seeker.Seek(0, io.SeekStart) + } else { + return "", nil, fmt.Errorf("file is not seekable") + } + + mimeType := http.DetectContentType(buffer[:n]) + return mimeType, buffer[:n], nil +} + +func VerifyFileType(file multipart.File, allowedMimeTypes []string) error { + mimeType, _, err := DetectFileType(file) + if err != nil { + return err + } + + for _, mt := range allowedMimeTypes { + if strings.EqualFold(mimeType, mt) { + return nil + } + } + + return fmt.Errorf("invalid file type: %s", mimeType) +}