package sessions import ( "crypto/rand" "crypto/sha256" "crypto/subtle" "encoding/base64" "encoding/hex" "fmt" "sync" "time" ) type SessionStore[T any] struct { mu sync.RWMutex sessions map[string]session[T] maxSessions int tokenLength int } type session[T any] struct { csrfTokenHash string expires time.Time data T } func NewStore[T any](tokenLength, maxSessions int) *SessionStore[T] { return &SessionStore[T]{ sessions: make(map[string]session[T]), tokenLength: tokenLength, maxSessions: maxSessions, } } // Starts a background goroutine to clean up expired sessions func (s *SessionStore[T]) StartCleanup(interval time.Duration) { go func() { for { time.Sleep(interval) s.cleanupExpiredSessions() } }() } func (s *SessionStore[T]) cleanupExpiredSessions() { now := time.Now() s.mu.Lock() defer s.mu.Unlock() for k, sess := range s.sessions { if now.After(sess.expires) { delete(s.sessions, k) } } } // Creates a new session and returns the session token and CSRF token. func (s *SessionStore[T]) New(sessionMaxAge time.Duration, data T) (string, string, error) { s.mu.Lock() defer s.mu.Unlock() if len(s.sessions) >= s.maxSessions { return "", "", fmt.Errorf("maximum number of sessions reached") } st, err := generateToken(s.tokenLength) if err != nil { return "", "", fmt.Errorf("error generating session token: %w", err) } ct, err := generateToken(s.tokenLength) if err != nil { return "", "", fmt.Errorf("error generating CSRF token: %w", err) } hst := hash(st) hct := hash(ct) s.sessions[hst] = session[T]{ csrfTokenHash: hct, expires: time.Now().Add(sessionMaxAge), data: data, } return st, ct, nil } // Validates the session and CSRF tokens, and returns new rotated tokens and session data. func (s *SessionStore[T]) Validate(st, ct string) (string, string, T, error) { var zero T hst := hash(st) s.mu.RLock() sess, ok := s.sessions[hst] s.mu.RUnlock() if !ok { return "", "", zero, fmt.Errorf("invalid session token") } if time.Now().After(sess.expires) { s.mu.Lock() delete(s.sessions, hst) s.mu.Unlock() return "", "", zero, fmt.Errorf("session expired") } if subtle.ConstantTimeCompare([]byte(sess.csrfTokenHash), []byte(hash(ct))) != 1 { return "", "", zero, fmt.Errorf("invalid CSRF token") } // Rotate session tokens newSt, err := generateToken(s.tokenLength) if err != nil { return "", "", zero, fmt.Errorf("error generating new session token: %w", err) } newCt, err := generateToken(s.tokenLength) if err != nil { return "", "", zero, fmt.Errorf("error generating new CSRF token: %w", err) } newHst := hash(newSt) newHct := hash(newCt) s.mu.Lock() delete(s.sessions, hst) s.sessions[newHst] = session[T]{ csrfTokenHash: newHct, expires: sess.expires, data: sess.data, } s.mu.Unlock() return newSt, newCt, sess.data, nil } // Deletes a session using the session token func (s *SessionStore[T]) Delete(st string) error { hst := hash(st) s.mu.Lock() defer s.mu.Unlock() if _, ok := s.sessions[hst]; !ok { return fmt.Errorf("invalid session token") } delete(s.sessions, hst) return nil } func generateToken(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { return "", err } return base64.URLEncoding.EncodeToString(bytes), nil } func hash(token string) string { sum := sha256.Sum256([]byte(token)) return hex.EncodeToString(sum[:]) }