92 lines
2.8 KiB
Go
92 lines
2.8 KiB
Go
package storage
|
|
|
|
import (
|
|
"fmt"
|
|
"time"
|
|
)
|
|
|
|
// CreateSession creates or replaces the session for a user (single session per user)
|
|
func (s *Store) CreateSession(record SessionRecord) error {
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Delete any existing session for this user
|
|
deleteQuery := `DELETE FROM sessions WHERE user_id = ?`
|
|
if _, err := tx.Exec(deleteQuery, record.UserID); err != nil {
|
|
return fmt.Errorf("failed to delete existing session: %w", err)
|
|
}
|
|
|
|
// Insert new session
|
|
insertQuery := `INSERT INTO sessions (session_id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)`
|
|
if _, err := tx.Exec(insertQuery, record.SessionID, record.UserID, record.CreatedAt, record.ExpiresAt); err != nil {
|
|
return fmt.Errorf("failed to create session: %w", err)
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// GetSession retrieves a session by ID
|
|
func (s *Store) GetSession(sessionID string) (*SessionRecord, error) {
|
|
var session SessionRecord
|
|
query := `SELECT session_id, user_id, created_at, expires_at FROM sessions WHERE session_id = ?`
|
|
|
|
err := s.db.QueryRow(query, sessionID).Scan(
|
|
&session.SessionID, &session.UserID, &session.CreatedAt, &session.ExpiresAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// GetSessionByUserID retrieves the active session for a user
|
|
func (s *Store) GetSessionByUserID(userID string) (*SessionRecord, error) {
|
|
var session SessionRecord
|
|
query := `SELECT session_id, user_id, created_at, expires_at FROM sessions WHERE user_id = ?`
|
|
|
|
err := s.db.QueryRow(query, userID).Scan(
|
|
&session.SessionID, &session.UserID, &session.CreatedAt, &session.ExpiresAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// DeleteSession removes a session
|
|
func (s *Store) DeleteSession(sessionID string) error {
|
|
query := `DELETE FROM sessions WHERE session_id = ?`
|
|
_, err := s.db.Exec(query, sessionID)
|
|
return err
|
|
}
|
|
|
|
// DeleteSessionByUserID removes all sessions for a user
|
|
func (s *Store) DeleteSessionByUserID(userID string) error {
|
|
query := `DELETE FROM sessions WHERE user_id = ?`
|
|
_, err := s.db.Exec(query, userID)
|
|
return err
|
|
}
|
|
|
|
// DeleteExpiredSessions removes expired sessions
|
|
func (s *Store) DeleteExpiredSessions() (int64, error) {
|
|
query := `DELETE FROM sessions WHERE expires_at < ?`
|
|
result, err := s.db.Exec(query, time.Now().UTC())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// IsSessionValid checks if a session exists and is not expired
|
|
func (s *Store) IsSessionValid(sessionID string) (bool, error) {
|
|
var count int
|
|
query := `SELECT COUNT(*) FROM sessions WHERE session_id = ? AND expires_at > ?`
|
|
err := s.db.QueryRow(query, sessionID, time.Now().UTC()).Scan(&count)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
} |