axiom/sessions/sessions.go
2025-06-30 01:48:35 -06:00

165 lines
3.5 KiB
Go

package sessions
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"sync"
"time"
)
type SessionStore[T any] struct {
mu sync.RWMutex
sessions map[string]session[T]
maxSessions int
tokenLength int
}
type session[T any] struct {
csrfTokenHash string
expires time.Time
data T
}
func NewStore[T any](tokenLength, maxSessions int) *SessionStore[T] {
return &SessionStore[T]{
sessions: make(map[string]session[T]),
tokenLength: tokenLength,
maxSessions: maxSessions,
}
}
// Starts a background goroutine to clean up expired sessions
func (s *SessionStore[T]) StartCleanup(interval time.Duration) {
go func() {
for {
time.Sleep(interval)
s.cleanupExpiredSessions()
}
}()
}
func (s *SessionStore[T]) cleanupExpiredSessions() {
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
for k, sess := range s.sessions {
if now.After(sess.expires) {
delete(s.sessions, k)
}
}
}
// Creates a new session and returns the session token and CSRF token.
func (s *SessionStore[T]) New(sessionMaxAge time.Duration, data T) (string, string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.sessions) >= s.maxSessions {
return "", "", fmt.Errorf("maximum number of sessions reached")
}
st, err := generateToken(s.tokenLength)
if err != nil {
return "", "", fmt.Errorf("error generating session token: %w", err)
}
ct, err := generateToken(s.tokenLength)
if err != nil {
return "", "", fmt.Errorf("error generating CSRF token: %w", err)
}
hst := hash(st)
hct := hash(ct)
s.sessions[hst] = session[T]{
csrfTokenHash: hct,
expires: time.Now().Add(sessionMaxAge),
data: data,
}
return st, ct, nil
}
// Validates the session and CSRF tokens, and returns new rotated tokens and session data.
func (s *SessionStore[T]) Validate(st, ct string) (string, string, T, error) {
var zero T
hst := hash(st)
s.mu.RLock()
sess, ok := s.sessions[hst]
s.mu.RUnlock()
if !ok {
return "", "", zero, fmt.Errorf("invalid session token")
}
if time.Now().After(sess.expires) {
s.mu.Lock()
delete(s.sessions, hst)
s.mu.Unlock()
return "", "", zero, fmt.Errorf("session expired")
}
if subtle.ConstantTimeCompare([]byte(sess.csrfTokenHash), []byte(hash(ct))) != 1 {
return "", "", zero, fmt.Errorf("invalid CSRF token")
}
// Rotate session tokens
newSt, err := generateToken(s.tokenLength)
if err != nil {
return "", "", zero, fmt.Errorf("error generating new session token: %w", err)
}
newCt, err := generateToken(s.tokenLength)
if err != nil {
return "", "", zero, fmt.Errorf("error generating new CSRF token: %w", err)
}
newHst := hash(newSt)
newHct := hash(newCt)
s.mu.Lock()
delete(s.sessions, hst)
s.sessions[newHst] = session[T]{
csrfTokenHash: newHct,
expires: sess.expires,
data: sess.data,
}
s.mu.Unlock()
return newSt, newCt, sess.data, nil
}
// Deletes a session using the session token
func (s *SessionStore[T]) Delete(st string) error {
hst := hash(st)
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.sessions[hst]; !ok {
return fmt.Errorf("invalid session token")
}
delete(s.sessions, hst)
return nil
}
func generateToken(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}
func hash(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}