v0.2.0 restructured and generalized to be more modular, added golang-jwt dependency

This commit is contained in:
2025-11-03 15:11:40 -05:00
parent 3a662862d7
commit 3471030edd
14 changed files with 760 additions and 482 deletions

View File

@ -1,38 +1,45 @@
# Auth Package
Pluggable authentication utilities for Go applications.
Modular authentication utilities for Go applications.
## Features
- **Password Hashing**: Argon2id with PHC format
- **JWT**: HS256/RS256 token generation and validation
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF
- **HTTP Auth**: Basic/Bearer header parsing
- **Password Hashing**: Standalone Argon2id hashing with PHC format.
- **JWT**: HS256/RS256 token management via a simple facade over `golang-jwt`.
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF.
- **HTTP Auth**: Helpers for parsing Basic and Bearer authentication headers.
## Usage
```go
// Argon2 Password Hashing
hash, _ := auth.HashPassword("password123")
err := auth.VerifyPassword("password123", hash)
// JWT with HS256
auth, _ := auth.NewAuthenticator([]byte("32-byte-secret-key..."))
token, _ := auth.GenerateToken("user123", map[string]interface{}{"role": "admin"})
userID, claims, _ := auth.ValidateToken(token)
jwtMgr, _ := auth.NewJWT([]byte("a-very-secure-32-byte-secret-key"))
token, _ := jwtMgr.GenerateToken("user123", map[string]any{"role": "admin"})
userID, claims, _ := jwtMgr.ValidateToken(token)
// SCRAM authentication
server := auth.NewScramServer()
cred, _ := auth.DeriveCredential("user", "password", salt, 1, 65536, 4)
phcHash, _ := auth.HashPassword("password123")
cred, _ := auth.MigrateFromPHC("user", "password123", phcHash)
server.AddCredential(cred)
```
## Package Structure
- `interfaces.go` - Core interfaces
- `jwt.go` - JWT token operations
- `argon2.go` - Password hashing
- `scram.go` - SCRAM-SHA256 protocol
- `token.go` - Token validation utilities
- `http.go` - HTTP header parsing
- `errors.go` - Error definitions
- `doc.go` - Overview and package documentation
- `argon2.go` - Standalone Argon2id password hashing
- `jwt.go` - JWT manager (HS256/RS256) wrapping `golang-jwt`
- `scram.go` - SCRAM-SHA256 client/server protocol
- `http.go` - HTTP Basic/Bearer header parsing
- `token.go` - Simple in-memory token validator
- `error.go` - Centralized error definitions
## Testing
```bash
go test -v ./auth
go test -v ./
```

View File

