165 lines
3.4 KiB
Go
165 lines
3.4 KiB
Go
package sessions
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type SessionStore[T any] struct {
|
|
mu sync.RWMutex
|
|
sessions map[string]session[T]
|
|
maxSessions int
|
|
tokenLength int
|
|
}
|
|
|
|
type session[T any] struct {
|
|
csrfTokenHash string
|
|
expires time.Time
|
|
data T
|
|
}
|
|
|
|
func NewStore[T any](tokenLength, maxSessions int) *SessionStore[T] {
|
|
return &SessionStore[T]{
|
|
sessions: make(map[string]session[T]),
|
|
tokenLength: tokenLength,
|
|
maxSessions: maxSessions,
|
|
}
|
|
}
|
|
|
|
// Starts a background goroutine to clean up expired sessions
|
|
func (s *SessionStore[T]) StartCleanup(interval time.Duration) {
|
|
go func() {
|
|
for {
|
|
time.Sleep(interval)
|
|
s.cleanupExpiredSessions()
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *SessionStore[T]) cleanupExpiredSessions() {
|
|
now := time.Now()
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
for k, sess := range s.sessions {
|
|
if now.After(sess.expires) {
|
|
delete(s.sessions, k)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Creates a new session and returns the session token and CSRF token.
|
|
func (s *SessionStore[T]) New(sessionMaxAge time.Duration, data T) (string, string, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if len(s.sessions) >= s.maxSessions {
|
|
return "", "", fmt.Errorf("maximum number of sessions reached")
|
|
}
|
|
|
|
st, err := generateToken(s.tokenLength)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error generating session token: %w", err)
|
|
}
|
|
|
|
ct, err := generateToken(s.tokenLength)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error generating CSRF token: %w", err)
|
|
}
|
|
|
|
hst := hash(st)
|
|
hct := hash(ct)
|
|
|
|
s.sessions[hst] = session[T]{
|
|
csrfTokenHash: hct,
|
|
expires: time.Now().Add(sessionMaxAge),
|
|
data: data,
|
|
}
|
|
|
|
return st, ct, nil
|
|
}
|
|
|
|
// Validates the session and CSRF tokens, and returns new rotated tokens and session data.
|
|
func (s *SessionStore[T]) Validate(st, ct string) (string, string, T, error) {
|
|
var zero T
|
|
|
|
hst := hash(st)
|
|
|
|
s.mu.RLock()
|
|
sess, ok := s.sessions[hst]
|
|
s.mu.RUnlock()
|
|
|
|
if !ok {
|
|
return "", "", zero, fmt.Errorf("invalid session token")
|
|
}
|
|
|
|
if time.Now().After(sess.expires) {
|
|
s.mu.Lock()
|
|
delete(s.sessions, hst)
|
|
s.mu.Unlock()
|
|
return "", "", zero, fmt.Errorf("session expired")
|
|
}
|
|
|
|
if subtle.ConstantTimeCompare([]byte(sess.csrfTokenHash), []byte(hash(ct))) != 1 {
|
|
return "", "", zero, fmt.Errorf("invalid CSRF token")
|
|
}
|
|
|
|
// Rotate session tokens
|
|
newSt, err := generateToken(s.tokenLength)
|
|
if err != nil {
|
|
return "", "", zero, fmt.Errorf("error generating new session token: %w", err)
|
|
}
|
|
newCt, err := generateToken(s.tokenLength)
|
|
if err != nil {
|
|
return "", "", zero, fmt.Errorf("error generating new CSRF token: %w", err)
|
|
}
|
|
|
|
newHst := hash(newSt)
|
|
newHct := hash(newCt)
|
|
|
|
s.mu.Lock()
|
|
delete(s.sessions, hst)
|
|
s.sessions[newHst] = session[T]{
|
|
csrfTokenHash: newHct,
|
|
expires: sess.expires,
|
|
data: sess.data,
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
return newSt, newCt, sess.data, nil
|
|
}
|
|
|
|
// Deletes a session using the session token
|
|
func (s *SessionStore[T]) Delete(st string) error {
|
|
hst := hash(st)
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if _, ok := s.sessions[hst]; !ok {
|
|
return fmt.Errorf("invalid session token")
|
|
}
|
|
|
|
delete(s.sessions, hst)
|
|
return nil
|
|
}
|
|
|
|
func generateToken(length int) (string, error) {
|
|
bytes := make([]byte, length)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(bytes), nil
|
|
}
|
|
|
|
func hash(token string) string {
|
|
sum := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|