v0.5.0 user support with auth added, tests and doc updated
This commit is contained in:
232
internal/http/auth.go
Normal file
232
internal/http/auth.go
Normal file
@ -0,0 +1,232 @@
|
||||
// FILE: internal/http/auth.go
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"chess/internal/core"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
var usernameRegex = regexp.MustCompile(`^[a-zA-Z0-9_]{1,40}$`)
|
||||
|
||||
// RegisterRequest defines the user registration payload
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" validate:"required,min=1,max=40"`
|
||||
Email string `json:"email" validate:"omitempty,max=255"`
|
||||
Password string `json:"password" validate:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
// LoginRequest defines the authentication payload
|
||||
type LoginRequest struct {
|
||||
Identifier string `json:"identifier" validate:"required"` // username or email
|
||||
Password string `json:"password" validate:"required"`
|
||||
}
|
||||
|
||||
// AuthResponse contains JWT token and user information
|
||||
type AuthResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"userId"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
}
|
||||
|
||||
// UserResponse contains current user information
|
||||
type UserResponse struct {
|
||||
UserID string `json:"userId"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// RegisterHandler creates a new user account
|
||||
func (h *HTTPHandler) RegisterHandler(c *fiber.Ctx) error {
|
||||
var req RegisterRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(core.ErrorResponse{
|
||||
Error: "invalid request body",
|
||||
Code: core.ErrInvalidRequest,
|
||||
Details: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Validate username format
|
||||
if !usernameRegex.MatchString(req.Username) {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(core.ErrorResponse{
|
||||
Error: "invalid username format",
|
||||
Code: core.ErrInvalidRequest,
|
||||
Details: "username must be 1-40 characters, alphanumeric and underscore only",
|
||||
})
|
||||
}
|
||||
|
||||
// Validate email format if provided
|
||||
if req.Email != "" && !emailRegex.MatchString(req.Email) {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(core.ErrorResponse{
|
||||
Error: "invalid email format",
|
||||
Code: core.ErrInvalidRequest,
|
||||
Details: "email must be a valid email address",
|
||||
})
|
||||
}
|
||||
|
||||
// Validate password strength
|
||||
if err := validatePassword(req.Password); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(core.ErrorResponse{
|
||||
Error: "weak password",
|
||||
Code: core.ErrInvalidRequest,
|
||||
Details: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Normalize for case-insensitive storage
|
||||
req.Username = strings.ToLower(req.Username)
|
||||
if req.Email != "" {
|
||||
req.Email = strings.ToLower(req.Email)
|
||||
}
|
||||
|
||||
// Create user
|
||||
user, err := h.svc.CreateUser(req.Username, req.Email, req.Password)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "already exists") {
|
||||
return c.Status(fiber.StatusConflict).JSON(core.ErrorResponse{
|
||||
Error: "user already exists",
|
||||
Code: core.ErrInvalidRequest,
|
||||
Details: "username or email already taken",
|
||||
})
|
||||
}
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(core.ErrorResponse{
|
||||
Error: "failed to create user",
|
||||
Code: core.ErrInternalError,
|
||||
})
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := h.svc.GenerateUserToken(user.UserID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(core.ErrorResponse{
|
||||
Error: "failed to generate token",
|
||||
Code: core.ErrInternalError,
|
||||
})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(AuthResponse{
|
||||
Token: token,
|
||||
UserID: user.UserID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
// validatePassword checks password strength requirements
|
||||
func validatePassword(password string) error {
|
||||
const (
|
||||
minPasswordLength = 8
|
||||
maxPasswordLength = 128
|
||||
)
|
||||
if len(password) < minPasswordLength {
|
||||
return fmt.Errorf("password must be at least 8 characters")
|
||||
}
|
||||
if len(password) > maxPasswordLength {
|
||||
return fmt.Errorf("password must not exceed 128 characters")
|
||||
}
|
||||
|
||||
// Check for at least one letter and one number
|
||||
hasLetter := false
|
||||
hasNumber := false
|
||||
for _, r := range password {
|
||||
switch {
|
||||
case unicode.IsLetter(r):
|
||||
hasLetter = true
|
||||
case unicode.IsNumber(r):
|
||||
hasNumber = true
|
||||
}
|
||||
if hasLetter && hasNumber {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasLetter || !hasNumber {
|
||||
return fmt.Errorf("password must contain at least one letter and one number")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoginHandler authenticates user and returns JWT token
|
||||
func (h *HTTPHandler) LoginHandler(c *fiber.Ctx) error {
|
||||
var req LoginRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(core.ErrorResponse{
|
||||
Error: "invalid request body",
|
||||
Code: core.ErrInvalidRequest,
|
||||
Details: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Normalize identifier for case-insensitive lookup
|
||||
req.Identifier = strings.ToLower(req.Identifier)
|
||||
|
||||
// Authenticate user
|
||||
user, err := h.svc.AuthenticateUser(req.Identifier, req.Password)
|
||||
if err != nil {
|
||||
// Always return same error to prevent user enumeration
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(core.ErrorResponse{
|
||||
Error: "invalid credentials",
|
||||
Code: core.ErrInvalidRequest,
|
||||
})
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := h.svc.GenerateUserToken(user.UserID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(core.ErrorResponse{
|
||||
Error: "failed to generate token",
|
||||
Code: core.ErrInternalError,
|
||||
})
|
||||
}
|
||||
|
||||
// Update last login
|
||||
// TODO: for now, non-blocking if login time update fails, log/block in the future
|
||||
_ = h.svc.UpdateLastLogin(user.UserID)
|
||||
|
||||
return c.JSON(AuthResponse{
|
||||
Token: token,
|
||||
UserID: user.UserID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
// GetCurrentUserHandler returns authenticated user information
|
||||
func (h *HTTPHandler) GetCurrentUserHandler(c *fiber.Ctx) error {
|
||||
userID, ok := c.Locals("userID").(string)
|
||||
if !ok || userID == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(core.ErrorResponse{
|
||||
Error: "unauthorized",
|
||||
Code: core.ErrInvalidRequest,
|
||||
})
|
||||
}
|
||||
|
||||
user, err := h.svc.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusNotFound).JSON(core.ErrorResponse{
|
||||
Error: "user not found",
|
||||
Code: core.ErrInvalidRequest,
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(UserResponse{
|
||||
UserID: user.UserID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
})
|
||||
}
|
||||
@ -2,13 +2,14 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"chess/internal/core"
|
||||
"chess/internal/processor"
|
||||
"chess/internal/service"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"chess/internal/core"
|
||||
"chess/internal/processor"
|
||||
"chess/internal/service"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/limiter"
|
||||
@ -47,27 +48,66 @@ func NewFiberApp(proc *processor.Processor, svc *service.Service, devMode bool)
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: "*",
|
||||
AllowMethods: "GET,POST,PUT,DELETE,OPTIONS",
|
||||
AllowHeaders: "Origin,Content-Type,Accept",
|
||||
AllowHeaders: "Origin,Content-Type,Accept,Authorization",
|
||||
}))
|
||||
|
||||
// Health check (no rate limit)
|
||||
app.Get("/health", h.Health)
|
||||
|
||||
// API v1 routes with rate limiting
|
||||
// API v1 routes
|
||||
api := app.Group("/api/v1")
|
||||
|
||||
// Rate limiter: 10/20 req/sec per IP with expiry
|
||||
// Auth routes with specific rate limiting
|
||||
auth := api.Group("/auth")
|
||||
|
||||
// Register: 5 req/min per IP
|
||||
auth.Post("/register", limiter.New(limiter.Config{
|
||||
Max: 5,
|
||||
Expiration: 1 * time.Minute,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
},
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusTooManyRequests).JSON(core.ErrorResponse{
|
||||
Error: "rate limit exceeded",
|
||||
Code: core.ErrRateLimitExceeded,
|
||||
Details: "5 registrations per minute allowed",
|
||||
})
|
||||
},
|
||||
}), h.RegisterHandler)
|
||||
|
||||
// Login: 10 req/min per IP
|
||||
auth.Post("/login", limiter.New(limiter.Config{
|
||||
Max: 10,
|
||||
Expiration: 1 * time.Minute,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
},
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusTooManyRequests).JSON(core.ErrorResponse{
|
||||
Error: "rate limit exceeded",
|
||||
Code: core.ErrRateLimitExceeded,
|
||||
Details: "10 login attempts per minute allowed",
|
||||
})
|
||||
},
|
||||
}), h.LoginHandler)
|
||||
|
||||
// Create token validator closure
|
||||
validateToken := svc.ValidateToken
|
||||
|
||||
// Current user (requires auth)
|
||||
auth.Get("/me", AuthRequired(validateToken), h.GetCurrentUserHandler)
|
||||
|
||||
// Game routes with standard rate limiting
|
||||
maxReq := rateLimitRate
|
||||
if devMode {
|
||||
maxReq = rateLimitRate * 2 // Loosen rate limiter for testing
|
||||
maxReq = rateLimitRate * 2
|
||||
}
|
||||
api.Use(limiter.New(limiter.Config{
|
||||
Max: maxReq, // Allow requests per second
|
||||
Expiration: 1 * time.Second, // Per second
|
||||
Max: maxReq,
|
||||
Expiration: 1 * time.Second,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
// Check X-Forwarded-For first, then X-Real-IP, then RemoteIP
|
||||
if xff := c.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP from X-Forwarded-For chain
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
@ -82,9 +122,6 @@ func NewFiberApp(proc *processor.Processor, svc *service.Service, devMode bool)
|
||||
Details: fmt.Sprintf("%d requests per second allowed", maxReq),
|
||||
})
|
||||
},
|
||||
Storage: nil, // Use in-memory storage (default)
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
}))
|
||||
|
||||
// Content-Type validation for POST and PUT requests
|
||||
@ -93,8 +130,8 @@ func NewFiberApp(proc *processor.Processor, svc *service.Service, devMode bool)
|
||||
// Middleware validation for sanitization
|
||||
api.Use(validationMiddleware)
|
||||
|
||||
// Register game routes
|
||||
api.Post("/games", h.CreateGame)
|
||||
// Register game routes with auth middleware
|
||||
api.Post("/games", OptionalAuth(validateToken), h.CreateGame) // Optional auth for player ID association
|
||||
api.Put("/games/:gameId/players", h.ConfigurePlayers)
|
||||
api.Get("/games/:gameId", h.GetGame)
|
||||
api.Delete("/games/:gameId", h.DeleteGame)
|
||||
@ -179,8 +216,12 @@ func (h *HTTPHandler) CreateGame(c *fiber.Ctx) error {
|
||||
var req core.CreateGameRequest
|
||||
req = *(validatedBody.(*core.CreateGameRequest))
|
||||
|
||||
// Let processor generate game ID via service
|
||||
// Retrieve authenticated user ID if available
|
||||
userID, _ := c.Locals("userID").(string)
|
||||
|
||||
// Generate game ID via service with optional user context
|
||||
cmd := processor.NewCreateGameCommand(req)
|
||||
cmd.UserID = userID // Add user ID to command if authenticated
|
||||
|
||||
resp := h.proc.Execute(cmd)
|
||||
|
||||
|
||||
63
internal/http/middleware.go
Normal file
63
internal/http/middleware.go
Normal file
@ -0,0 +1,63 @@
|
||||
// FILE: internal/http/middleware.go
|
||||
package http
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"chess/internal/core"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// TokenValidator validates JWT tokens
|
||||
type TokenValidator func(token string) (userID string, claims map[string]any, err error)
|
||||
|
||||
// AuthRequired enforces JWT authentication for protected endpoints
|
||||
func AuthRequired(validateToken TokenValidator) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
token := extractBearerToken(c.Get("Authorization"))
|
||||
if token == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(core.ErrorResponse{
|
||||
Error: "missing authorization token",
|
||||
Code: core.ErrInvalidRequest,
|
||||
})
|
||||
}
|
||||
|
||||
userID, _, err := validateToken(token)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(core.ErrorResponse{
|
||||
Error: "invalid or expired token",
|
||||
Code: core.ErrInvalidRequest,
|
||||
})
|
||||
}
|
||||
|
||||
c.Locals("userID", userID)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalAuth validates JWT if present but allows anonymous access
|
||||
func OptionalAuth(validateToken TokenValidator) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
token := extractBearerToken(c.Get("Authorization"))
|
||||
if token == "" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
userID, _, err := validateToken(token)
|
||||
if err == nil {
|
||||
c.Locals("userID", userID)
|
||||
}
|
||||
// Continue regardless of token validity
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// extractBearerToken extracts JWT token from Authorization header
|
||||
func extractBearerToken(header string) string {
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(header, prefix) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimPrefix(header, prefix)
|
||||
}
|
||||
@ -21,6 +21,7 @@ const (
|
||||
// Command is a unified structure for all processor operations
|
||||
type Command struct {
|
||||
Type CommandType
|
||||
UserID string
|
||||
GameID string // For game-specific commands
|
||||
Args interface{} // Command-specific arguments
|
||||
}
|
||||
|
||||
@ -159,8 +159,20 @@ func (p *Processor) handleCreateGame(cmd Command) ProcessorResponse {
|
||||
return p.errorResponse(fmt.Sprintf("FEN parse error: %v", err), core.ErrInvalidRequest)
|
||||
}
|
||||
|
||||
// Create game in service with validated FEN and turn
|
||||
if err = p.svc.CreateGame(gameID, args.White, args.Black, validatedFEN, b.Turn()); err != nil {
|
||||
// Create players with appropriate IDs
|
||||
whitePlayer := core.NewPlayer(args.White, core.ColorWhite)
|
||||
blackPlayer := core.NewPlayer(args.Black, core.ColorBlack)
|
||||
|
||||
// Override player IDs for authenticated human players
|
||||
if args.White.Type == core.PlayerHuman && cmd.UserID != "" {
|
||||
whitePlayer.ID = cmd.UserID
|
||||
}
|
||||
if args.Black.Type == core.PlayerHuman && cmd.UserID != "" {
|
||||
blackPlayer.ID = cmd.UserID
|
||||
}
|
||||
|
||||
// Create game in service with fully-formed players
|
||||
if err = p.svc.CreateGame(gameID, whitePlayer, blackPlayer, validatedFEN, b.Turn()); err != nil {
|
||||
return p.errorResponse(fmt.Sprintf("failed to create game: %v", err), core.ErrInternalError)
|
||||
}
|
||||
|
||||
@ -206,8 +218,12 @@ func (p *Processor) handleConfigurePlayers(cmd Command) ProcessorResponse {
|
||||
return p.errorResponse("cannot change players while computer is calculating", core.ErrInvalidRequest)
|
||||
}
|
||||
|
||||
// Create new player instances
|
||||
whitePlayer := core.NewPlayer(args.White, core.ColorWhite)
|
||||
blackPlayer := core.NewPlayer(args.Black, core.ColorBlack)
|
||||
|
||||
// Update players in service
|
||||
if err = p.svc.UpdatePlayers(cmd.GameID, args.White, args.Black); err != nil {
|
||||
if err = p.svc.UpdatePlayers(cmd.GameID, whitePlayer, blackPlayer); err != nil {
|
||||
return p.errorResponse(fmt.Sprintf("failed to update players: %v", err), core.ErrInternalError)
|
||||
}
|
||||
|
||||
@ -389,7 +405,6 @@ func (p *Processor) handleDeleteGame(cmd Command) ProcessorResponse {
|
||||
return p.errorResponse("game not found", core.ErrGameNotFound)
|
||||
}
|
||||
|
||||
// TODO: gracefully handle deleting game even if pending, discard engine response
|
||||
// Only block deletion if actively computing
|
||||
if g.State() == core.StatePending {
|
||||
return p.errorResponse("cannot delete game while computer move is in progress", core.ErrInvalidRequest)
|
||||
|
||||
188
internal/service/game.go
Normal file
188
internal/service/game.go
Normal file
@ -0,0 +1,188 @@
|
||||
// FILE: internal/service/game.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"chess/internal/core"
|
||||
"chess/internal/game"
|
||||
"chess/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CreateGame registers a new game with pre-constructed players
|
||||
func (s *Service) CreateGame(id string, whitePlayer, blackPlayer *core.Player, initialFEN string, startingTurn core.Color) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.games[id]; exists {
|
||||
return fmt.Errorf("game %s already exists", id)
|
||||
}
|
||||
|
||||
// Store game with provided players
|
||||
s.games[id] = game.New(initialFEN, whitePlayer, blackPlayer, startingTurn)
|
||||
|
||||
// Persist if storage enabled
|
||||
if s.store != nil {
|
||||
record := storage.GameRecord{
|
||||
GameID: id,
|
||||
InitialFEN: initialFEN,
|
||||
WhitePlayerID: whitePlayer.ID,
|
||||
WhiteType: int(whitePlayer.Type),
|
||||
WhiteLevel: whitePlayer.Level,
|
||||
WhiteSearchTime: whitePlayer.SearchTime,
|
||||
BlackPlayerID: blackPlayer.ID,
|
||||
BlackType: int(blackPlayer.Type),
|
||||
BlackLevel: blackPlayer.Level,
|
||||
BlackSearchTime: blackPlayer.SearchTime,
|
||||
StartTimeUTC: time.Now().UTC(),
|
||||
}
|
||||
s.store.RecordNewGame(record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePlayers replaces players in an existing game
|
||||
func (s *Service) UpdatePlayers(gameID string, whitePlayer, blackPlayer *core.Player) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
// Update the game's players
|
||||
g.UpdatePlayers(whitePlayer, blackPlayer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGame retrieves a game by ID
|
||||
func (s *Service) GetGame(gameID string) (*game.Game, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// GenerateGameID creates a new unique game ID
|
||||
func (s *Service) GenerateGameID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Ensure UUID uniqueness (handle potential conflicts)
|
||||
for {
|
||||
id := uuid.New().String()
|
||||
if _, exists := s.games[id]; !exists {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyMove adds a validated move to the game history
|
||||
func (s *Service) ApplyMove(gameID, moveUCI, newFEN string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
// Determine whose turn it was before this move
|
||||
currentTurn := g.NextTurnColor()
|
||||
nextTurn := core.OppositeColor(currentTurn)
|
||||
|
||||
// Add the new position to game history
|
||||
g.AddSnapshot(newFEN, moveUCI, nextTurn)
|
||||
|
||||
// Persist if storage enabled
|
||||
if s.store != nil {
|
||||
moveNumber := len(g.Moves())
|
||||
record := storage.MoveRecord{
|
||||
GameID: gameID,
|
||||
MoveNumber: moveNumber,
|
||||
MoveUCI: moveUCI,
|
||||
FENAfterMove: newFEN,
|
||||
PlayerColor: currentTurn.String(),
|
||||
MoveTimeUTC: time.Now().UTC(),
|
||||
}
|
||||
s.store.RecordMove(record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateGameState sets the game's end state (checkmate, stalemate, etc)
|
||||
func (s *Service) UpdateGameState(gameID string, state core.State) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
g.SetState(state)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLastMoveResult stores metadata about the last move
|
||||
func (s *Service) SetLastMoveResult(gameID string, result *game.MoveResult) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
g.SetLastResult(result)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UndoMoves removes the specified number of moves from game history
|
||||
func (s *Service) UndoMoves(gameID string, count int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
originalMoveCount := len(g.Moves())
|
||||
|
||||
if err := g.UndoMoves(count); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete undone moves from storage if enabled
|
||||
if s.store != nil {
|
||||
remainingMoves := originalMoveCount - count
|
||||
s.store.DeleteUndoneMoves(gameID, remainingMoves)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGame removes a game from memory
|
||||
func (s *Service) DeleteGame(gameID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.games[gameID]; !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
delete(s.games, gameID)
|
||||
return nil
|
||||
}
|
||||
@ -2,214 +2,29 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"chess/internal/core"
|
||||
"chess/internal/game"
|
||||
"chess/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Service is a pure state manager for chess games with optional persistence
|
||||
type Service struct {
|
||||
games map[string]*game.Game
|
||||
mu sync.RWMutex
|
||||
store *storage.Store // nil if persistence disabled
|
||||
games map[string]*game.Game
|
||||
mu sync.RWMutex
|
||||
store *storage.Store // nil if persistence disabled
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
// New creates a new service instance with optional storage
|
||||
func New(store *storage.Store) (*Service, error) {
|
||||
func New(store *storage.Store, jwtSecret []byte) (*Service, error) {
|
||||
return &Service{
|
||||
games: make(map[string]*game.Game),
|
||||
store: store,
|
||||
games: make(map[string]*game.Game),
|
||||
store: store,
|
||||
jwtSecret: jwtSecret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateGame creates game with player configuration
|
||||
func (s *Service) CreateGame(id string, whiteConfig, blackConfig core.PlayerConfig, initialFEN string, startingTurn core.Color) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.games[id]; exists {
|
||||
return fmt.Errorf("game %s already exists", id)
|
||||
}
|
||||
|
||||
// Create players with UUIDs and config
|
||||
whitePlayer := core.NewPlayer(whiteConfig, core.ColorWhite)
|
||||
blackPlayer := core.NewPlayer(blackConfig, core.ColorBlack)
|
||||
|
||||
s.games[id] = game.New(initialFEN, whitePlayer, blackPlayer, startingTurn)
|
||||
|
||||
// Persist if storage enabled
|
||||
if s.store != nil {
|
||||
record := storage.GameRecord{
|
||||
GameID: id,
|
||||
InitialFEN: initialFEN,
|
||||
WhitePlayerID: whitePlayer.ID,
|
||||
WhiteType: int(whitePlayer.Type),
|
||||
WhiteLevel: whitePlayer.Level,
|
||||
WhiteSearchTime: whitePlayer.SearchTime,
|
||||
BlackPlayerID: blackPlayer.ID,
|
||||
BlackType: int(blackPlayer.Type),
|
||||
BlackLevel: blackPlayer.Level,
|
||||
BlackSearchTime: blackPlayer.SearchTime,
|
||||
StartTimeUTC: time.Now().UTC(),
|
||||
}
|
||||
s.store.RecordNewGame(record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePlayers replaces players in an existing game
|
||||
func (s *Service) UpdatePlayers(gameID string, whiteConfig, blackConfig core.PlayerConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
// Create new player instances with new UUIDs
|
||||
whitePlayer := core.NewPlayer(whiteConfig, core.ColorWhite)
|
||||
blackPlayer := core.NewPlayer(blackConfig, core.ColorBlack)
|
||||
|
||||
// Update the game's players
|
||||
g.UpdatePlayers(whitePlayer, blackPlayer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGame retrieves a game by ID
|
||||
func (s *Service) GetGame(gameID string) (*game.Game, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// GenerateGameID creates a new unique game ID
|
||||
func (s *Service) GenerateGameID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Ensure UUID uniqueness (handle potential conflicts)
|
||||
for {
|
||||
id := uuid.New().String()
|
||||
if _, exists := s.games[id]; !exists {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyMove adds a validated move to the game history
|
||||
func (s *Service) ApplyMove(gameID, moveUCI, newFEN string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
// Determine whose turn it was before this move
|
||||
currentTurn := g.NextTurnColor()
|
||||
nextTurn := core.OppositeColor(currentTurn)
|
||||
|
||||
// Add the new position to game history
|
||||
g.AddSnapshot(newFEN, moveUCI, nextTurn)
|
||||
|
||||
// Persist if storage enabled
|
||||
if s.store != nil {
|
||||
moveNumber := len(g.Moves())
|
||||
record := storage.MoveRecord{
|
||||
GameID: gameID,
|
||||
MoveNumber: moveNumber,
|
||||
MoveUCI: moveUCI,
|
||||
FENAfterMove: newFEN,
|
||||
PlayerColor: currentTurn.String(),
|
||||
MoveTimeUTC: time.Now().UTC(),
|
||||
}
|
||||
s.store.RecordMove(record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateGameState sets the game's end state (checkmate, stalemate, etc)
|
||||
func (s *Service) UpdateGameState(gameID string, state core.State) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
g.SetState(state)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLastMoveResult stores metadata about the last move
|
||||
func (s *Service) SetLastMoveResult(gameID string, result *game.MoveResult) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
g.SetLastResult(result)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UndoMoves removes the specified number of moves from game history
|
||||
func (s *Service) UndoMoves(gameID string, count int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
g, ok := s.games[gameID]
|
||||
if !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
originalMoveCount := len(g.Moves())
|
||||
|
||||
if err := g.UndoMoves(count); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete undone moves from storage if enabled
|
||||
if s.store != nil {
|
||||
remainingMoves := originalMoveCount - count
|
||||
s.store.DeleteUndoneMoves(gameID, remainingMoves)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGame removes a game from memory
|
||||
func (s *Service) DeleteGame(gameID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.games[gameID]; !ok {
|
||||
return fmt.Errorf("game not found: %s", gameID)
|
||||
}
|
||||
|
||||
delete(s.games, gameID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStorageHealth returns the storage component status
|
||||
func (s *Service) GetStorageHealth() string {
|
||||
if s.store == nil {
|
||||
|
||||
175
internal/service/user.go
Normal file
175
internal/service/user.go
Normal file
@ -0,0 +1,175 @@
|
||||
// FILE: internal/service/user.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"chess/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lixenwraith/auth"
|
||||
)
|
||||
|
||||
// User represents a registered user account
|
||||
type User struct {
|
||||
UserID string
|
||||
Username string
|
||||
Email string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// CreateUser creates new user with transactional consistency
|
||||
func (s *Service) CreateUser(username, email, password string) (*User, error) {
|
||||
if s.store == nil {
|
||||
return nil, fmt.Errorf("storage disabled")
|
||||
}
|
||||
|
||||
// Hash password
|
||||
passwordHash, err := auth.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Generate guaranteed unique user ID with proper collision handling
|
||||
userID, err := s.generateUniqueUserID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate unique ID: %w", err)
|
||||
}
|
||||
|
||||
// Create user record
|
||||
user := &User{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
// Use transactional storage method
|
||||
record := storage.UserRecord{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
PasswordHash: passwordHash,
|
||||
CreatedAt: user.CreatedAt,
|
||||
}
|
||||
|
||||
if err = s.store.CreateUser(record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// AuthenticateUser verifies user credentials and returns user information
|
||||
// AuthenticateUser verifies user credentials and returns user information
|
||||
func (s *Service) AuthenticateUser(identifier, password string) (*User, error) {
|
||||
if s.store == nil {
|
||||
return nil, fmt.Errorf("storage disabled")
|
||||
}
|
||||
|
||||
var userRecord *storage.UserRecord
|
||||
var err error
|
||||
|
||||
// Check if identifier looks like email
|
||||
if strings.Contains(identifier, "@") {
|
||||
userRecord, err = s.store.GetUserByEmail(identifier)
|
||||
} else {
|
||||
userRecord, err = s.store.GetUserByUsername(identifier)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Always hash to prevent timing attacks
|
||||
auth.HashPassword(password)
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if err := auth.VerifyPassword(password, userRecord.PasswordHash); err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
return &User{
|
||||
UserID: userRecord.UserID,
|
||||
Username: userRecord.Username,
|
||||
Email: userRecord.Email,
|
||||
CreatedAt: userRecord.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateLastLogin updates the last login timestamp for a user
|
||||
func (s *Service) UpdateLastLogin(userID string) error {
|
||||
if s.store == nil {
|
||||
return fmt.Errorf("storage disabled")
|
||||
}
|
||||
|
||||
err := s.store.UpdateUserLastLoginSync(userID, time.Now().UTC())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update last login time for user %s: %w\n", userID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves user information by user ID
|
||||
func (s *Service) GetUserByID(userID string) (*User, error) {
|
||||
if s.store == nil {
|
||||
return nil, fmt.Errorf("storage disabled")
|
||||
}
|
||||
|
||||
userRecord, err := s.store.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
return &User{
|
||||
UserID: userRecord.UserID,
|
||||
Username: userRecord.Username,
|
||||
Email: userRecord.Email,
|
||||
CreatedAt: userRecord.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateUserToken creates a JWT token for the specified user
|
||||
func (s *Service) GenerateUserToken(userID string) (string, error) {
|
||||
user, err := s.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims := map[string]any{
|
||||
"username": user.Username,
|
||||
"email": user.Email,
|
||||
}
|
||||
|
||||
return auth.GenerateHS256Token(s.jwtSecret, userID, claims, 7*24*time.Hour)
|
||||
}
|
||||
|
||||
// ValidateToken verifies JWT token and returns user ID with claims
|
||||
func (s *Service) ValidateToken(token string) (string, map[string]any, error) {
|
||||
return auth.ValidateHS256Token(s.jwtSecret, token)
|
||||
}
|
||||
|
||||
// generateUniqueUserID creates a unique user ID with collision detection
|
||||
func (s *Service) generateUniqueUserID() (string, error) {
|
||||
const maxAttempts = 10
|
||||
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
id := uuid.New().String()
|
||||
|
||||
// Check for collision
|
||||
if _, err := s.store.GetUserByID(id); err != nil {
|
||||
// Error means not found, ID is unique
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Collision detected, try again
|
||||
if i == maxAttempts-1 {
|
||||
// After max attempts, fail and don't risk collision
|
||||
return "", fmt.Errorf("failed to generate unique ID after %d attempts", maxAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed to generate unique user ID")
|
||||
}
|
||||
138
internal/storage/game.go
Normal file
138
internal/storage/game.go
Normal file
@ -0,0 +1,138 @@
|
||||
// FILE: internal/storage/game.go
|
||||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// RecordNewGame asynchronously records a new game
|
||||
func (s *Store) RecordNewGame(record GameRecord) error {
|
||||
if !s.healthStatus.Load() {
|
||||
return nil // Silently drop if degraded
|
||||
}
|
||||
|
||||
select {
|
||||
case s.writeChan <- func(tx *sql.Tx) error {
|
||||
query := `INSERT INTO games (
|
||||
game_id, initial_fen,
|
||||
white_player_id, white_type, white_level, white_search_time,
|
||||
black_player_id, black_type, black_level, black_search_time,
|
||||
start_time_utc
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||||
|
||||
_, err := tx.Exec(query,
|
||||
record.GameID, record.InitialFEN,
|
||||
record.WhitePlayerID, record.WhiteType, record.WhiteLevel, record.WhiteSearchTime,
|
||||
record.BlackPlayerID, record.BlackType, record.BlackLevel, record.BlackSearchTime,
|
||||
record.StartTimeUTC,
|
||||
)
|
||||
return err
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// Channel full, drop write
|
||||
log.Printf("Storage write queue full, dropping game record")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RecordMove asynchronously records a move
|
||||
func (s *Store) RecordMove(record MoveRecord) error {
|
||||
if !s.healthStatus.Load() {
|
||||
return nil // Silently drop if degraded
|
||||
}
|
||||
|
||||
select {
|
||||
case s.writeChan <- func(tx *sql.Tx) error {
|
||||
query := `INSERT INTO moves (
|
||||
game_id, move_number, move_uci, fen_after_move, player_color, move_time_utc
|
||||
) VALUES (?, ?, ?, ?, ?, ?)`
|
||||
|
||||
_, err := tx.Exec(query,
|
||||
record.GameID, record.MoveNumber, record.MoveUCI,
|
||||
record.FENAfterMove, record.PlayerColor, record.MoveTimeUTC,
|
||||
)
|
||||
return err
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// Channel full, drop write
|
||||
log.Printf("Storage write queue full, dropping move record")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteUndoneMoves asynchronously deletes moves after undo
|
||||
func (s *Store) DeleteUndoneMoves(gameID string, afterMoveNumber int) error {
|
||||
if !s.healthStatus.Load() {
|
||||
return nil // Silently drop if degraded
|
||||
}
|
||||
|
||||
select {
|
||||
case s.writeChan <- func(tx *sql.Tx) error {
|
||||
query := `DELETE FROM moves WHERE game_id = ? AND move_number > ?`
|
||||
_, err := tx.Exec(query, gameID, afterMoveNumber)
|
||||
return err
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// Channel full, drop write
|
||||
log.Printf("Storage write queue full, dropping undo operation")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// QueryGames retrieves games with optional filtering
|
||||
func (s *Store) QueryGames(gameID, playerID string) ([]GameRecord, error) {
|
||||
query := `SELECT
|
||||
game_id, initial_fen,
|
||||
white_player_id, white_type, white_level, white_search_time,
|
||||
black_player_id, black_type, black_level, black_search_time,
|
||||
start_time_utc
|
||||
FROM games WHERE 1=1`
|
||||
|
||||
var args []interface{}
|
||||
|
||||
// Handle gameID filtering
|
||||
if gameID != "" && gameID != "*" {
|
||||
query += " AND game_id = ?"
|
||||
args = append(args, gameID)
|
||||
}
|
||||
|
||||
// Handle playerID filtering
|
||||
if playerID != "" && playerID != "*" {
|
||||
query += " AND (white_player_id = ? OR black_player_id = ?)"
|
||||
args = append(args, playerID, playerID)
|
||||
}
|
||||
|
||||
query += " ORDER BY start_time_utc DESC"
|
||||
|
||||
rows, err := s.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var games []GameRecord
|
||||
for rows.Next() {
|
||||
var g GameRecord
|
||||
err := rows.Scan(
|
||||
&g.GameID, &g.InitialFEN,
|
||||
&g.WhitePlayerID, &g.WhiteType, &g.WhiteLevel, &g.WhiteSearchTime,
|
||||
&g.BlackPlayerID, &g.BlackType, &g.BlackLevel, &g.BlackSearchTime,
|
||||
&g.StartTimeUTC,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
games = append(games, g)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("rows iteration failed: %w", err)
|
||||
}
|
||||
|
||||
return games, nil
|
||||
}
|
||||
@ -3,6 +3,16 @@ package storage
|
||||
|
||||
import "time"
|
||||
|
||||
// UserRecord represents a user account in the database
|
||||
type UserRecord struct {
|
||||
UserID string `db:"user_id"`
|
||||
Username string `db:"username"`
|
||||
Email string `db:"email"`
|
||||
PasswordHash string `db:"password_hash"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
LastLoginAt *time.Time `db:"last_login_at"`
|
||||
}
|
||||
|
||||
// GameRecord represents a row in the games table
|
||||
type GameRecord struct {
|
||||
GameID string `db:"game_id"`
|
||||
@ -31,6 +41,19 @@ type MoveRecord struct {
|
||||
|
||||
// Schema defines the SQLite database structure
|
||||
const Schema = `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
username TEXT UNIQUE NOT NULL COLLATE NOCASE,
|
||||
email TEXT COLLATE NOCASE,
|
||||
password_hash TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at DATETIME
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
||||
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email_unique ON users(email) WHERE email IS NOT NULL AND email != '';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS games (
|
||||
game_id TEXT PRIMARY KEY,
|
||||
initial_fen TEXT NOT NULL,
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Store handles SQLite database operations with async writes
|
||||
// Store handles SQLite database operations with async writes for games and sync writes for auth
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
path string
|
||||
@ -70,6 +70,11 @@ func NewStore(dataSourceName string, devMode bool) (*Store, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// IsHealthy returns the current health status
|
||||
func (s *Store) IsHealthy() bool {
|
||||
return s.healthStatus.Load()
|
||||
}
|
||||
|
||||
// writerLoop processes async write operations
|
||||
func (s *Store) writerLoop() {
|
||||
defer s.wg.Done()
|
||||
@ -125,88 +130,6 @@ func (s *Store) executeWrite(fn func(*sql.Tx) error) {
|
||||
}
|
||||
}
|
||||
|
||||
// RecordNewGame asynchronously records a new game
|
||||
func (s *Store) RecordNewGame(record GameRecord) error {
|
||||
if !s.healthStatus.Load() {
|
||||
return nil // Silently drop if degraded
|
||||
}
|
||||
|
||||
select {
|
||||
case s.writeChan <- func(tx *sql.Tx) error {
|
||||
query := `INSERT INTO games (
|
||||
game_id, initial_fen,
|
||||
white_player_id, white_type, white_level, white_search_time,
|
||||
black_player_id, black_type, black_level, black_search_time,
|
||||
start_time_utc
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||||
|
||||
_, err := tx.Exec(query,
|
||||
record.GameID, record.InitialFEN,
|
||||
record.WhitePlayerID, record.WhiteType, record.WhiteLevel, record.WhiteSearchTime,
|
||||
record.BlackPlayerID, record.BlackType, record.BlackLevel, record.BlackSearchTime,
|
||||
record.StartTimeUTC,
|
||||
)
|
||||
return err
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// Channel full, drop write
|
||||
log.Printf("Storage write queue full, dropping game record")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RecordMove asynchronously records a move
|
||||
func (s *Store) RecordMove(record MoveRecord) error {
|
||||
if !s.healthStatus.Load() {
|
||||
return nil // Silently drop if degraded
|
||||
}
|
||||
|
||||
select {
|
||||
case s.writeChan <- func(tx *sql.Tx) error {
|
||||
query := `INSERT INTO moves (
|
||||
game_id, move_number, move_uci, fen_after_move, player_color, move_time_utc
|
||||
) VALUES (?, ?, ?, ?, ?, ?)`
|
||||
|
||||
_, err := tx.Exec(query,
|
||||
record.GameID, record.MoveNumber, record.MoveUCI,
|
||||
record.FENAfterMove, record.PlayerColor, record.MoveTimeUTC,
|
||||
)
|
||||
return err
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// Channel full, drop write
|
||||
log.Printf("Storage write queue full, dropping move record")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteUndoneMoves asynchronously deletes moves after undo
|
||||
func (s *Store) DeleteUndoneMoves(gameID string, afterMoveNumber int) error {
|
||||
if !s.healthStatus.Load() {
|
||||
return nil // Silently drop if degraded
|
||||
}
|
||||
|
||||
select {
|
||||
case s.writeChan <- func(tx *sql.Tx) error {
|
||||
query := `DELETE FROM moves WHERE game_id = ? AND move_number > ?`
|
||||
_, err := tx.Exec(query, gameID, afterMoveNumber)
|
||||
return err
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// Channel full, drop write
|
||||
log.Printf("Storage write queue full, dropping undo operation")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsHealthy returns the current health status
|
||||
func (s *Store) IsHealthy() bool {
|
||||
return s.healthStatus.Load()
|
||||
}
|
||||
|
||||
// Close gracefully closes the database connection
|
||||
func (s *Store) Close() error {
|
||||
// Signal writer to stop
|
||||
@ -260,57 +183,4 @@ func (s *Store) DeleteDB() error {
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryGames retrieves games with optional filtering
|
||||
func (s *Store) QueryGames(gameID, playerID string) ([]GameRecord, error) {
|
||||
query := `SELECT
|
||||
game_id, initial_fen,
|
||||
white_player_id, white_type, white_level, white_search_time,
|
||||
black_player_id, black_type, black_level, black_search_time,
|
||||
start_time_utc
|
||||
FROM games WHERE 1=1`
|
||||
|
||||
var args []interface{}
|
||||
|
||||
// Handle gameID filtering
|
||||
if gameID != "" && gameID != "*" {
|
||||
query += " AND game_id = ?"
|
||||
args = append(args, gameID)
|
||||
}
|
||||
|
||||
// Handle playerID filtering
|
||||
if playerID != "" && playerID != "*" {
|
||||
query += " AND (white_player_id = ? OR black_player_id = ?)"
|
||||
args = append(args, playerID, playerID)
|
||||
}
|
||||
|
||||
query += " ORDER BY start_time_utc DESC"
|
||||
|
||||
rows, err := s.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var games []GameRecord
|
||||
for rows.Next() {
|
||||
var g GameRecord
|
||||
err := rows.Scan(
|
||||
&g.GameID, &g.InitialFEN,
|
||||
&g.WhitePlayerID, &g.WhiteType, &g.WhiteLevel, &g.WhiteSearchTime,
|
||||
&g.BlackPlayerID, &g.BlackType, &g.BlackLevel, &g.BlackSearchTime,
|
||||
&g.StartTimeUTC,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
games = append(games, g)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("rows iteration failed: %w", err)
|
||||
}
|
||||
|
||||
return games, nil
|
||||
}
|
||||
187
internal/storage/user.go
Normal file
187
internal/storage/user.go
Normal file
@ -0,0 +1,187 @@
|
||||
// FILE: internal/storage/user.go
|
||||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CreateUser creates user with transaction isolation to prevent race conditions
|
||||
func (s *Store) CreateUser(record UserRecord) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Check uniqueness within transaction
|
||||
exists, err := s.userExists(tx, record.Username, record.Email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return fmt.Errorf("username or email already exists")
|
||||
}
|
||||
|
||||
// Insert user
|
||||
query := `INSERT INTO users (
|
||||
user_id, username, email, password_hash, created_at
|
||||
) VALUES (?, ?, ?, ?, ?)`
|
||||
|
||||
_, err = tx.Exec(query,
|
||||
record.UserID, record.Username, record.Email,
|
||||
record.PasswordHash, record.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// userExists verifies username/email uniqueness within a transaction
|
||||
func (s *Store) userExists(tx *sql.Tx, username, email string) (bool, error) {
|
||||
var count int
|
||||
query := `SELECT COUNT(*) FROM users WHERE username = ? COLLATE NOCASE`
|
||||
args := []interface{}{username}
|
||||
|
||||
if email != "" {
|
||||
query = `SELECT COUNT(*) FROM users WHERE username = ? COLLATE NOCASE OR email = ? COLLATE NOCASE`
|
||||
args = append(args, email)
|
||||
}
|
||||
|
||||
err := tx.QueryRow(query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// UpdateUserPassword updates user password hash
|
||||
func (s *Store) UpdateUserPassword(userID string, passwordHash string) error {
|
||||
query := `UPDATE users SET password_hash = ? WHERE user_id = ?`
|
||||
_, err := s.db.Exec(query, passwordHash, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateUserEmail updates user email
|
||||
func (s *Store) UpdateUserEmail(userID string, email string) error {
|
||||
query := `UPDATE users SET email = ? WHERE user_id = ?`
|
||||
_, err := s.db.Exec(query, email, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateUserUsername updates username
|
||||
func (s *Store) UpdateUserUsername(userID string, username string) error {
|
||||
query := `UPDATE users SET username = ? WHERE user_id = ?`
|
||||
_, err := s.db.Exec(query, username, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAllUsers retrieves all users
|
||||
func (s *Store) GetAllUsers() ([]UserRecord, error) {
|
||||
query := `SELECT user_id, username, email, password_hash, created_at, last_login_at
|
||||
FROM users ORDER BY created_at DESC`
|
||||
|
||||
rows, err := s.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []UserRecord
|
||||
for rows.Next() {
|
||||
var user UserRecord
|
||||
err := rows.Scan(
|
||||
&user.UserID, &user.Username, &user.Email,
|
||||
&user.PasswordHash, &user.CreatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateUserLastLoginSync updates user last login time
|
||||
func (s *Store) UpdateUserLastLoginSync(userID string, loginTime time.Time) error {
|
||||
query := `UPDATE users SET last_login_at = ? WHERE user_id = ?`
|
||||
|
||||
_, err := s.db.Exec(query, loginTime, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update last login for user %s: %w", userID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByUsername retrieves user by username with case-insensitive matching
|
||||
func (s *Store) GetUserByUsername(username string) (*UserRecord, error) {
|
||||
var user UserRecord
|
||||
query := `SELECT user_id, username, email, password_hash, created_at, last_login_at
|
||||
FROM users WHERE username = ? COLLATE NOCASE`
|
||||
|
||||
err := s.db.QueryRow(query, username).Scan(
|
||||
&user.UserID, &user.Username, &user.Email,
|
||||
&user.PasswordHash, &user.CreatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves user by email with case-insensitive matching
|
||||
func (s *Store) GetUserByEmail(email string) (*UserRecord, error) {
|
||||
var user UserRecord
|
||||
query := `SELECT user_id, username, email, password_hash, created_at, last_login_at
|
||||
FROM users WHERE email = ? COLLATE NOCASE`
|
||||
|
||||
err := s.db.QueryRow(query, email).Scan(
|
||||
&user.UserID, &user.Username, &user.Email,
|
||||
&user.PasswordHash, &user.CreatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves user by unique user ID
|
||||
func (s *Store) GetUserByID(userID string) (*UserRecord, error) {
|
||||
var user UserRecord
|
||||
query := `SELECT user_id, username, email, password_hash, created_at, last_login_at
|
||||
FROM users WHERE user_id = ?`
|
||||
|
||||
err := s.db.QueryRow(query, userID).Scan(
|
||||
&user.UserID, &user.Username, &user.Email,
|
||||
&user.PasswordHash, &user.CreatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// DeleteUser removes a user from the database
|
||||
func (s *Store) DeleteUser(userID string) error {
|
||||
if !s.healthStatus.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case s.writeChan <- func(tx *sql.Tx) error {
|
||||
query := `DELETE FROM users WHERE user_id = ?`
|
||||
_, err := tx.Exec(query, userID)
|
||||
return err
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
log.Printf("Storage write queue full, dropping user deletion")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user