@ -13,40 +13,85 @@ import (
// Default Argon2id parameters
const (
DefaultArgonTime = 3 // iterations (reduce for faster but less secure auth)
DefaultArgonTime = 3 // iterations
DefaultArgonMemory = 64 * 1024 // 64 MB
DefaultArgonThreads = 4
DefaultArgonSaltLen = 16
DefaultArgonKeyLen = 32
)
// HashPassword creates an Argon2id PHC-format hash
func (a *Authenticator) HashPassword(password string) (string, error) {
// argonParams holds configurable Argon2id parameters
type argonParams struct {
time uint32
memory uint32
threads uint8
keyLen uint32
saltLen uint32
}
// Option configures Argon2id hashing parameters
type Option func(*argonParams)
// WithTime sets Argon2 iterations
func WithTime(t uint32) Option {
return func(p *argonParams) {
if t > 0 {
p.time = t
}
}
}
// WithMemory sets Argon2 memory in KiB
func WithMemory(m uint32) Option {
return func(p *argonParams) {
if m > 0 {
p.memory = m
}
}
}
// WithThreads sets Argon2 parallelism
func WithThreads(t uint8) Option {
return func(p *argonParams) {
if t > 0 {
p.threads = t
}
}
}
// HashPassword creates Argon2id PHC-format hash (standalone)
func HashPassword(password string, opts ...Option) (string, error) {
if len(password) < 8 {
return "", ErrWeakPassword
}
// Generate salt
salt := make([]byte, DefaultArgonSaltLen)
params := &argonParams{
time: DefaultArgonTime,
memory: DefaultArgonMemory,
threads: DefaultArgonThreads,
keyLen: DefaultArgonKeyLen,
saltLen: DefaultArgonSaltLen,
}
for _, opt := range opts {
opt(params)
}
salt := make([]byte, params.saltLen)
if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err)
}
// Derive key using Argon2id
hash := argon2.IDKey([]byte(password), salt, a.argonTime, a.argonMemory, a.argonThreads, DefaultArgonKeyLen)
hash := argon2.IDKey([]byte(password), salt, params.time, params.memory, params.threads, params.keyLen)
// Construct PHC format
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, a.argonMemory, a.argonTime, a.argonThreads, saltB64, hashB64)
return phcHash, nil
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, params.memory, params.time, params.threads, saltB64, hashB64), nil
}
// VerifyPassword checks password against PHC-format hash
func (a *Authenticator) VerifyPassword(password, phcHash string) error {
// Parse PHC format
// VerifyPassword checks password against PHC-format hash (standalone)
func VerifyPassword(password, phcHash string) error {
parts := strings.Split(phcHash, "$")
if len(parts) != 6 || parts[1] != "argon2id" {
return ErrPHCInvalidFormat
@ -66,10 +111,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error {
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
}
// Compute hash with same parameters
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
// Constant-time comparison
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
return ErrInvalidCredentials
}
@ -77,9 +120,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error {
return nil
}
// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential
// MigrateFromPHC converts PHC hash to SCRAM credential
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
// Parse PHC format
parts := strings.Split(phcHash, "$")
if len(parts) != 6 || parts[1] != "argon2id" {
return nil, ErrPHCInvalidFormat
@ -94,17 +136,10 @@ func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
return nil, ErrPHCInvalidSalt
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return nil, ErrPHCInvalidHash
// Use standalone function for verification
if err := VerifyPassword(password, phcHash); err != nil {
return nil, err
}
// Verify password against hash
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
return nil, ErrInvalidCredentials
}
// Derive SCRAM credential with same parameters
return DeriveCredential(username, password, salt, time, memory, threads)
}

View File

@ -11,13 +11,10 @@ import (
)
func TestPasswordHashing(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err, "Failed to create authenticator")
password := "testPassword123"
// Test hashing
hash, err := auth.HashPassword(password)
// Test hashing with default parameters
hash, err := HashPassword(password)
require.NoError(t, err, "Failed to hash password")
// Verify PHC format
@ -25,20 +22,30 @@ func TestPasswordHashing(t *testing.T) {
"Hash should have argon2id prefix, got: %s", hash)
// Test verification with correct password
err = auth.VerifyPassword(password, hash)
err = VerifyPassword(password, hash)
assert.NoError(t, err, "Failed to verify correct password")
// Test verification with incorrect password
err = auth.VerifyPassword("wrongPassword", hash)
err = VerifyPassword("wrongPassword", hash)
assert.Error(t, err, "Verification should fail for incorrect password")
assert.Equal(t, ErrInvalidCredentials, err)
// Test weak password
_, err = auth.HashPassword("weak")
_, err = HashPassword("weak")
assert.Equal(t, ErrWeakPassword, err, "Should reject weak password")
// Test with custom options
hash, err = HashPassword(password,
WithTime(5),
WithMemory(128*1024),
WithThreads(8))
require.NoError(t, err)
err = VerifyPassword(password, hash)
assert.NoError(t, err)
// Test malformed PHC hash
err = auth.VerifyPassword(password, "$invalid$format")
err = VerifyPassword(password, "$invalid$format")
assert.Error(t, err, "Should reject malformed hash")
// Test corrupted salt
@ -47,33 +54,27 @@ func TestPasswordHashing(t *testing.T) {
if len(parts) == 6 {
parts[4] = "invalid!base64"
corruptedHash = strings.Join(parts, "$")
err = auth.VerifyPassword(password, corruptedHash)
err = VerifyPassword(password, corruptedHash)
assert.Error(t, err, "Should reject corrupted salt")
}
}
func TestEmptyPasswordAfterValidation(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err)
// Empty password should be rejected by length check
_, err = auth.HashPassword("")
_, err := HashPassword("")
assert.Equal(t, ErrWeakPassword, err)
// 8-character password should pass
hash, err := auth.HashPassword("12345678")
hash, err := HashPassword("12345678")
require.NoError(t, err)
err = auth.VerifyPassword("12345678", hash)
err = VerifyPassword("12345678", hash)
assert.NoError(t, err)
}
func TestConcurrentPasswordOperations(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err)
password := "testPassword123"
hash, err := auth.HashPassword(password)
hash, err := HashPassword(password)
require.NoError(t, err)
// Test concurrent verification
@ -82,7 +83,7 @@ func TestConcurrentPasswordOperations(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
err := auth.VerifyPassword(password, hash)
err := VerifyPassword(password, hash)
assert.NoError(t, err)
}()
}
@ -90,14 +91,11 @@ func TestConcurrentPasswordOperations(t *testing.T) {
}
func TestPHCMigration(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err)
password := "testPassword123"
username := "migrationUser"
// Generate PHC hash
phcHash, err := auth.HashPassword(password)
phcHash, err := HashPassword(password)
require.NoError(t, err)
// Migrate to SCRAM credential

71
auth.go
View File

@ -1,71 +0,0 @@
// FILE: auth/auth.go
package auth
import (
"crypto/rsa"
"fmt"
)
// Authenticator provides password hashing and JWT operations
type Authenticator struct {
algorithm string
jwtSecret []byte // For HS256
privateKey *rsa.PrivateKey // For RS256
publicKey *rsa.PublicKey // For RS256
argonTime uint32
argonMemory uint32
argonThreads uint8
}
// NewAuthenticator creates a new authenticator with specified algorithm
func NewAuthenticator(key any, algorithm ...string) (*Authenticator, error) {
alg := "HS256"
if len(algorithm) > 0 && algorithm[0] != "" {
alg = algorithm[0]
}
auth := &Authenticator{
algorithm: alg,
argonTime: DefaultArgonTime,
argonMemory: DefaultArgonMemory,
argonThreads: DefaultArgonThreads,
}
switch alg {
case "HS256":
secret, ok := key.([]byte)
if !ok {
return nil, ErrInvalidKeyType
}
if len(secret) < 32 {
return nil, ErrSecretTooShort
}
auth.jwtSecret = secret
case "RS256":
switch k := key.(type) {
case *rsa.PrivateKey:
auth.privateKey = k
auth.publicKey = &k.PublicKey
case *rsa.PublicKey:
auth.publicKey = k
case []byte:
// Try parsing as PEM
if privKey, err := parseRSAPrivateKey(k); err == nil {
auth.privateKey = privKey
auth.publicKey = &privKey.PublicKey
} else if pubKey, err := parseRSAPublicKey(k); err == nil {
auth.publicKey = pubKey
} else {
return nil, fmt.Errorf("failed to parse RSA key: %w", err)
}
default:
return nil, ErrInvalidKeyType
}
default:
return nil, ErrInvalidAlgorithm
}
return auth, nil
}

