diff --git a/README.md b/README.md index c82d0e1..56e5174 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,45 @@ # Auth Package -Pluggable authentication utilities for Go applications. +Modular authentication utilities for Go applications. ## Features -- **Password Hashing**: Argon2id with PHC format -- **JWT**: HS256/RS256 token generation and validation -- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF -- **HTTP Auth**: Basic/Bearer header parsing +- **Password Hashing**: Standalone Argon2id hashing with PHC format. +- **JWT**: HS256/RS256 token management via a simple facade over `golang-jwt`. +- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF. +- **HTTP Auth**: Helpers for parsing Basic and Bearer authentication headers. ## Usage + ```go +// Argon2 Password Hashing +hash, _ := auth.HashPassword("password123") +err := auth.VerifyPassword("password123", hash) + // JWT with HS256 -auth, _ := auth.NewAuthenticator([]byte("32-byte-secret-key...")) -token, _ := auth.GenerateToken("user123", map[string]interface{}{"role": "admin"}) -userID, claims, _ := auth.ValidateToken(token) +jwtMgr, _ := auth.NewJWT([]byte("a-very-secure-32-byte-secret-key")) +token, _ := jwtMgr.GenerateToken("user123", map[string]any{"role": "admin"}) +userID, claims, _ := jwtMgr.ValidateToken(token) // SCRAM authentication server := auth.NewScramServer() -cred, _ := auth.DeriveCredential("user", "password", salt, 1, 65536, 4) +phcHash, _ := auth.HashPassword("password123") +cred, _ := auth.MigrateFromPHC("user", "password123", phcHash) server.AddCredential(cred) ``` ## Package Structure -- `interfaces.go` - Core interfaces -- `jwt.go` - JWT token operations -- `argon2.go` - Password hashing -- `scram.go` - SCRAM-SHA256 protocol -- `token.go` - Token validation utilities -- `http.go` - HTTP header parsing -- `errors.go` - Error definitions +- `doc.go` - Overview and package documentation +- `argon2.go` - Standalone Argon2id password hashing +- `jwt.go` - JWT manager (HS256/RS256) wrapping `golang-jwt` +- `scram.go` - SCRAM-SHA256 client/server protocol +- `http.go` - HTTP Basic/Bearer header parsing +- `token.go` - Simple in-memory token validator +- `error.go` - Centralized error definitions ## Testing + ```bash -go test -v ./auth +go test -v ./ ``` \ No newline at end of file diff --git a/argon2.go b/argon2.go index 533d5ce..5e0e9de 100644 --- a/argon2.go +++ b/argon2.go @@ -13,40 +13,85 @@ import ( // Default Argon2id parameters const ( - DefaultArgonTime = 3 // iterations (reduce for faster but less secure auth) + DefaultArgonTime = 3 // iterations DefaultArgonMemory = 64 * 1024 // 64 MB DefaultArgonThreads = 4 DefaultArgonSaltLen = 16 DefaultArgonKeyLen = 32 ) -// HashPassword creates an Argon2id PHC-format hash -func (a *Authenticator) HashPassword(password string) (string, error) { +// argonParams holds configurable Argon2id parameters +type argonParams struct { + time uint32 + memory uint32 + threads uint8 + keyLen uint32 + saltLen uint32 +} + +// Option configures Argon2id hashing parameters +type Option func(*argonParams) + +// WithTime sets Argon2 iterations +func WithTime(t uint32) Option { + return func(p *argonParams) { + if t > 0 { + p.time = t + } + } +} + +// WithMemory sets Argon2 memory in KiB +func WithMemory(m uint32) Option { + return func(p *argonParams) { + if m > 0 { + p.memory = m + } + } +} + +// WithThreads sets Argon2 parallelism +func WithThreads(t uint8) Option { + return func(p *argonParams) { + if t > 0 { + p.threads = t + } + } +} + +// HashPassword creates Argon2id PHC-format hash (standalone) +func HashPassword(password string, opts ...Option) (string, error) { if len(password) < 8 { return "", ErrWeakPassword } - // Generate salt - salt := make([]byte, DefaultArgonSaltLen) + params := &argonParams{ + time: DefaultArgonTime, + memory: DefaultArgonMemory, + threads: DefaultArgonThreads, + keyLen: DefaultArgonKeyLen, + saltLen: DefaultArgonSaltLen, + } + + for _, opt := range opts { + opt(params) + } + + salt := make([]byte, params.saltLen) if _, err := rand.Read(salt); err != nil { return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err) } - // Derive key using Argon2id - hash := argon2.IDKey([]byte(password), salt, a.argonTime, a.argonMemory, a.argonThreads, DefaultArgonKeyLen) + hash := argon2.IDKey([]byte(password), salt, params.time, params.memory, params.threads, params.keyLen) - // Construct PHC format saltB64 := base64.RawStdEncoding.EncodeToString(salt) hashB64 := base64.RawStdEncoding.EncodeToString(hash) - phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", - argon2.Version, a.argonMemory, a.argonTime, a.argonThreads, saltB64, hashB64) - - return phcHash, nil + return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, params.memory, params.time, params.threads, saltB64, hashB64), nil } -// VerifyPassword checks password against PHC-format hash -func (a *Authenticator) VerifyPassword(password, phcHash string) error { - // Parse PHC format +// VerifyPassword checks password against PHC-format hash (standalone) +func VerifyPassword(password, phcHash string) error { parts := strings.Split(phcHash, "$") if len(parts) != 6 || parts[1] != "argon2id" { return ErrPHCInvalidFormat @@ -66,10 +111,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error { return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err) } - // Compute hash with same parameters computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash))) - // Constant-time comparison if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 { return ErrInvalidCredentials } @@ -77,9 +120,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error { return nil } -// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential +// MigrateFromPHC converts PHC hash to SCRAM credential func MigrateFromPHC(username, password, phcHash string) (*Credential, error) { - // Parse PHC format parts := strings.Split(phcHash, "$") if len(parts) != 6 || parts[1] != "argon2id" { return nil, ErrPHCInvalidFormat @@ -94,17 +136,10 @@ func MigrateFromPHC(username, password, phcHash string) (*Credential, error) { return nil, ErrPHCInvalidSalt } - expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) - if err != nil { - return nil, ErrPHCInvalidHash + // Use standalone function for verification + if err := VerifyPassword(password, phcHash); err != nil { + return nil, err } - // Verify password against hash - computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash))) - if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 { - return nil, ErrInvalidCredentials - } - - // Derive SCRAM credential with same parameters return DeriveCredential(username, password, salt, time, memory, threads) } \ No newline at end of file diff --git a/argon2_test.go b/argon2_test.go index f58f066..af78978 100644 --- a/argon2_test.go +++ b/argon2_test.go @@ -11,13 +11,10 @@ import ( ) func TestPasswordHashing(t *testing.T) { - auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) - require.NoError(t, err, "Failed to create authenticator") - password := "testPassword123" - // Test hashing - hash, err := auth.HashPassword(password) + // Test hashing with default parameters + hash, err := HashPassword(password) require.NoError(t, err, "Failed to hash password") // Verify PHC format @@ -25,20 +22,30 @@ func TestPasswordHashing(t *testing.T) { "Hash should have argon2id prefix, got: %s", hash) // Test verification with correct password - err = auth.VerifyPassword(password, hash) + err = VerifyPassword(password, hash) assert.NoError(t, err, "Failed to verify correct password") // Test verification with incorrect password - err = auth.VerifyPassword("wrongPassword", hash) + err = VerifyPassword("wrongPassword", hash) assert.Error(t, err, "Verification should fail for incorrect password") assert.Equal(t, ErrInvalidCredentials, err) // Test weak password - _, err = auth.HashPassword("weak") + _, err = HashPassword("weak") assert.Equal(t, ErrWeakPassword, err, "Should reject weak password") + // Test with custom options + hash, err = HashPassword(password, + WithTime(5), + WithMemory(128*1024), + WithThreads(8)) + require.NoError(t, err) + + err = VerifyPassword(password, hash) + assert.NoError(t, err) + // Test malformed PHC hash - err = auth.VerifyPassword(password, "$invalid$format") + err = VerifyPassword(password, "$invalid$format") assert.Error(t, err, "Should reject malformed hash") // Test corrupted salt @@ -47,33 +54,27 @@ func TestPasswordHashing(t *testing.T) { if len(parts) == 6 { parts[4] = "invalid!base64" corruptedHash = strings.Join(parts, "$") - err = auth.VerifyPassword(password, corruptedHash) + err = VerifyPassword(password, corruptedHash) assert.Error(t, err, "Should reject corrupted salt") } } func TestEmptyPasswordAfterValidation(t *testing.T) { - auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) - require.NoError(t, err) - // Empty password should be rejected by length check - _, err = auth.HashPassword("") + _, err := HashPassword("") assert.Equal(t, ErrWeakPassword, err) // 8-character password should pass - hash, err := auth.HashPassword("12345678") + hash, err := HashPassword("12345678") require.NoError(t, err) - err = auth.VerifyPassword("12345678", hash) + err = VerifyPassword("12345678", hash) assert.NoError(t, err) } func TestConcurrentPasswordOperations(t *testing.T) { - auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) - require.NoError(t, err) - password := "testPassword123" - hash, err := auth.HashPassword(password) + hash, err := HashPassword(password) require.NoError(t, err) // Test concurrent verification @@ -82,7 +83,7 @@ func TestConcurrentPasswordOperations(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err := auth.VerifyPassword(password, hash) + err := VerifyPassword(password, hash) assert.NoError(t, err) }() } @@ -90,14 +91,11 @@ func TestConcurrentPasswordOperations(t *testing.T) { } func TestPHCMigration(t *testing.T) { - auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) - require.NoError(t, err) - password := "testPassword123" username := "migrationUser" // Generate PHC hash - phcHash, err := auth.HashPassword(password) + phcHash, err := HashPassword(password) require.NoError(t, err) // Migrate to SCRAM credential diff --git a/auth.go b/auth.go deleted file mode 100644 index 3a80ab3..0000000 --- a/auth.go +++ /dev/null @@ -1,71 +0,0 @@ -// FILE: auth/auth.go -package auth - -import ( - "crypto/rsa" - "fmt" -) - -// Authenticator provides password hashing and JWT operations -type Authenticator struct { - algorithm string - jwtSecret []byte // For HS256 - privateKey *rsa.PrivateKey // For RS256 - publicKey *rsa.PublicKey // For RS256 - argonTime uint32 - argonMemory uint32 - argonThreads uint8 -} - -// NewAuthenticator creates a new authenticator with specified algorithm -func NewAuthenticator(key any, algorithm ...string) (*Authenticator, error) { - alg := "HS256" - if len(algorithm) > 0 && algorithm[0] != "" { - alg = algorithm[0] - } - - auth := &Authenticator{ - algorithm: alg, - argonTime: DefaultArgonTime, - argonMemory: DefaultArgonMemory, - argonThreads: DefaultArgonThreads, - } - - switch alg { - case "HS256": - secret, ok := key.([]byte) - if !ok { - return nil, ErrInvalidKeyType - } - if len(secret) < 32 { - return nil, ErrSecretTooShort - } - auth.jwtSecret = secret - - case "RS256": - switch k := key.(type) { - case *rsa.PrivateKey: - auth.privateKey = k - auth.publicKey = &k.PublicKey - case *rsa.PublicKey: - auth.publicKey = k - case []byte: - // Try parsing as PEM - if privKey, err := parseRSAPrivateKey(k); err == nil { - auth.privateKey = privKey - auth.publicKey = &privKey.PublicKey - } else if pubKey, err := parseRSAPublicKey(k); err == nil { - auth.publicKey = pubKey - } else { - return nil, fmt.Errorf("failed to parse RSA key: %w", err) - } - default: - return nil, ErrInvalidKeyType - } - - default: - return nil, ErrInvalidAlgorithm - } - - return auth, nil -} \ No newline at end of file diff --git a/auth_test.go b/auth_test.go deleted file mode 100644 index 5176f37..0000000 --- a/auth_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// FILE: auth/auth_test.go -package auth - -import ( - "crypto/rand" - "crypto/rsa" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewAuthenticator(t *testing.T) { - // Test HS256 creation - auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) - require.NoError(t, err, "Failed to create HS256 authenticator") - assert.Equal(t, "HS256", auth.algorithm) - - // Test with short secret - _, err = NewAuthenticator([]byte("short")) - assert.Equal(t, ErrSecretTooShort, err) - - // Test RS256 with private key - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - authRS, err := NewAuthenticator(privateKey, "RS256") - require.NoError(t, err) - assert.Equal(t, "RS256", authRS.algorithm) - assert.NotNil(t, authRS.privateKey) - assert.NotNil(t, authRS.publicKey) - - // Test RS256 with public key only - authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256") - require.NoError(t, err) - assert.Equal(t, "RS256", authPub.algorithm) - assert.Nil(t, authPub.privateKey) - assert.NotNil(t, authPub.publicKey) - - // Test invalid algorithm - _, err = NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"), "INVALID") - assert.Equal(t, ErrInvalidAlgorithm, err) - - // Test invalid key type for HS256 - _, err = NewAuthenticator(privateKey, "HS256") - assert.Equal(t, ErrInvalidKeyType, err) -} - -func TestInterfaceCompliance(t *testing.T) { - // Verify Authenticator implements AuthenticatorInterface - auth, _ := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) - - var _ AuthenticatorInterface = auth - - // Test interface methods work - hash, err := auth.HashPassword("testpass123") - require.NoError(t, err) - - err = auth.VerifyPassword("testpass123", hash) - assert.NoError(t, err) - - token, err := auth.GenerateToken("user1", nil) - require.NoError(t, err) - - userID, _, err := auth.ValidateToken(token) - require.NoError(t, err) - assert.Equal(t, "user1", userID) -} \ No newline at end of file diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..ad50a02 --- /dev/null +++ b/doc.go @@ -0,0 +1,53 @@ +// FILE: auth/doc.go +package auth + +/* +Package auth provides modular authentication components: + +# Argon2 Password Hashing + +Standalone password hashing using Argon2id: + + hash, err := auth.HashPassword("password123") + err = auth.VerifyPassword("password123", hash) + + // With custom parameters + hash, err := auth.HashPassword("password123", + auth.WithTime(5), + auth.WithMemory(128*1024)) + +# JWT Token Management + +JSON Web Token generation and validation: + + // HS256 (symmetric) + jwtMgr, _ := auth.NewJWT(secret) + token, _ := jwtMgr.GenerateToken("user1", claims) + userID, claims, _ := jwtMgr.ValidateToken(token) + + // RS256 (asymmetric) + jwtMgr, _ := auth.NewJWTRSA(privateKey) + + // One-off operations + token, _ := auth.GenerateHS256Token(secret, "user1", claims, 1*time.Hour) + +# SCRAM-SHA256 Authentication + +Server and client implementation for SCRAM: + + // Server + server := auth.NewScramServer() + server.AddCredential(credential) + + // Client + client := auth.NewScramClient(username, password) + +# HTTP Authentication Parsing + +Utility functions for HTTP headers: + + username, password, _ := auth.ParseBasicAuth(header) + token, _ := auth.ParseBearerToken(header) + +Each module can be used independently without initializing other components. +*/ \ No newline at end of file diff --git a/error.go b/error.go index 824a494..b4ca76c 100644 --- a/error.go +++ b/error.go @@ -10,8 +10,6 @@ import ( var ( ErrInvalidCredentials = errors.New("invalid credentials") ErrWeakPassword = errors.New("password must be at least 8 characters") - ErrInvalidAlgorithm = errors.New("invalid algorithm") - ErrInvalidKeyType = errors.New("invalid key type for algorithm") ) // JWT-specific errors @@ -22,9 +20,6 @@ var ( ErrTokenInvalidSignature = errors.New("token: invalid signature") ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch") ErrTokenMissingClaim = errors.New("token: missing required claim") - ErrTokenInvalidHeader = errors.New("token: invalid header encoding") - ErrTokenInvalidClaims = errors.New("token: invalid claims encoding") - ErrTokenInvalidJSON = errors.New("token: malformed JSON") ErrTokenEmptyUserID = errors.New("token: empty user ID") ErrTokenNoPrivateKey = errors.New("token: private key required for signing") ErrTokenNoPublicKey = errors.New("token: public key required for verification") diff --git a/go.mod b/go.mod index b6154ca..a4a8ef3 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,9 @@ -module auth +module github.com/lixenwraith/auth go 1.25.3 require ( + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.43.0 ) diff --git a/go.sum b/go.sum index 4a5adfe..d143a61 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/interface.go b/interface.go deleted file mode 100644 index e84d390..0000000 --- a/interface.go +++ /dev/null @@ -1,17 +0,0 @@ -// FILE: auth/interface.go -package auth - -// AuthenticatorInterface defines the authentication operations -type AuthenticatorInterface interface { - HashPassword(password string) (hash string, err error) - VerifyPassword(password, hash string) (err error) - GenerateToken(userID string, claims map[string]any) (token string, err error) - ValidateToken(token string) (userID string, claims map[string]any, err error) -} - -// TokenValidator validates bearer tokens -type TokenValidator interface { - ValidateToken(token string) (valid bool) - AddToken(token string) - RemoveToken(token string) -} \ No newline at end of file diff --git a/jwt.go b/jwt.go index e5d4f53..ab10899 100644 --- a/jwt.go +++ b/jwt.go @@ -2,179 +2,289 @@ package auth import ( - "crypto" - "crypto/hmac" - "crypto/rand" "crypto/rsa" - "crypto/sha256" - "crypto/subtle" "crypto/x509" - "encoding/base64" - "encoding/json" "encoding/pem" + "errors" "fmt" - "strings" "time" + + "github.com/golang-jwt/jwt/v5" ) -// GenerateToken creates a JWT token with user claims -func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) { +// JWT configuration defaults +const ( + DefaultTokenLifetime = 24 * time.Hour + DefaultLeeway = 5 * time.Minute +) + +// customClaims extends RegisteredClaims with arbitrary user data +type customClaims struct { + jwt.RegisteredClaims + Extra map[string]any `json:"extra,omitempty"` +} + +// JWT manages token generation and validation +type JWT struct { + algorithm jwt.SigningMethod + signKey any // []byte for HMAC, *rsa.PrivateKey for RSA + verifyKey any // []byte for HMAC, *rsa.PublicKey for RSA + tokenLifetime time.Duration + leeway time.Duration + issuer string + audience []string +} + +// JWTOption configures JWT behavior +type JWTOption func(*JWT) + +// WithTokenLifetime sets token expiration duration +func WithTokenLifetime(d time.Duration) JWTOption { + return func(j *JWT) { + if d > 0 { + j.tokenLifetime = d + } + } +} + +// WithLeeway sets clock skew tolerance +func WithLeeway(d time.Duration) JWTOption { + return func(j *JWT) { + if d >= 0 { + j.leeway = d + } + } +} + +// WithIssuer sets token issuer claim +func WithIssuer(iss string) JWTOption { + return func(j *JWT) { + j.issuer = iss + } +} + +// WithAudience sets token audience claim +func WithAudience(aud []string) JWTOption { + return func(j *JWT) { + j.audience = aud + } +} + +// NewJWT creates JWT manager for HS256 (symmetric) +func NewJWT(secret []byte, opts ...JWTOption) (*JWT, error) { + if len(secret) < 32 { + return nil, ErrSecretTooShort + } + + j := &JWT{ + algorithm: jwt.SigningMethodHS256, + signKey: secret, + verifyKey: secret, + tokenLifetime: DefaultTokenLifetime, + leeway: DefaultLeeway, + } + + for _, opt := range opts { + opt(j) + } + + return j, nil +} + +// NewJWTRSA creates JWT manager for RS256 (asymmetric) +func NewJWTRSA(privateKey *rsa.PrivateKey, opts ...JWTOption) (*JWT, error) { + if privateKey == nil { + return nil, ErrTokenNoPrivateKey + } + + j := &JWT{ + algorithm: jwt.SigningMethodRS256, + signKey: privateKey, + verifyKey: &privateKey.PublicKey, + tokenLifetime: DefaultTokenLifetime, + leeway: DefaultLeeway, + } + + for _, opt := range opts { + opt(j) + } + + return j, nil +} + +// NewJWTVerifier creates JWT manager for verification only (RS256) +func NewJWTVerifier(publicKey *rsa.PublicKey, opts ...JWTOption) (*JWT, error) { + if publicKey == nil { + return nil, ErrTokenNoPublicKey + } + + j := &JWT{ + algorithm: jwt.SigningMethodRS256, + signKey: nil, // Cannot sign + verifyKey: publicKey, + tokenLifetime: DefaultTokenLifetime, + leeway: DefaultLeeway, + } + + for _, opt := range opts { + opt(j) + } + + return j, nil +} + +// GenerateToken creates signed JWT with claims +func (j *JWT) GenerateToken(userID string, claims map[string]any) (string, error) { if userID == "" { return "", ErrTokenEmptyUserID } - if a.algorithm == "RS256" && a.privateKey == nil { + if j.signKey == nil { return "", ErrTokenNoPrivateKey } - // Build JWT claims now := time.Now() - jwtClaims := map[string]any{ - "sub": userID, - "iat": now.Unix(), - "exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry + registeredClaims := jwt.RegisteredClaims{ + Subject: userID, + Issuer: j.issuer, + Audience: j.audience, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(j.tokenLifetime)), + NotBefore: jwt.NewNumericDate(now), } - // Reserved claims that cannot be overridden - reservedClaims := map[string]bool{ - "sub": true, "iat": true, "exp": true, "nbf": true, - "iss": true, "aud": true, "jti": true, "typ": true, - "alg": true, - } + token := jwt.NewWithClaims(j.algorithm, customClaims{ + RegisteredClaims: registeredClaims, + Extra: claims, + }) - // Merge custom claims - for k, v := range claims { - if !reservedClaims[k] { - jwtClaims[k] = v - } - } - - // Create JWT header - header := map[string]any{ - "alg": a.algorithm, - "typ": "JWT", - } - - // Encode header and payload - headerJSON, _ := json.Marshal(header) - claimsJSON, _ := json.Marshal(jwtClaims) - - headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) - claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) - - // Create signature - signingInput := headerB64 + "." + claimsB64 - var signature string - - switch a.algorithm { - case "HS256": - h := hmac.New(sha256.New, a.jwtSecret) - h.Write([]byte(signingInput)) - signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - - case "RS256": - hashed := sha256.Sum256([]byte(signingInput)) - sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:]) - if err != nil { - return "", fmt.Errorf("failed to sign token: %w", err) - } - signature = base64.RawURLEncoding.EncodeToString(sig) - } - - // Combine to form JWT - token := signingInput + "." + signature - - return token, nil + return token.SignedString(j.signKey) } -// ValidateToken verifies JWT and returns userID and claims -func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) { - // Split token - parts := strings.Split(token, ".") - if len(parts) != 3 { +// ValidateToken verifies JWT and extracts claims +func (j *JWT) ValidateToken(tokenString string) (string, map[string]any, error) { + parser := jwt.NewParser( + jwt.WithLeeway(j.leeway), + jwt.WithAudience(j.audience...), + jwt.WithIssuer(j.issuer), + jwt.WithValidMethods([]string{j.algorithm.Alg()}), + jwt.WithExpirationRequired(), + ) + + token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) { + // Algorithm already validated by WithValidMethods + return j.verifyKey, nil + }) + + if err != nil { + return "", nil, mapJWTError(err) + } + + claims, ok := token.Claims.(*customClaims) + if !ok || !token.Valid { return "", nil, ErrTokenMalformed } - // Decode header to check algorithm - headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil { - return "", nil, ErrTokenInvalidHeader - } - - var header map[string]any - if err = json.Unmarshal(headerJSON, &header); err != nil { - return "", nil, ErrTokenInvalidJSON - } - - // Verify algorithm matches - if alg, ok := header["alg"].(string); !ok || alg != a.algorithm { - return "", nil, ErrTokenAlgorithmMismatch - } - - // Verify signature - signingInput := parts[0] + "." + parts[1] - - switch a.algorithm { - case "HS256": - h := hmac.New(sha256.New, a.jwtSecret) - h.Write([]byte(signingInput)) - expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - - if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 { - return "", nil, ErrTokenInvalidSignature - } - - case "RS256": - if a.publicKey == nil { - return "", nil, ErrTokenNoPublicKey - } - - sig, err := base64.RawURLEncoding.DecodeString(parts[2]) - if err != nil { - return "", nil, ErrTokenInvalidSignature - } - - hashed := sha256.Sum256([]byte(signingInput)) - if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil { - return "", nil, ErrTokenInvalidSignature - } - } - - // Decode claims - claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return "", nil, ErrTokenInvalidClaims - } - - var claims map[string]any - if err := json.Unmarshal(claimsJSON, &claims); err != nil { - return "", nil, ErrTokenInvalidJSON - } - - // Check expiration - if exp, ok := claims["exp"].(float64); ok { - if time.Now().Unix() > int64(exp) { - return "", nil, ErrTokenExpired - } - } - - // Check not before - if nbf, ok := claims["nbf"].(float64); ok { - if time.Now().Unix() < int64(nbf) { - return "", nil, ErrTokenNotYetValid - } - } - - // Extract userID - userID, ok := claims["sub"].(string) - if !ok { - return "", nil, ErrTokenMissingClaim - } - - return userID, claims, nil + return claims.Subject, claims.Extra, nil } -// parseRSAPrivateKey parses PEM encoded RSA private key +// mapJWTError translates jwt library errors to auth package errors +func mapJWTError(err error) error { + switch { + case errors.Is(err, jwt.ErrTokenMalformed): + return fmt.Errorf("%w : %w", ErrTokenMalformed, err) + case errors.Is(err, jwt.ErrTokenUnverifiable): + return fmt.Errorf("%w : %w", ErrTokenMalformed, err) + case errors.Is(err, jwt.ErrTokenSignatureInvalid): + return fmt.Errorf("%w : %w", ErrTokenInvalidSignature, err) + case errors.Is(err, jwt.ErrTokenExpired): + return fmt.Errorf("%w : %w", ErrTokenExpired, err) + case errors.Is(err, jwt.ErrTokenNotValidYet): + return fmt.Errorf("%w : %w", ErrTokenNotYetValid, err) + case errors.Is(err, jwt.ErrTokenInvalidAudience): + return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err) + case errors.Is(err, jwt.ErrTokenInvalidIssuer): + return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err) + default: + // Check for algorithm mismatch in error message + if errors.Is(err, jwt.ErrTokenSignatureInvalid) { + return fmt.Errorf("%w : %w", ErrTokenAlgorithmMismatch, err) + } + return fmt.Errorf("%w : %w", ErrTokenMalformed, err) + } +} + +// Standalone helper functions for one-off operations + +// GenerateHS256Token creates HS256 JWT without manager instance +func GenerateHS256Token(secret []byte, userID string, claims map[string]any, lifetime time.Duration) (string, error) { + if len(secret) < 32 { + return "", ErrSecretTooShort + } + + now := time.Now() + token := jwt.NewWithClaims(jwt.SigningMethodHS256, customClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(lifetime)), + }, + Extra: claims, + }) + + return token.SignedString(secret) +} + +// ValidateHS256Token verifies HS256 JWT without manager instance +func ValidateHS256Token(secret []byte, tokenString string) (string, map[string]any, error) { + if len(secret) < 32 { + return "", nil, ErrSecretTooShort + } + + parser := jwt.NewParser( + jwt.WithValidMethods([]string{"HS256"}), + jwt.WithLeeway(DefaultLeeway), + ) + + token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) { + return secret, nil + }) + + if err != nil { + return "", nil, mapJWTError(err) + } + + claims, ok := token.Claims.(*customClaims) + if !ok || !token.Valid { + return "", nil, ErrTokenMalformed + } + + return claims.Subject, claims.Extra, nil +} + +// RSA Utilities + +// NewJWTRSAFromPEM creates a JWT manager for RS256 from raw PEM-encoded private key data. +func NewJWTRSAFromPEM(privateKeyPEM []byte, opts ...JWTOption) (*JWT, error) { + privateKey, err := parseRSAPrivateKey(privateKeyPEM) + if err != nil { + return nil, err + } + // Call the original constructor with the now-parsed key + return NewJWTRSA(privateKey, opts...) +} + +// NewJWTVerifierFromPEM creates a JWT manager for verification from raw PEM-encoded public key data. +func NewJWTVerifierFromPEM(publicKeyPEM []byte, opts ...JWTOption) (*JWT, error) { + publicKey, err := parseRSAPublicKey(publicKeyPEM) + if err != nil { + return nil, err + } + // Call the original constructor with the now-parsed key + return NewJWTVerifier(publicKey, opts...) +} + +// parseRSAPrivateKey parses a PEM-encoded RSA private key. func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { block, _ := pem.Decode(pemBytes) if block == nil { @@ -187,7 +297,7 @@ func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { return key, nil } -// parseRSAPublicKey parses PEM encoded RSA public key +// parseRSAPublicKey parses a PEM-encoded RSA public key. func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) { block, _ := pem.Decode(pemBytes) if block == nil { diff --git a/jwt_test.go b/jwt_test.go index 56ec30d..c65f28a 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -4,19 +4,20 @@ package auth import ( "crypto/rand" "crypto/rsa" - "encoding/base64" - "errors" - "fmt" + "crypto/x509" + "encoding/pem" "strings" "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestJWTHS256(t *testing.T) { - auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + secret := []byte("test-secret-key-must-be-32-bytes") + jwtMgr, err := NewJWT(secret) require.NoError(t, err) userID := "user123" @@ -26,54 +27,17 @@ func TestJWTHS256(t *testing.T) { } // Generate token - token, err := auth.GenerateToken(userID, claims) - require.NoError(t, err, "Failed to generate token") + token, err := jwtMgr.GenerateToken(userID, claims) + require.NoError(t, err) assert.NotEmpty(t, token) // Validate token - extractedUserID, extractedClaims, err := auth.ValidateToken(token) - require.NoError(t, err, "Failed to validate token") + extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token) + require.NoError(t, err) assert.Equal(t, userID, extractedUserID) assert.Equal(t, "test@example.com", extractedClaims["email"]) assert.Equal(t, "admin", extractedClaims["role"]) - - // Test invalid token - _, _, err = auth.ValidateToken("invalid.token.here") - assert.Error(t, err) - assert.True(t, errors.Is(err, ErrTokenInvalidJSON)) - - // Test tampered token - parts := strings.Split(token, ".") - require.Len(t, parts, 3, "JWT should have 3 parts") - - tampered := parts[0] + "." + parts[1] + ".invalidsignature" - _, _, err = auth.ValidateToken(tampered) - assert.Error(t, err) - assert.True(t, errors.Is(err, ErrTokenInvalidSignature)) - - // Test reserved claims cannot be overridden - overrideClaims := map[string]any{ - "sub": "override", - "iat": 12345, - "exp": 67890, - "nbf": 11111, - "iss": "attacker", - "aud": "victim", - "jti": "fake", - } - - token, err = auth.GenerateToken(userID, overrideClaims) - require.NoError(t, err) - - extractedUserID, extractedClaims, err = auth.ValidateToken(token) - require.NoError(t, err) - - assert.Equal(t, userID, extractedUserID, "UserID should not be overridden") - assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden") - assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden") - assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims") - assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims") } func TestJWTRS256(t *testing.T) { @@ -82,95 +46,214 @@ func TestJWTRS256(t *testing.T) { require.NoError(t, err) // Test with private key (can sign and verify) - authPriv, err := NewAuthenticator(privateKey, "RS256") + jwtMgr, err := NewJWTRSA(privateKey) require.NoError(t, err) userID := "user456" claims := map[string]any{ - "email": "rs256@example.com", "scope": "read:all", } // Generate token - token, err := authPriv.GenerateToken(userID, claims) + token, err := jwtMgr.GenerateToken(userID, claims) require.NoError(t, err) assert.NotEmpty(t, token) - // Validate token with private key auth (has public key too) - extractedUserID, extractedClaims, err := authPriv.ValidateToken(token) + // Validate with same manager + extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token) require.NoError(t, err) - assert.Equal(t, userID, extractedUserID) - assert.Equal(t, "rs256@example.com", extractedClaims["email"]) assert.Equal(t, "read:all", extractedClaims["scope"]) - // Test with public key only (can only verify) - authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256") + // Test with verifier only (public key) + verifier, err := NewJWTVerifier(&privateKey.PublicKey) require.NoError(t, err) - // Should be able to validate token - extractedUserID, extractedClaims, err = authPub.ValidateToken(token) + // Should validate token + extractedUserID, _, err = verifier.ValidateToken(token) require.NoError(t, err) assert.Equal(t, userID, extractedUserID) - // Should not be able to generate token - _, err = authPub.GenerateToken(userID, claims) - assert.Error(t, err, "Public key only auth should not generate tokens") + // Should not generate token + _, err = verifier.GenerateToken(userID, claims) assert.Equal(t, ErrTokenNoPrivateKey, err) - - // Test algorithm mismatch - authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) - require.NoError(t, err) - - _, _, err = authHS256.ValidateToken(token) - assert.Error(t, err, "HS256 auth should not validate RS256 token") - // assert.True(t, errors.Is(err, ErrInvalidToken)) - assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch)) - - fmt.Println(err) } -func TestExpiredToken(t *testing.T) { - auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) +func TestJWTOptions(t *testing.T) { + secret := []byte("test-secret-key-must-be-32-bytes") + + // Test custom lifetime + jwtMgr, err := NewJWT(secret, + WithTokenLifetime(1*time.Hour), + WithIssuer("test-issuer"), + WithAudience([]string{"api.example.com"}), + ) require.NoError(t, err) - userID := "user123" - - // Generate normal token (should have 7 days expiry) - token, err := auth.GenerateToken(userID, nil) + token, err := jwtMgr.GenerateToken("user1", nil) require.NoError(t, err) - _, extractedClaims, err := auth.ValidateToken(token) - require.NoError(t, err) + // Parse token to check claims + parsed, _ := jwt.Parse(token, func(token *jwt.Token) (any, error) { + return secret, nil + }) - // Check expiry is in future (approximately 7 days) - expiry := extractedClaims["exp"].(float64) - now := time.Now().Unix() + claims := parsed.Claims.(jwt.MapClaims) - assert.Greater(t, expiry, float64(now), "Token expiry should be in future") - assert.InDelta(t, expiry, float64(now+7*24*60*60), 10, - "Token expiry should be approximately 7 days from now") + // Check issuer + assert.Equal(t, "test-issuer", claims["iss"]) + + // Check audience + aud := claims["aud"].([]any) + assert.Contains(t, aud, "api.example.com") + + // Check expiration is ~1 hour + exp := int64(claims["exp"].(float64)) + iat := int64(claims["iat"].(float64)) + assert.InDelta(t, 3600, exp-iat, 10) } -func TestCorruptJWTParts(t *testing.T) { - auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) +func TestJWTErrors(t *testing.T) { + secret := []byte("test-secret-key-must-be-32-bytes") + jwtMgr, err := NewJWT(secret) require.NoError(t, err) - // Test with missing parts - _, _, err = auth.ValidateToken("only.two") - assert.True(t, errors.Is(err, ErrTokenMalformed)) + // Empty user ID + _, err = jwtMgr.GenerateToken("", nil) + assert.Equal(t, ErrTokenEmptyUserID, err) - // Test with invalid header encoding - _, _, err = auth.ValidateToken("not-base64!.valid.valid") - assert.Error(t, err) + // Invalid token format + _, _, err = jwtMgr.ValidateToken("invalid.token") + assert.ErrorIs(t, err, ErrTokenMalformed) - // Test with invalid claims encoding - validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" - _, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid") - assert.Error(t, err) + // Tampered signature + token, _ := jwtMgr.GenerateToken("user1", nil) + parts := strings.Split(token, ".") + tampered := parts[0] + "." + parts[1] + ".invalidsignature" + _, _, err = jwtMgr.ValidateToken(tampered) + assert.ErrorIs(t, err, ErrTokenInvalidSignature) - // Test with invalid JSON in claims - invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json")) - _, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature") - assert.Error(t, err) + // Wrong algorithm + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + rsaMgr, _ := NewJWTRSA(rsaKey) + rsaToken, _ := rsaMgr.GenerateToken("user1", nil) + + _, _, err = jwtMgr.ValidateToken(rsaToken) + assert.ErrorIs(t, err, ErrTokenInvalidSignature) +} + +func TestJWTExpiration(t *testing.T) { + secret := []byte("test-secret-key-must-be-32-bytes") + + // Create token with 1 second lifetime + jwtMgr, err := NewJWT(secret, WithTokenLifetime(1*time.Second), WithLeeway(0)) + require.NoError(t, err) + + token, err := jwtMgr.GenerateToken("user1", nil) + require.NoError(t, err) + + // Should be valid immediately + _, _, err = jwtMgr.ValidateToken(token) + assert.NoError(t, err) + + // Wait for expiration + time.Sleep(2 * time.Second) + + // Should be expired + _, _, err = jwtMgr.ValidateToken(token) + assert.ErrorIs(t, err, ErrTokenExpired) +} + +func TestLeeway(t *testing.T) { + secret := []byte("test-secret-key-must-be-32-bytes") + + // Create manager with no leeway + jwtMgr, err := NewJWT(secret, WithLeeway(0)) + require.NoError(t, err) + + // Manually create a token with NotBefore in future + now := time.Now() + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": "user1", + "nbf": now.Add(2 * time.Second).Unix(), + "exp": now.Add(1 * time.Hour).Unix(), + }) + tokenString, err := token.SignedString(secret) + require.NoError(t, err) + + // Should fail immediately (not valid yet) + _, _, err = jwtMgr.ValidateToken(tokenString) + assert.ErrorIs(t, err, ErrTokenNotYetValid) + + // Create manager with leeway + jwtMgrWithLeeway, err := NewJWT(secret, WithLeeway(5*time.Second)) + require.NoError(t, err) + + // Should pass with leeway + _, _, err = jwtMgrWithLeeway.ValidateToken(tokenString) + assert.NoError(t, err) +} + +func TestStandaloneFunctions(t *testing.T) { + secret := []byte("test-secret-key-must-be-32-bytes") + userID := "standalone-user" + claims := map[string]any{"test": "value"} + + // Generate token + token, err := GenerateHS256Token(secret, userID, claims, 1*time.Hour) + require.NoError(t, err) + + // Validate token + extractedUserID, extractedClaims, err := ValidateHS256Token(secret, token) + require.NoError(t, err) + + assert.Equal(t, userID, extractedUserID) + assert.Equal(t, "value", extractedClaims["test"]) + + // Test with short secret + _, err = GenerateHS256Token([]byte("short"), userID, claims, 1*time.Hour) + assert.Equal(t, ErrSecretTooShort, err) +} + +func TestJWTRSAFromPEM(t *testing.T) { + // 1. Generate a new RSA key pair for this test + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + // 2. Encode the private key to PEM format + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + // 3. Encode the public key to PEM format + publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + }) + + // 4. Test the PEM constructor for the signer + jwtMgr, err := NewJWTRSAFromPEM(privateKeyPEM) + require.NoError(t, err) + + token, err := jwtMgr.GenerateToken("user-from-pem", nil) + require.NoError(t, err) + assert.NotEmpty(t, token) + + // 5. Test the PEM constructor for the verifier + verifier, err := NewJWTVerifierFromPEM(publicKeyPEM) + require.NoError(t, err) + + userID, _, err := verifier.ValidateToken(token) + require.NoError(t, err) + assert.Equal(t, "user-from-pem", userID) + + // 6. Test failure cases with invalid data + _, err = NewJWTRSAFromPEM([]byte("invalid pem data")) + assert.ErrorIs(t, err, ErrRSAInvalidPEM) + + _, err = NewJWTVerifierFromPEM([]byte("invalid pem data")) + assert.ErrorIs(t, err, ErrRSAInvalidPEM) } \ No newline at end of file diff --git a/scram.go b/scram.go index 51e7266..d9840e5 100644 --- a/scram.go +++ b/scram.go @@ -143,6 +143,10 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3 return nil, ErrSCRAMSaltTooShort } + if time == 0 || memory == 0 || threads == 0 { + return nil, ErrSCRAMZeroParams + } + // Derive salted password using Argon2id saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen) diff --git a/scram_test.go b/scram_test.go new file mode 100644 index 0000000..40c8fc2 --- /dev/null +++ b/scram_test.go @@ -0,0 +1,146 @@ +// FILE: auth/scram_test.go +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupScramTest is a helper to initialize a server and user credential for testing. +// It performs the full Argon2 -> SCRAM migration workflow. +func setupScramTest(t *testing.T) (server *ScramServer, username, password string, cred *Credential) { + username = "testuser" + password = "SecurePassword123" + + // 1. Start with an Argon2 PHC hash, as a real application would. + phcHash, err := HashPassword(password) + require.NoError(t, err, "Setup failed: could not hash password") + + // 2. Migrate the PHC hash to a SCRAM credential. + cred, err = MigrateFromPHC(username, password, phcHash) + require.NoError(t, err, "Setup failed: could not migrate from PHC hash") + + // 3. Create a server and add the new credential. + server = NewScramServer() + server.AddCredential(cred) + + return server, username, password, cred +} + +// TestScram_FullRoundtrip_Success simulates a complete, successful authentication handshake. +func TestScram_FullRoundtrip_Success(t *testing.T) { + server, username, password, _ := setupScramTest(t) + client := NewScramClient(username, password) + + // --- Step 1: Client sends its first message --- + clientFirst, err := client.StartAuthentication() + require.NoError(t, err) + + // --- Step 2: Server receives client's message and responds --- + serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce) + require.NoError(t, err, "Server failed to process client's first message") + + // --- Step 3: Client receives server's message, computes proof --- + clientFinal, err := client.ProcessServerFirstMessage(serverFirst) + require.NoError(t, err, "Client failed to process server's first message") + + // --- Step 4: Server receives client's proof and verifies it --- + serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof) + require.NoError(t, err, "Server failed to verify client's final proof") + assert.NotEmpty(t, serverFinal.ServerSignature, "Server signature should not be empty") + + // --- Step 5: Client verifies server's signature (mutual authentication) --- + err = client.VerifyServerFinalMessage(serverFinal) + assert.NoError(t, err, "Client failed to verify server's final signature") + + t.Log("SCRAM full roundtrip successful") +} + +// TestScram_FullRoundtrip_WrongPassword ensures authentication fails with an incorrect password. +func TestScram_FullRoundtrip_WrongPassword(t *testing.T) { + server, username, _, _ := setupScramTest(t) + // Create a client with the WRONG password + client := NewScramClient(username, "WrongPassword!!!") + + // Steps 1-3 will appear to succeed, as the client doesn't know the password is wrong yet. + clientFirst, err := client.StartAuthentication() + require.NoError(t, err) + + serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce) + require.NoError(t, err) + + clientFinal, err := client.ProcessServerFirstMessage(serverFirst) + require.NoError(t, err) + + // --- Step 4: Server verification should fail here --- + _, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof) + assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should reject proof from wrong password") + + t.Log("SCRAM correctly failed for wrong password") +} + +// TestScram_FullRoundtrip_UserNotFound tests for user enumeration protection. +// The server should not reveal whether a user exists or not in its first message. +func TestScram_FullRoundtrip_UserNotFound(t *testing.T) { + server, _, _, _ := setupScramTest(t) + client := NewScramClient("unknown_user", "any_password") + + clientFirst, err := client.StartAuthentication() + require.NoError(t, err) + + // --- Step 2: Server should return an error, but also a FAKE response --- + // This prevents an attacker from knowing if the user exists based on the response structure. + serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce) + assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should return an error for an unknown user") + assert.NotEmpty(t, serverFirst.FullNonce, "Server must still provide a nonce to prevent enumeration") + assert.NotEmpty(t, serverFirst.Salt, "Server must still provide a salt to prevent enumeration") + + t.Log("SCRAM correctly protected against user enumeration") +} + +// TestScram_InvalidNonce simulates a replay attack or message mismatch. +func TestScram_InvalidNonce(t *testing.T) { + server, username, password, _ := setupScramTest(t) + client := NewScramClient(username, password) + + // Perform the first part of the handshake + clientFirst, _ := client.StartAuthentication() + serverFirst, _ := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce) + clientFinal, _ := client.ProcessServerFirstMessage(serverFirst) + + // Attempt to finalize with a completely different nonce + _, err := server.ProcessClientFinalMessage("this-is-a-bad-nonce", clientFinal.ClientProof) + assert.ErrorIs(t, err, ErrSCRAMInvalidNonce, "Server should reject a final message with an unknown nonce") +} + +// TestScram_CredentialImportExport verifies that credentials can be serialized and deserialized correctly. +func TestScram_CredentialImportExport(t *testing.T) { + _, _, _, originalCred := setupScramTest(t) + + // Export the credential to a map + exportedData := originalCred.Export() + require.NotNil(t, exportedData) + + // Assert that required fields exist and are strings (as they are base64 encoded) + assert.IsType(t, "", exportedData["salt"]) + assert.IsType(t, "", exportedData["stored_key"]) + assert.IsType(t, "", exportedData["server_key"]) + + // Import the credential back from the map + importedCred, err := ImportCredential(exportedData) + require.NoError(t, err) + require.NotNil(t, importedCred) + + // Verify that the imported credential is identical to the original + assert.Equal(t, originalCred.Username, importedCred.Username) + assert.Equal(t, originalCred.Salt, importedCred.Salt) + assert.Equal(t, originalCred.ArgonTime, importedCred.ArgonTime) + assert.Equal(t, originalCred.ArgonMemory, importedCred.ArgonMemory) + assert.Equal(t, originalCred.ArgonThreads, importedCred.ArgonThreads) + assert.Equal(t, originalCred.StoredKey, importedCred.StoredKey) + assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey) + + t.Log("SCRAM credential import/export successful") +} \ No newline at end of file