View File

@ -1,68 +0,0 @@
// FILE: auth/auth_test.go
package auth
import (
"crypto/rand"
"crypto/rsa"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAuthenticator(t *testing.T) {
// Test HS256 creation
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err, "Failed to create HS256 authenticator")
assert.Equal(t, "HS256", auth.algorithm)
// Test with short secret
_, err = NewAuthenticator([]byte("short"))
assert.Equal(t, ErrSecretTooShort, err)
// Test RS256 with private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
authRS, err := NewAuthenticator(privateKey, "RS256")
require.NoError(t, err)
assert.Equal(t, "RS256", authRS.algorithm)
assert.NotNil(t, authRS.privateKey)
assert.NotNil(t, authRS.publicKey)
// Test RS256 with public key only
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
require.NoError(t, err)
assert.Equal(t, "RS256", authPub.algorithm)
assert.Nil(t, authPub.privateKey)
assert.NotNil(t, authPub.publicKey)
// Test invalid algorithm
_, err = NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"), "INVALID")
assert.Equal(t, ErrInvalidAlgorithm, err)
// Test invalid key type for HS256
_, err = NewAuthenticator(privateKey, "HS256")
assert.Equal(t, ErrInvalidKeyType, err)
}
func TestInterfaceCompliance(t *testing.T) {
// Verify Authenticator implements AuthenticatorInterface
auth, _ := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
var _ AuthenticatorInterface = auth
// Test interface methods work
hash, err := auth.HashPassword("testpass123")
require.NoError(t, err)
err = auth.VerifyPassword("testpass123", hash)
assert.NoError(t, err)
token, err := auth.GenerateToken("user1", nil)
require.NoError(t, err)
userID, _, err := auth.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, "user1", userID)
}

53
doc.go Normal file
View File

@ -0,0 +1,53 @@
// FILE: auth/doc.go
package auth
/*
Package auth provides modular authentication components:
# Argon2 Password Hashing
Standalone password hashing using Argon2id:
hash, err := auth.HashPassword("password123")
err = auth.VerifyPassword("password123", hash)
// With custom parameters
hash, err := auth.HashPassword("password123",
auth.WithTime(5),
auth.WithMemory(128*1024))
# JWT Token Management
JSON Web Token generation and validation:
// HS256 (symmetric)
jwtMgr, _ := auth.NewJWT(secret)
token, _ := jwtMgr.GenerateToken("user1", claims)
userID, claims, _ := jwtMgr.ValidateToken(token)
// RS256 (asymmetric)
jwtMgr, _ := auth.NewJWTRSA(privateKey)
// One-off operations
token, _ := auth.GenerateHS256Token(secret, "user1", claims, 1*time.Hour)
# SCRAM-SHA256 Authentication
Server and client implementation for SCRAM:
// Server
server := auth.NewScramServer()
server.AddCredential(credential)
// Client
client := auth.NewScramClient(username, password)
# HTTP Authentication Parsing
Utility functions for HTTP headers:
username, password, _ := auth.ParseBasicAuth(header)
token, _ := auth.ParseBearerToken(header)
Each module can be used independently without initializing other components.
*/

View File

@ -10,8 +10,6 @@ import (
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrWeakPassword = errors.New("password must be at least 8 characters")
ErrInvalidAlgorithm = errors.New("invalid algorithm")
ErrInvalidKeyType = errors.New("invalid key type for algorithm")
)
// JWT-specific errors
@ -22,9 +20,6 @@ var (
ErrTokenInvalidSignature = errors.New("token: invalid signature")
ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch")
ErrTokenMissingClaim = errors.New("token: missing required claim")
ErrTokenInvalidHeader = errors.New("token: invalid header encoding")
ErrTokenInvalidClaims = errors.New("token: invalid claims encoding")
ErrTokenInvalidJSON = errors.New("token: malformed JSON")
ErrTokenEmptyUserID = errors.New("token: empty user ID")
ErrTokenNoPrivateKey = errors.New("token: private key required for signing")
ErrTokenNoPublicKey = errors.New("token: public key required for verification")

3
go.mod
View File

@ -1,8 +1,9 @@
module auth
module github.com/lixenwraith/auth
go 1.25.3
require (
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.43.0
)

2
go.sum
View File

@ -1,5 +1,7 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=

View File

@ -1,17 +0,0 @@
// FILE: auth/interface.go
package auth
// AuthenticatorInterface defines the authentication operations
type AuthenticatorInterface interface {
HashPassword(password string) (hash string, err error)
VerifyPassword(password, hash string) (err error)
GenerateToken(userID string, claims map[string]any) (token string, err error)
ValidateToken(token string) (userID string, claims map[string]any, err error)
}
// TokenValidator validates bearer tokens
type TokenValidator interface {
ValidateToken(token string) (valid bool)
AddToken(token string)
RemoveToken(token string)
}

376
jwt.go
View File

@ -2,179 +2,289 @@
package auth
import (
"crypto"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/subtle"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
// GenerateToken creates a JWT token with user claims
func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) {
// JWT configuration defaults
const (
DefaultTokenLifetime = 24 * time.Hour
DefaultLeeway = 5 * time.Minute
)
// customClaims extends RegisteredClaims with arbitrary user data
type customClaims struct {
jwt.RegisteredClaims
Extra map[string]any `json:"extra,omitempty"`
}
// JWT manages token generation and validation
type JWT struct {
algorithm jwt.SigningMethod
signKey any // []byte for HMAC, *rsa.PrivateKey for RSA
verifyKey any // []byte for HMAC, *rsa.PublicKey for RSA
tokenLifetime time.Duration
leeway time.Duration
issuer string
audience []string
}
// JWTOption configures JWT behavior
type JWTOption func(*JWT)
// WithTokenLifetime sets token expiration duration
func WithTokenLifetime(d time.Duration) JWTOption {
return func(j *JWT) {
if d > 0 {
j.tokenLifetime = d
}
}
}
// WithLeeway sets clock skew tolerance
func WithLeeway(d time.Duration) JWTOption {
return func(j *JWT) {
if d >= 0 {
j.leeway = d
}
}
}
// WithIssuer sets token issuer claim
func WithIssuer(iss string) JWTOption {
return func(j *JWT) {
j.issuer = iss
}
}
// WithAudience sets token audience claim
func WithAudience(aud []string) JWTOption {
return func(j *JWT) {
j.audience = aud
}
}
// NewJWT creates JWT manager for HS256 (symmetric)
func NewJWT(secret []byte, opts ...JWTOption) (*JWT, error) {
if len(secret) < 32 {
return nil, ErrSecretTooShort
}
j := &JWT{
algorithm: jwt.SigningMethodHS256,
signKey: secret,
verifyKey: secret,
tokenLifetime: DefaultTokenLifetime,
leeway: DefaultLeeway,
}
for _, opt := range opts {
opt(j)
}
return j, nil
}
// NewJWTRSA creates JWT manager for RS256 (asymmetric)
func NewJWTRSA(privateKey *rsa.PrivateKey, opts ...JWTOption) (*JWT, error) {
if privateKey == nil {
return nil, ErrTokenNoPrivateKey
}
j := &JWT{
algorithm: jwt.SigningMethodRS256,
signKey: privateKey,
verifyKey: &privateKey.PublicKey,
tokenLifetime: DefaultTokenLifetime,
leeway: DefaultLeeway,
}
for _, opt := range opts {
opt(j)
}
return j, nil
}
// NewJWTVerifier creates JWT manager for verification only (RS256)
func NewJWTVerifier(publicKey *rsa.PublicKey, opts ...JWTOption) (*JWT, error) {
if publicKey == nil {
return nil, ErrTokenNoPublicKey
}
j := &JWT{
algorithm: jwt.SigningMethodRS256,
signKey: nil, // Cannot sign
verifyKey: publicKey,
tokenLifetime: DefaultTokenLifetime,
leeway: DefaultLeeway,
}
for _, opt := range opts {
opt(j)
}
return j, nil
}
// GenerateToken creates signed JWT with claims
func (j *JWT) GenerateToken(userID string, claims map[string]any) (string, error) {
if userID == "" {
return "", ErrTokenEmptyUserID
}
if a.algorithm == "RS256" && a.privateKey == nil {
if j.signKey == nil {
return "", ErrTokenNoPrivateKey
}
// Build JWT claims
now := time.Now()
jwtClaims := map[string]any{
"sub": userID,
"iat": now.Unix(),
"exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry
registeredClaims := jwt.RegisteredClaims{
Subject: userID,
Issuer: j.issuer,
Audience: j.audience,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(j.tokenLifetime)),
NotBefore: jwt.NewNumericDate(now),
}
// Reserved claims that cannot be overridden
reservedClaims := map[string]bool{
"sub": true, "iat": true, "exp": true, "nbf": true,
"iss": true, "aud": true, "jti": true, "typ": true,
"alg": true,
token := jwt.NewWithClaims(j.algorithm, customClaims{
RegisteredClaims: registeredClaims,
Extra: claims,
})
return token.SignedString(j.signKey)
}
// Merge custom claims
for k, v := range claims {
if !reservedClaims[k] {
jwtClaims[k] = v
}
}
// ValidateToken verifies JWT and extracts claims
func (j *JWT) ValidateToken(tokenString string) (string, map[string]any, error) {
parser := jwt.NewParser(
jwt.WithLeeway(j.leeway),
jwt.WithAudience(j.audience...),
jwt.WithIssuer(j.issuer),
jwt.WithValidMethods([]string{j.algorithm.Alg()}),
jwt.WithExpirationRequired(),
)
// Create JWT header
header := map[string]any{
"alg": a.algorithm,
"typ": "JWT",
}
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
// Algorithm already validated by WithValidMethods
return j.verifyKey, nil
})
// Encode header and payload
headerJSON, _ := json.Marshal(header)
claimsJSON, _ := json.Marshal(jwtClaims)
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
// Create signature
signingInput := headerB64 + "." + claimsB64
var signature string
switch a.algorithm {
case "HS256":
h := hmac.New(sha256.New, a.jwtSecret)
h.Write([]byte(signingInput))
signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
case "RS256":
hashed := sha256.Sum256([]byte(signingInput))
sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:])
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
signature = base64.RawURLEncoding.EncodeToString(sig)
return "", nil, mapJWTError(err)
}
// Combine to form JWT
token := signingInput + "." + signature
return token, nil
}
// ValidateToken verifies JWT and returns userID and claims
func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) {
// Split token
parts := strings.Split(token, ".")
if len(parts) != 3 {
claims, ok := token.Claims.(*customClaims)
if !ok || !token.Valid {
return "", nil, ErrTokenMalformed
}
// Decode header to check algorithm
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
return claims.Subject, claims.Extra, nil
}
// mapJWTError translates jwt library errors to auth package errors
func mapJWTError(err error) error {
switch {
case errors.Is(err, jwt.ErrTokenMalformed):
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
case errors.Is(err, jwt.ErrTokenUnverifiable):
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
return fmt.Errorf("%w : %w", ErrTokenInvalidSignature, err)
case errors.Is(err, jwt.ErrTokenExpired):
return fmt.Errorf("%w : %w", ErrTokenExpired, err)
case errors.Is(err, jwt.ErrTokenNotValidYet):
return fmt.Errorf("%w : %w", ErrTokenNotYetValid, err)
case errors.Is(err, jwt.ErrTokenInvalidAudience):
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
case errors.Is(err, jwt.ErrTokenInvalidIssuer):
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
default:
// Check for algorithm mismatch in error message
if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
return fmt.Errorf("%w : %w", ErrTokenAlgorithmMismatch, err)
}
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
}
}
// Standalone helper functions for one-off operations
// GenerateHS256Token creates HS256 JWT without manager instance
func GenerateHS256Token(secret []byte, userID string, claims map[string]any, lifetime time.Duration) (string, error) {
if len(secret) < 32 {
return "", ErrSecretTooShort
}
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, customClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(lifetime)),
},
Extra: claims,
})
return token.SignedString(secret)
}
// ValidateHS256Token verifies HS256 JWT without manager instance
func ValidateHS256Token(secret []byte, tokenString string) (string, map[string]any, error) {
if len(secret) < 32 {
return "", nil, ErrSecretTooShort
}
parser := jwt.NewParser(
jwt.WithValidMethods([]string{"HS256"}),
jwt.WithLeeway(DefaultLeeway),
)
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
return secret, nil
})
if err != nil {
return "", nil, ErrTokenInvalidHeader
return "", nil, mapJWTError(err)
}
var header map[string]any
if err = json.Unmarshal(headerJSON, &header); err != nil {
return "", nil, ErrTokenInvalidJSON
claims, ok := token.Claims.(*customClaims)
if !ok || !token.Valid {
return "", nil, ErrTokenMalformed
}
// Verify algorithm matches
if alg, ok := header["alg"].(string); !ok || alg != a.algorithm {
return "", nil, ErrTokenAlgorithmMismatch
return claims.Subject, claims.Extra, nil
}
// Verify signature
signingInput := parts[0] + "." + parts[1]
// RSA Utilities
switch a.algorithm {
case "HS256":
h := hmac.New(sha256.New, a.jwtSecret)
h.Write([]byte(signingInput))
expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 {
return "", nil, ErrTokenInvalidSignature
}
case "RS256":
if a.publicKey == nil {
return "", nil, ErrTokenNoPublicKey
}
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
// NewJWTRSAFromPEM creates a JWT manager for RS256 from raw PEM-encoded private key data.
func NewJWTRSAFromPEM(privateKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
privateKey, err := parseRSAPrivateKey(privateKeyPEM)
if err != nil {
return "", nil, ErrTokenInvalidSignature
return nil, err
}
// Call the original constructor with the now-parsed key
return NewJWTRSA(privateKey, opts...)
}
hashed := sha256.Sum256([]byte(signingInput))
if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil {
return "", nil, ErrTokenInvalidSignature
}
}
// Decode claims
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
// NewJWTVerifierFromPEM creates a JWT manager for verification from raw PEM-encoded public key data.
func NewJWTVerifierFromPEM(publicKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
publicKey, err := parseRSAPublicKey(publicKeyPEM)
if err != nil {
return "", nil, ErrTokenInvalidClaims
return nil, err
}
// Call the original constructor with the now-parsed key
return NewJWTVerifier(publicKey, opts...)
}
var claims map[string]any
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
return "", nil, ErrTokenInvalidJSON
}
// Check expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return "", nil, ErrTokenExpired
}
}
// Check not before
if nbf, ok := claims["nbf"].(float64); ok {
if time.Now().Unix() < int64(nbf) {
return "", nil, ErrTokenNotYetValid
}
}
// Extract userID
userID, ok := claims["sub"].(string)
if !ok {
return "", nil, ErrTokenMissingClaim
}
return userID, claims, nil
}
// parseRSAPrivateKey parses PEM encoded RSA private key
// parseRSAPrivateKey parses a PEM-encoded RSA private key.
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
@ -187,7 +297,7 @@ func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
return key, nil
}
// parseRSAPublicKey parses PEM encoded RSA public key
// parseRSAPublicKey parses a PEM-encoded RSA public key.
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {

View File

@ -4,19 +4,20 @@ package auth
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"errors"
"fmt"
"crypto/x509"
"encoding/pem"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestJWTHS256(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
secret := []byte("test-secret-key-must-be-32-bytes")
jwtMgr, err := NewJWT(secret)
require.NoError(t, err)
userID := "user123"
@ -26,54 +27,17 @@ func TestJWTHS256(t *testing.T) {
}
// Generate token
token, err := auth.GenerateToken(userID, claims)
require.NoError(t, err, "Failed to generate token")
token, err := jwtMgr.GenerateToken(userID, claims)
require.NoError(t, err)
assert.NotEmpty(t, token)
// Validate token
extractedUserID, extractedClaims, err := auth.ValidateToken(token)
require.NoError(t, err, "Failed to validate token")
extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, userID, extractedUserID)
assert.Equal(t, "test@example.com", extractedClaims["email"])
assert.Equal(t, "admin", extractedClaims["role"])
// Test invalid token
_, _, err = auth.ValidateToken("invalid.token.here")
assert.Error(t, err)
assert.True(t, errors.Is(err, ErrTokenInvalidJSON))
// Test tampered token
parts := strings.Split(token, ".")
require.Len(t, parts, 3, "JWT should have 3 parts")
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
_, _, err = auth.ValidateToken(tampered)
assert.Error(t, err)
assert.True(t, errors.Is(err, ErrTokenInvalidSignature))
// Test reserved claims cannot be overridden
overrideClaims := map[string]any{
"sub": "override",
"iat": 12345,
"exp": 67890,
"nbf": 11111,
"iss": "attacker",
"aud": "victim",
"jti": "fake",
}
token, err = auth.GenerateToken(userID, overrideClaims)
require.NoError(t, err)
extractedUserID, extractedClaims, err = auth.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, userID, extractedUserID, "UserID should not be overridden")
assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden")
assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden")
assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims")
assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims")
}
func TestJWTRS256(t *testing.T) {
@ -82,95 +46,214 @@ func TestJWTRS256(t *testing.T) {
require.NoError(t, err)
// Test with private key (can sign and verify)
authPriv, err := NewAuthenticator(privateKey, "RS256")
jwtMgr, err := NewJWTRSA(privateKey)
require.NoError(t, err)
userID := "user456"
claims := map[string]any{
"email": "rs256@example.com",
"scope": "read:all",
}
// Generate token
token, err := authPriv.GenerateToken(userID, claims)
token, err := jwtMgr.GenerateToken(userID, claims)
require.NoError(t, err)
assert.NotEmpty(t, token)
// Validate token with private key auth (has public key too)
extractedUserID, extractedClaims, err := authPriv.ValidateToken(token)
// Validate with same manager
extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, userID, extractedUserID)
assert.Equal(t, "rs256@example.com", extractedClaims["email"])
assert.Equal(t, "read:all", extractedClaims["scope"])
// Test with public key only (can only verify)
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
// Test with verifier only (public key)
verifier, err := NewJWTVerifier(&privateKey.PublicKey)
require.NoError(t, err)
// Should be able to validate token
extractedUserID, extractedClaims, err = authPub.ValidateToken(token)
// Should validate token
extractedUserID, _, err = verifier.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, userID, extractedUserID)
// Should not be able to generate token
_, err = authPub.GenerateToken(userID, claims)
assert.Error(t, err, "Public key only auth should not generate tokens")
// Should not generate token
_, err = verifier.GenerateToken(userID, claims)
assert.Equal(t, ErrTokenNoPrivateKey, err)
// Test algorithm mismatch
authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err)
_, _, err = authHS256.ValidateToken(token)
assert.Error(t, err, "HS256 auth should not validate RS256 token")
// assert.True(t, errors.Is(err, ErrInvalidToken))
assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch))
fmt.Println(err)
}
func TestExpiredToken(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
func TestJWTOptions(t *testing.T) {
secret := []byte("test-secret-key-must-be-32-bytes")
// Test custom lifetime
jwtMgr, err := NewJWT(secret,
WithTokenLifetime(1*time.Hour),
WithIssuer("test-issuer"),
WithAudience([]string{"api.example.com"}),
)
require.NoError(t, err)
userID := "user123"
// Generate normal token (should have 7 days expiry)
token, err := auth.GenerateToken(userID, nil)
token, err := jwtMgr.GenerateToken("user1", nil)
require.NoError(t, err)
_, extractedClaims, err := auth.ValidateToken(token)
require.NoError(t, err)
// Parse token to check claims
parsed, _ := jwt.Parse(token, func(token *jwt.Token) (any, error) {
return secret, nil
})
// Check expiry is in future (approximately 7 days)
expiry := extractedClaims["exp"].(float64)
now := time.Now().Unix()
claims := parsed.Claims.(jwt.MapClaims)
assert.Greater(t, expiry, float64(now), "Token expiry should be in future")
assert.InDelta(t, expiry, float64(now+7*24*60*60), 10,
"Token expiry should be approximately 7 days from now")
// Check issuer
assert.Equal(t, "test-issuer", claims["iss"])
// Check audience
aud := claims["aud"].([]any)
assert.Contains(t, aud, "api.example.com")
// Check expiration is ~1 hour
exp := int64(claims["exp"].(float64))
iat := int64(claims["iat"].(float64))
assert.InDelta(t, 3600, exp-iat, 10)
}
func TestCorruptJWTParts(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
func TestJWTErrors(t *testing.T) {
secret := []byte("test-secret-key-must-be-32-bytes")
jwtMgr, err := NewJWT(secret)
require.NoError(t, err)
// Test with missing parts
_, _, err = auth.ValidateToken("only.two")
assert.True(t, errors.Is(err, ErrTokenMalformed))
// Empty user ID
_, err = jwtMgr.GenerateToken("", nil)
assert.Equal(t, ErrTokenEmptyUserID, err)
// Test with invalid header encoding
_, _, err = auth.ValidateToken("not-base64!.valid.valid")
assert.Error(t, err)
// Invalid token format
_, _, err = jwtMgr.ValidateToken("invalid.token")
assert.ErrorIs(t, err, ErrTokenMalformed)
// Test with invalid claims encoding
validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
_, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid")
assert.Error(t, err)
// Tampered signature
token, _ := jwtMgr.GenerateToken("user1", nil)
parts := strings.Split(token, ".")
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
_, _, err = jwtMgr.ValidateToken(tampered)
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
// Test with invalid JSON in claims
invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json"))
_, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature")
assert.Error(t, err)
// Wrong algorithm
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
rsaMgr, _ := NewJWTRSA(rsaKey)
rsaToken, _ := rsaMgr.GenerateToken("user1", nil)
_, _, err = jwtMgr.ValidateToken(rsaToken)
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
}
func TestJWTExpiration(t *testing.T) {
secret := []byte("test-secret-key-must-be-32-bytes")
// Create token with 1 second lifetime
jwtMgr, err := NewJWT(secret, WithTokenLifetime(1*time.Second), WithLeeway(0))
require.NoError(t, err)
token, err := jwtMgr.GenerateToken("user1", nil)
require.NoError(t, err)
// Should be valid immediately
_, _, err = jwtMgr.ValidateToken(token)
assert.NoError(t, err)
// Wait for expiration
time.Sleep(2 * time.Second)
// Should be expired
_, _, err = jwtMgr.ValidateToken(token)
assert.ErrorIs(t, err, ErrTokenExpired)
}
func TestLeeway(t *testing.T) {
secret := []byte("test-secret-key-must-be-32-bytes")
// Create manager with no leeway
jwtMgr, err := NewJWT(secret, WithLeeway(0))
require.NoError(t, err)
// Manually create a token with NotBefore in future
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "user1",
"nbf": now.Add(2 * time.Second).Unix(),
"exp": now.Add(1 * time.Hour).Unix(),
})
tokenString, err := token.SignedString(secret)
require.NoError(t, err)
// Should fail immediately (not valid yet)
_, _, err = jwtMgr.ValidateToken(tokenString)
assert.ErrorIs(t, err, ErrTokenNotYetValid)
// Create manager with leeway
jwtMgrWithLeeway, err := NewJWT(secret, WithLeeway(5*time.Second))
require.NoError(t, err)
// Should pass with leeway
_, _, err = jwtMgrWithLeeway.ValidateToken(tokenString)
assert.NoError(t, err)
}
func TestStandaloneFunctions(t *testing.T) {
secret := []byte("test-secret-key-must-be-32-bytes")
userID := "standalone-user"
claims := map[string]any{"test": "value"}
// Generate token
token, err := GenerateHS256Token(secret, userID, claims, 1*time.Hour)
require.NoError(t, err)
// Validate token
extractedUserID, extractedClaims, err := ValidateHS256Token(secret, token)
require.NoError(t, err)
assert.Equal(t, userID, extractedUserID)
assert.Equal(t, "value", extractedClaims["test"])
// Test with short secret
_, err = GenerateHS256Token([]byte("short"), userID, claims, 1*time.Hour)
assert.Equal(t, ErrSecretTooShort, err)
}
func TestJWTRSAFromPEM(t *testing.T) {
// 1. Generate a new RSA key pair for this test
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
// 2. Encode the private key to PEM format
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
// 3. Encode the public key to PEM format
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
})
// 4. Test the PEM constructor for the signer
jwtMgr, err := NewJWTRSAFromPEM(privateKeyPEM)
require.NoError(t, err)
token, err := jwtMgr.GenerateToken("user-from-pem", nil)
require.NoError(t, err)
assert.NotEmpty(t, token)
// 5. Test the PEM constructor for the verifier
verifier, err := NewJWTVerifierFromPEM(publicKeyPEM)
require.NoError(t, err)
userID, _, err := verifier.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, "user-from-pem", userID)
// 6. Test failure cases with invalid data
_, err = NewJWTRSAFromPEM([]byte("invalid pem data"))
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
_, err = NewJWTVerifierFromPEM([]byte("invalid pem data"))
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
}

View File

@ -143,6 +143,10 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
return nil, ErrSCRAMSaltTooShort
}
if time == 0 || memory == 0 || threads == 0 {
return nil, ErrSCRAMZeroParams
}
// Derive salted password using Argon2id
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen)

146
scram_test.go Normal file
View File

@ -0,0 +1,146 @@
// FILE: auth/scram_test.go
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// setupScramTest is a helper to initialize a server and user credential for testing.
// It performs the full Argon2 -> SCRAM migration workflow.
func setupScramTest(t *testing.T) (server *ScramServer, username, password string, cred *Credential) {
username = "testuser"
password = "SecurePassword123"
// 1. Start with an Argon2 PHC hash, as a real application would.
phcHash, err := HashPassword(password)
require.NoError(t, err, "Setup failed: could not hash password")
// 2. Migrate the PHC hash to a SCRAM credential.
cred, err = MigrateFromPHC(username, password, phcHash)
require.NoError(t, err, "Setup failed: could not migrate from PHC hash")
// 3. Create a server and add the new credential.
server = NewScramServer()
server.AddCredential(cred)
return server, username, password, cred
}
// TestScram_FullRoundtrip_Success simulates a complete, successful authentication handshake.
func TestScram_FullRoundtrip_Success(t *testing.T) {
server, username, password, _ := setupScramTest(t)
client := NewScramClient(username, password)
// --- Step 1: Client sends its first message ---
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
// --- Step 2: Server receives client's message and responds ---
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err, "Server failed to process client's first message")
// --- Step 3: Client receives server's message, computes proof ---
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err, "Client failed to process server's first message")
// --- Step 4: Server receives client's proof and verifies it ---
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
require.NoError(t, err, "Server failed to verify client's final proof")
assert.NotEmpty(t, serverFinal.ServerSignature, "Server signature should not be empty")
// --- Step 5: Client verifies server's signature (mutual authentication) ---
err = client.VerifyServerFinalMessage(serverFinal)
assert.NoError(t, err, "Client failed to verify server's final signature")
t.Log("SCRAM full roundtrip successful")
}
// TestScram_FullRoundtrip_WrongPassword ensures authentication fails with an incorrect password.
func TestScram_FullRoundtrip_WrongPassword(t *testing.T) {
server, username, _, _ := setupScramTest(t)
// Create a client with the WRONG password
client := NewScramClient(username, "WrongPassword!!!")
// Steps 1-3 will appear to succeed, as the client doesn't know the password is wrong yet.
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err)
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err)
// --- Step 4: Server verification should fail here ---
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should reject proof from wrong password")
t.Log("SCRAM correctly failed for wrong password")
}
// TestScram_FullRoundtrip_UserNotFound tests for user enumeration protection.
// The server should not reveal whether a user exists or not in its first message.
func TestScram_FullRoundtrip_UserNotFound(t *testing.T) {
server, _, _, _ := setupScramTest(t)
client := NewScramClient("unknown_user", "any_password")
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
// --- Step 2: Server should return an error, but also a FAKE response ---
// This prevents an attacker from knowing if the user exists based on the response structure.
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should return an error for an unknown user")
assert.NotEmpty(t, serverFirst.FullNonce, "Server must still provide a nonce to prevent enumeration")
assert.NotEmpty(t, serverFirst.Salt, "Server must still provide a salt to prevent enumeration")
t.Log("SCRAM correctly protected against user enumeration")
}
// TestScram_InvalidNonce simulates a replay attack or message mismatch.
func TestScram_InvalidNonce(t *testing.T) {
server, username, password, _ := setupScramTest(t)
client := NewScramClient(username, password)
// Perform the first part of the handshake
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
clientFinal, _ := client.ProcessServerFirstMessage(serverFirst)
// Attempt to finalize with a completely different nonce
_, err := server.ProcessClientFinalMessage("this-is-a-bad-nonce", clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrSCRAMInvalidNonce, "Server should reject a final message with an unknown nonce")
}
// TestScram_CredentialImportExport verifies that credentials can be serialized and deserialized correctly.
func TestScram_CredentialImportExport(t *testing.T) {
_, _, _, originalCred := setupScramTest(t)
// Export the credential to a map
exportedData := originalCred.Export()
require.NotNil(t, exportedData)
// Assert that required fields exist and are strings (as they are base64 encoded)
assert.IsType(t, "", exportedData["salt"])
assert.IsType(t, "", exportedData["stored_key"])
assert.IsType(t, "", exportedData["server_key"])
// Import the credential back from the map
importedCred, err := ImportCredential(exportedData)
require.NoError(t, err)
require.NotNil(t, importedCred)
// Verify that the imported credential is identical to the original
assert.Equal(t, originalCred.Username, importedCred.Username)
assert.Equal(t, originalCred.Salt, importedCred.Salt)
assert.Equal(t, originalCred.ArgonTime, importedCred.ArgonTime)
assert.Equal(t, originalCred.ArgonMemory, importedCred.ArgonMemory)
assert.Equal(t, originalCred.ArgonThreads, importedCred.ArgonThreads)
assert.Equal(t, originalCred.StoredKey, importedCred.StoredKey)
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
t.Log("SCRAM credential import/export successful")
}