v0.2.0 restructured and generalized to be more modular, added golang-jwt dependency

This commit is contained in:
2025-11-03 15:11:40 -05:00
parent 3a662862d7
commit 3471030edd
14 changed files with 760 additions and 482 deletions

View File

@ -1,38 +1,45 @@
# Auth Package # Auth Package
Pluggable authentication utilities for Go applications. Modular authentication utilities for Go applications.
## Features ## Features
- **Password Hashing**: Argon2id with PHC format - **Password Hashing**: Standalone Argon2id hashing with PHC format.
- **JWT**: HS256/RS256 token generation and validation - **JWT**: HS256/RS256 token management via a simple facade over `golang-jwt`.
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF - **SCRAM-SHA256**: Client/server implementation with Argon2id KDF.
- **HTTP Auth**: Basic/Bearer header parsing - **HTTP Auth**: Helpers for parsing Basic and Bearer authentication headers.
## Usage ## Usage
```go ```go
// Argon2 Password Hashing
hash, _ := auth.HashPassword("password123")
err := auth.VerifyPassword("password123", hash)
// JWT with HS256 // JWT with HS256
auth, _ := auth.NewAuthenticator([]byte("32-byte-secret-key...")) jwtMgr, _ := auth.NewJWT([]byte("a-very-secure-32-byte-secret-key"))
token, _ := auth.GenerateToken("user123", map[string]interface{}{"role": "admin"}) token, _ := jwtMgr.GenerateToken("user123", map[string]any{"role": "admin"})
userID, claims, _ := auth.ValidateToken(token) userID, claims, _ := jwtMgr.ValidateToken(token)
// SCRAM authentication // SCRAM authentication
server := auth.NewScramServer() server := auth.NewScramServer()
cred, _ := auth.DeriveCredential("user", "password", salt, 1, 65536, 4) phcHash, _ := auth.HashPassword("password123")
cred, _ := auth.MigrateFromPHC("user", "password123", phcHash)
server.AddCredential(cred) server.AddCredential(cred)
``` ```
## Package Structure ## Package Structure
- `interfaces.go` - Core interfaces - `doc.go` - Overview and package documentation
- `jwt.go` - JWT token operations - `argon2.go` - Standalone Argon2id password hashing
- `argon2.go` - Password hashing - `jwt.go` - JWT manager (HS256/RS256) wrapping `golang-jwt`
- `scram.go` - SCRAM-SHA256 protocol - `scram.go` - SCRAM-SHA256 client/server protocol
- `token.go` - Token validation utilities - `http.go` - HTTP Basic/Bearer header parsing
- `http.go` - HTTP header parsing - `token.go` - Simple in-memory token validator
- `errors.go` - Error definitions - `error.go` - Centralized error definitions
## Testing ## Testing
```bash ```bash
go test -v ./auth go test -v ./
``` ```

View File

@ -13,40 +13,85 @@ import (
// Default Argon2id parameters // Default Argon2id parameters
const ( const (
DefaultArgonTime = 3 // iterations (reduce for faster but less secure auth) DefaultArgonTime = 3 // iterations
DefaultArgonMemory = 64 * 1024 // 64 MB DefaultArgonMemory = 64 * 1024 // 64 MB
DefaultArgonThreads = 4 DefaultArgonThreads = 4
DefaultArgonSaltLen = 16 DefaultArgonSaltLen = 16
DefaultArgonKeyLen = 32 DefaultArgonKeyLen = 32
) )
// HashPassword creates an Argon2id PHC-format hash // argonParams holds configurable Argon2id parameters
func (a *Authenticator) HashPassword(password string) (string, error) { type argonParams struct {
time uint32
memory uint32
threads uint8
keyLen uint32
saltLen uint32
}
// Option configures Argon2id hashing parameters
type Option func(*argonParams)
// WithTime sets Argon2 iterations
func WithTime(t uint32) Option {
return func(p *argonParams) {
if t > 0 {
p.time = t
}
}
}
// WithMemory sets Argon2 memory in KiB
func WithMemory(m uint32) Option {
return func(p *argonParams) {
if m > 0 {
p.memory = m
}
}
}
// WithThreads sets Argon2 parallelism
func WithThreads(t uint8) Option {
return func(p *argonParams) {
if t > 0 {
p.threads = t
}
}
}
// HashPassword creates Argon2id PHC-format hash (standalone)
func HashPassword(password string, opts ...Option) (string, error) {
if len(password) < 8 { if len(password) < 8 {
return "", ErrWeakPassword return "", ErrWeakPassword
} }
// Generate salt params := &argonParams{
salt := make([]byte, DefaultArgonSaltLen) time: DefaultArgonTime,
memory: DefaultArgonMemory,
threads: DefaultArgonThreads,
keyLen: DefaultArgonKeyLen,
saltLen: DefaultArgonSaltLen,
}
for _, opt := range opts {
opt(params)
}
salt := make([]byte, params.saltLen)
if _, err := rand.Read(salt); err != nil { if _, err := rand.Read(salt); err != nil {
return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err) return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err)
} }
// Derive key using Argon2id hash := argon2.IDKey([]byte(password), salt, params.time, params.memory, params.threads, params.keyLen)
hash := argon2.IDKey([]byte(password), salt, a.argonTime, a.argonMemory, a.argonThreads, DefaultArgonKeyLen)
// Construct PHC format
saltB64 := base64.RawStdEncoding.EncodeToString(salt) saltB64 := base64.RawStdEncoding.EncodeToString(salt)
hashB64 := base64.RawStdEncoding.EncodeToString(hash) hashB64 := base64.RawStdEncoding.EncodeToString(hash)
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, a.argonMemory, a.argonTime, a.argonThreads, saltB64, hashB64) argon2.Version, params.memory, params.time, params.threads, saltB64, hashB64), nil
return phcHash, nil
} }
// VerifyPassword checks password against PHC-format hash // VerifyPassword checks password against PHC-format hash (standalone)
func (a *Authenticator) VerifyPassword(password, phcHash string) error { func VerifyPassword(password, phcHash string) error {
// Parse PHC format
parts := strings.Split(phcHash, "$") parts := strings.Split(phcHash, "$")
if len(parts) != 6 || parts[1] != "argon2id" { if len(parts) != 6 || parts[1] != "argon2id" {
return ErrPHCInvalidFormat return ErrPHCInvalidFormat
@ -66,10 +111,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error {
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err) return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
} }
// Compute hash with same parameters
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash))) computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
// Constant-time comparison
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 { if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
return ErrInvalidCredentials return ErrInvalidCredentials
} }
@ -77,9 +120,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error {
return nil return nil
} }
// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential // MigrateFromPHC converts PHC hash to SCRAM credential
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) { func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
// Parse PHC format
parts := strings.Split(phcHash, "$") parts := strings.Split(phcHash, "$")
if len(parts) != 6 || parts[1] != "argon2id" { if len(parts) != 6 || parts[1] != "argon2id" {
return nil, ErrPHCInvalidFormat return nil, ErrPHCInvalidFormat
@ -94,17 +136,10 @@ func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
return nil, ErrPHCInvalidSalt return nil, ErrPHCInvalidSalt
} }
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) // Use standalone function for verification
if err != nil { if err := VerifyPassword(password, phcHash); err != nil {
return nil, ErrPHCInvalidHash return nil, err
} }
// Verify password against hash
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
return nil, ErrInvalidCredentials
}
// Derive SCRAM credential with same parameters
return DeriveCredential(username, password, salt, time, memory, threads) return DeriveCredential(username, password, salt, time, memory, threads)
} }

View File

@ -11,13 +11,10 @@ import (
) )
func TestPasswordHashing(t *testing.T) { func TestPasswordHashing(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err, "Failed to create authenticator")
password := "testPassword123" password := "testPassword123"
// Test hashing // Test hashing with default parameters
hash, err := auth.HashPassword(password) hash, err := HashPassword(password)
require.NoError(t, err, "Failed to hash password") require.NoError(t, err, "Failed to hash password")
// Verify PHC format // Verify PHC format
@ -25,20 +22,30 @@ func TestPasswordHashing(t *testing.T) {
"Hash should have argon2id prefix, got: %s", hash) "Hash should have argon2id prefix, got: %s", hash)
// Test verification with correct password // Test verification with correct password
err = auth.VerifyPassword(password, hash) err = VerifyPassword(password, hash)
assert.NoError(t, err, "Failed to verify correct password") assert.NoError(t, err, "Failed to verify correct password")
// Test verification with incorrect password // Test verification with incorrect password
err = auth.VerifyPassword("wrongPassword", hash) err = VerifyPassword("wrongPassword", hash)
assert.Error(t, err, "Verification should fail for incorrect password") assert.Error(t, err, "Verification should fail for incorrect password")
assert.Equal(t, ErrInvalidCredentials, err) assert.Equal(t, ErrInvalidCredentials, err)
// Test weak password // Test weak password
_, err = auth.HashPassword("weak") _, err = HashPassword("weak")
assert.Equal(t, ErrWeakPassword, err, "Should reject weak password") assert.Equal(t, ErrWeakPassword, err, "Should reject weak password")
// Test with custom options
hash, err = HashPassword(password,
WithTime(5),
WithMemory(128*1024),
WithThreads(8))
require.NoError(t, err)
err = VerifyPassword(password, hash)
assert.NoError(t, err)
// Test malformed PHC hash // Test malformed PHC hash
err = auth.VerifyPassword(password, "$invalid$format") err = VerifyPassword(password, "$invalid$format")
assert.Error(t, err, "Should reject malformed hash") assert.Error(t, err, "Should reject malformed hash")
// Test corrupted salt // Test corrupted salt
@ -47,33 +54,27 @@ func TestPasswordHashing(t *testing.T) {
if len(parts) == 6 { if len(parts) == 6 {
parts[4] = "invalid!base64" parts[4] = "invalid!base64"
corruptedHash = strings.Join(parts, "$") corruptedHash = strings.Join(parts, "$")
err = auth.VerifyPassword(password, corruptedHash) err = VerifyPassword(password, corruptedHash)
assert.Error(t, err, "Should reject corrupted salt") assert.Error(t, err, "Should reject corrupted salt")
} }
} }
func TestEmptyPasswordAfterValidation(t *testing.T) { func TestEmptyPasswordAfterValidation(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err)
// Empty password should be rejected by length check // Empty password should be rejected by length check
_, err = auth.HashPassword("") _, err := HashPassword("")
assert.Equal(t, ErrWeakPassword, err) assert.Equal(t, ErrWeakPassword, err)
// 8-character password should pass // 8-character password should pass
hash, err := auth.HashPassword("12345678") hash, err := HashPassword("12345678")
require.NoError(t, err) require.NoError(t, err)
err = auth.VerifyPassword("12345678", hash) err = VerifyPassword("12345678", hash)
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestConcurrentPasswordOperations(t *testing.T) { func TestConcurrentPasswordOperations(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err)
password := "testPassword123" password := "testPassword123"
hash, err := auth.HashPassword(password) hash, err := HashPassword(password)
require.NoError(t, err) require.NoError(t, err)
// Test concurrent verification // Test concurrent verification
@ -82,7 +83,7 @@ func TestConcurrentPasswordOperations(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := auth.VerifyPassword(password, hash) err := VerifyPassword(password, hash)
assert.NoError(t, err) assert.NoError(t, err)
}() }()
} }
@ -90,14 +91,11 @@ func TestConcurrentPasswordOperations(t *testing.T) {
} }
func TestPHCMigration(t *testing.T) { func TestPHCMigration(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err)
password := "testPassword123" password := "testPassword123"
username := "migrationUser" username := "migrationUser"
// Generate PHC hash // Generate PHC hash
phcHash, err := auth.HashPassword(password) phcHash, err := HashPassword(password)
require.NoError(t, err) require.NoError(t, err)
// Migrate to SCRAM credential // Migrate to SCRAM credential

71
auth.go
View File

@ -1,71 +0,0 @@
// FILE: auth/auth.go
package auth
import (
"crypto/rsa"
"fmt"
)
// Authenticator provides password hashing and JWT operations
type Authenticator struct {
algorithm string
jwtSecret []byte // For HS256
privateKey *rsa.PrivateKey // For RS256
publicKey *rsa.PublicKey // For RS256
argonTime uint32
argonMemory uint32
argonThreads uint8
}
// NewAuthenticator creates a new authenticator with specified algorithm
func NewAuthenticator(key any, algorithm ...string) (*Authenticator, error) {
alg := "HS256"
if len(algorithm) > 0 && algorithm[0] != "" {
alg = algorithm[0]
}
auth := &Authenticator{
algorithm: alg,
argonTime: DefaultArgonTime,
argonMemory: DefaultArgonMemory,
argonThreads: DefaultArgonThreads,
}
switch alg {
case "HS256":
secret, ok := key.([]byte)
if !ok {
return nil, ErrInvalidKeyType
}
if len(secret) < 32 {
return nil, ErrSecretTooShort
}
auth.jwtSecret = secret
case "RS256":
switch k := key.(type) {
case *rsa.PrivateKey:
auth.privateKey = k
auth.publicKey = &k.PublicKey
case *rsa.PublicKey:
auth.publicKey = k
case []byte:
// Try parsing as PEM
if privKey, err := parseRSAPrivateKey(k); err == nil {
auth.privateKey = privKey
auth.publicKey = &privKey.PublicKey
} else if pubKey, err := parseRSAPublicKey(k); err == nil {
auth.publicKey = pubKey
} else {
return nil, fmt.Errorf("failed to parse RSA key: %w", err)
}
default:
return nil, ErrInvalidKeyType
}
default:
return nil, ErrInvalidAlgorithm
}
return auth, nil
}

View File

@ -1,68 +0,0 @@
// FILE: auth/auth_test.go
package auth
import (
"crypto/rand"
"crypto/rsa"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAuthenticator(t *testing.T) {
// Test HS256 creation
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err, "Failed to create HS256 authenticator")
assert.Equal(t, "HS256", auth.algorithm)
// Test with short secret
_, err = NewAuthenticator([]byte("short"))
assert.Equal(t, ErrSecretTooShort, err)
// Test RS256 with private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
authRS, err := NewAuthenticator(privateKey, "RS256")
require.NoError(t, err)
assert.Equal(t, "RS256", authRS.algorithm)
assert.NotNil(t, authRS.privateKey)
assert.NotNil(t, authRS.publicKey)
// Test RS256 with public key only
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
require.NoError(t, err)
assert.Equal(t, "RS256", authPub.algorithm)
assert.Nil(t, authPub.privateKey)
assert.NotNil(t, authPub.publicKey)
// Test invalid algorithm
_, err = NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"), "INVALID")
assert.Equal(t, ErrInvalidAlgorithm, err)
// Test invalid key type for HS256
_, err = NewAuthenticator(privateKey, "HS256")
assert.Equal(t, ErrInvalidKeyType, err)
}
func TestInterfaceCompliance(t *testing.T) {
// Verify Authenticator implements AuthenticatorInterface
auth, _ := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
var _ AuthenticatorInterface = auth
// Test interface methods work
hash, err := auth.HashPassword("testpass123")
require.NoError(t, err)
err = auth.VerifyPassword("testpass123", hash)
assert.NoError(t, err)
token, err := auth.GenerateToken("user1", nil)
require.NoError(t, err)
userID, _, err := auth.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, "user1", userID)
}

53
doc.go Normal file
View File

@ -0,0 +1,53 @@
// FILE: auth/doc.go
package auth
/*
Package auth provides modular authentication components:
# Argon2 Password Hashing
Standalone password hashing using Argon2id:
hash, err := auth.HashPassword("password123")
err = auth.VerifyPassword("password123", hash)
// With custom parameters
hash, err := auth.HashPassword("password123",
auth.WithTime(5),
auth.WithMemory(128*1024))
# JWT Token Management
JSON Web Token generation and validation:
// HS256 (symmetric)
jwtMgr, _ := auth.NewJWT(secret)
token, _ := jwtMgr.GenerateToken("user1", claims)
userID, claims, _ := jwtMgr.ValidateToken(token)
// RS256 (asymmetric)
jwtMgr, _ := auth.NewJWTRSA(privateKey)
// One-off operations
token, _ := auth.GenerateHS256Token(secret, "user1", claims, 1*time.Hour)
# SCRAM-SHA256 Authentication
Server and client implementation for SCRAM:
// Server
server := auth.NewScramServer()
server.AddCredential(credential)
// Client
client := auth.NewScramClient(username, password)
# HTTP Authentication Parsing
Utility functions for HTTP headers:
username, password, _ := auth.ParseBasicAuth(header)
token, _ := auth.ParseBearerToken(header)
Each module can be used independently without initializing other components.
*/

View File

@ -10,8 +10,6 @@ import (
var ( var (
ErrInvalidCredentials = errors.New("invalid credentials") ErrInvalidCredentials = errors.New("invalid credentials")
ErrWeakPassword = errors.New("password must be at least 8 characters") ErrWeakPassword = errors.New("password must be at least 8 characters")
ErrInvalidAlgorithm = errors.New("invalid algorithm")
ErrInvalidKeyType = errors.New("invalid key type for algorithm")
) )
// JWT-specific errors // JWT-specific errors
@ -22,9 +20,6 @@ var (
ErrTokenInvalidSignature = errors.New("token: invalid signature") ErrTokenInvalidSignature = errors.New("token: invalid signature")
ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch") ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch")
ErrTokenMissingClaim = errors.New("token: missing required claim") ErrTokenMissingClaim = errors.New("token: missing required claim")
ErrTokenInvalidHeader = errors.New("token: invalid header encoding")
ErrTokenInvalidClaims = errors.New("token: invalid claims encoding")
ErrTokenInvalidJSON = errors.New("token: malformed JSON")
ErrTokenEmptyUserID = errors.New("token: empty user ID") ErrTokenEmptyUserID = errors.New("token: empty user ID")
ErrTokenNoPrivateKey = errors.New("token: private key required for signing") ErrTokenNoPrivateKey = errors.New("token: private key required for signing")
ErrTokenNoPublicKey = errors.New("token: public key required for verification") ErrTokenNoPublicKey = errors.New("token: public key required for verification")

3
go.mod
View File

@ -1,8 +1,9 @@
module auth module github.com/lixenwraith/auth
go 1.25.3 go 1.25.3
require ( require (
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.43.0 golang.org/x/crypto v0.43.0
) )

2
go.sum
View File

@ -1,5 +1,7 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=

View File

@ -1,17 +0,0 @@
// FILE: auth/interface.go
package auth
// AuthenticatorInterface defines the authentication operations
type AuthenticatorInterface interface {
HashPassword(password string) (hash string, err error)
VerifyPassword(password, hash string) (err error)
GenerateToken(userID string, claims map[string]any) (token string, err error)
ValidateToken(token string) (userID string, claims map[string]any, err error)
}
// TokenValidator validates bearer tokens
type TokenValidator interface {
ValidateToken(token string) (valid bool)
AddToken(token string)
RemoveToken(token string)
}

408
jwt.go
View File

@ -2,179 +2,289 @@
package auth package auth
import ( import (
"crypto"
"crypto/hmac"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/subtle"
"crypto/x509" "crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/golang-jwt/jwt/v5"
) )
// GenerateToken creates a JWT token with user claims // JWT configuration defaults
func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) { const (
DefaultTokenLifetime = 24 * time.Hour
DefaultLeeway = 5 * time.Minute
)
// customClaims extends RegisteredClaims with arbitrary user data
type customClaims struct {
jwt.RegisteredClaims
Extra map[string]any `json:"extra,omitempty"`
}
// JWT manages token generation and validation
type JWT struct {
algorithm jwt.SigningMethod
signKey any // []byte for HMAC, *rsa.PrivateKey for RSA
verifyKey any // []byte for HMAC, *rsa.PublicKey for RSA
tokenLifetime time.Duration
leeway time.Duration
issuer string
audience []string
}
// JWTOption configures JWT behavior
type JWTOption func(*JWT)
// WithTokenLifetime sets token expiration duration
func WithTokenLifetime(d time.Duration) JWTOption {
return func(j *JWT) {
if d > 0 {
j.tokenLifetime = d
}
}
}
// WithLeeway sets clock skew tolerance
func WithLeeway(d time.Duration) JWTOption {
return func(j *JWT) {
if d >= 0 {
j.leeway = d
}
}
}
// WithIssuer sets token issuer claim
func WithIssuer(iss string) JWTOption {
return func(j *JWT) {
j.issuer = iss
}
}
// WithAudience sets token audience claim
func WithAudience(aud []string) JWTOption {
return func(j *JWT) {
j.audience = aud
}
}
// NewJWT creates JWT manager for HS256 (symmetric)
func NewJWT(secret []byte, opts ...JWTOption) (*JWT, error) {
if len(secret) < 32 {
return nil, ErrSecretTooShort
}
j := &JWT{
algorithm: jwt.SigningMethodHS256,
signKey: secret,
verifyKey: secret,
tokenLifetime: DefaultTokenLifetime,
leeway: DefaultLeeway,
}
for _, opt := range opts {
opt(j)
}
return j, nil
}
// NewJWTRSA creates JWT manager for RS256 (asymmetric)
func NewJWTRSA(privateKey *rsa.PrivateKey, opts ...JWTOption) (*JWT, error) {
if privateKey == nil {
return nil, ErrTokenNoPrivateKey
}
j := &JWT{
algorithm: jwt.SigningMethodRS256,
signKey: privateKey,
verifyKey: &privateKey.PublicKey,
tokenLifetime: DefaultTokenLifetime,
leeway: DefaultLeeway,
}
for _, opt := range opts {
opt(j)
}
return j, nil
}
// NewJWTVerifier creates JWT manager for verification only (RS256)
func NewJWTVerifier(publicKey *rsa.PublicKey, opts ...JWTOption) (*JWT, error) {
if publicKey == nil {
return nil, ErrTokenNoPublicKey
}
j := &JWT{
algorithm: jwt.SigningMethodRS256,
signKey: nil, // Cannot sign
verifyKey: publicKey,
tokenLifetime: DefaultTokenLifetime,
leeway: DefaultLeeway,
}
for _, opt := range opts {
opt(j)
}
return j, nil
}
// GenerateToken creates signed JWT with claims
func (j *JWT) GenerateToken(userID string, claims map[string]any) (string, error) {
if userID == "" { if userID == "" {
return "", ErrTokenEmptyUserID return "", ErrTokenEmptyUserID
} }
if a.algorithm == "RS256" && a.privateKey == nil { if j.signKey == nil {
return "", ErrTokenNoPrivateKey return "", ErrTokenNoPrivateKey
} }
// Build JWT claims
now := time.Now() now := time.Now()
jwtClaims := map[string]any{ registeredClaims := jwt.RegisteredClaims{
"sub": userID, Subject: userID,
"iat": now.Unix(), Issuer: j.issuer,
"exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry Audience: j.audience,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(j.tokenLifetime)),
NotBefore: jwt.NewNumericDate(now),
} }
// Reserved claims that cannot be overridden token := jwt.NewWithClaims(j.algorithm, customClaims{
reservedClaims := map[string]bool{ RegisteredClaims: registeredClaims,
"sub": true, "iat": true, "exp": true, "nbf": true, Extra: claims,
"iss": true, "aud": true, "jti": true, "typ": true, })
"alg": true,
}
// Merge custom claims return token.SignedString(j.signKey)
for k, v := range claims {
if !reservedClaims[k] {
jwtClaims[k] = v
}
}
// Create JWT header
header := map[string]any{
"alg": a.algorithm,
"typ": "JWT",
}
// Encode header and payload
headerJSON, _ := json.Marshal(header)
claimsJSON, _ := json.Marshal(jwtClaims)
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
// Create signature
signingInput := headerB64 + "." + claimsB64
var signature string
switch a.algorithm {
case "HS256":
h := hmac.New(sha256.New, a.jwtSecret)
h.Write([]byte(signingInput))
signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
case "RS256":
hashed := sha256.Sum256([]byte(signingInput))
sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:])
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
signature = base64.RawURLEncoding.EncodeToString(sig)
}
// Combine to form JWT
token := signingInput + "." + signature
return token, nil
} }
// ValidateToken verifies JWT and returns userID and claims // ValidateToken verifies JWT and extracts claims
func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) { func (j *JWT) ValidateToken(tokenString string) (string, map[string]any, error) {
// Split token parser := jwt.NewParser(
parts := strings.Split(token, ".") jwt.WithLeeway(j.leeway),
if len(parts) != 3 { jwt.WithAudience(j.audience...),
jwt.WithIssuer(j.issuer),
jwt.WithValidMethods([]string{j.algorithm.Alg()}),
jwt.WithExpirationRequired(),
)
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
// Algorithm already validated by WithValidMethods
return j.verifyKey, nil
})
if err != nil {
return "", nil, mapJWTError(err)
}
claims, ok := token.Claims.(*customClaims)
if !ok || !token.Valid {
return "", nil, ErrTokenMalformed return "", nil, ErrTokenMalformed
} }
// Decode header to check algorithm return claims.Subject, claims.Extra, nil
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return "", nil, ErrTokenInvalidHeader
}
var header map[string]any
if err = json.Unmarshal(headerJSON, &header); err != nil {
return "", nil, ErrTokenInvalidJSON
}
// Verify algorithm matches
if alg, ok := header["alg"].(string); !ok || alg != a.algorithm {
return "", nil, ErrTokenAlgorithmMismatch
}
// Verify signature
signingInput := parts[0] + "." + parts[1]
switch a.algorithm {
case "HS256":
h := hmac.New(sha256.New, a.jwtSecret)
h.Write([]byte(signingInput))
expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 {
return "", nil, ErrTokenInvalidSignature
}
case "RS256":
if a.publicKey == nil {
return "", nil, ErrTokenNoPublicKey
}
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return "", nil, ErrTokenInvalidSignature
}
hashed := sha256.Sum256([]byte(signingInput))
if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil {
return "", nil, ErrTokenInvalidSignature
}
}
// Decode claims
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", nil, ErrTokenInvalidClaims
}
var claims map[string]any
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
return "", nil, ErrTokenInvalidJSON
}
// Check expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return "", nil, ErrTokenExpired
}
}
// Check not before
if nbf, ok := claims["nbf"].(float64); ok {
if time.Now().Unix() < int64(nbf) {
return "", nil, ErrTokenNotYetValid
}
}
// Extract userID
userID, ok := claims["sub"].(string)
if !ok {
return "", nil, ErrTokenMissingClaim
}
return userID, claims, nil
} }
// parseRSAPrivateKey parses PEM encoded RSA private key // mapJWTError translates jwt library errors to auth package errors
func mapJWTError(err error) error {
switch {
case errors.Is(err, jwt.ErrTokenMalformed):
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
case errors.Is(err, jwt.ErrTokenUnverifiable):
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
return fmt.Errorf("%w : %w", ErrTokenInvalidSignature, err)
case errors.Is(err, jwt.ErrTokenExpired):
return fmt.Errorf("%w : %w", ErrTokenExpired, err)
case errors.Is(err, jwt.ErrTokenNotValidYet):
return fmt.Errorf("%w : %w", ErrTokenNotYetValid, err)
case errors.Is(err, jwt.ErrTokenInvalidAudience):
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
case errors.Is(err, jwt.ErrTokenInvalidIssuer):
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
default:
// Check for algorithm mismatch in error message
if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
return fmt.Errorf("%w : %w", ErrTokenAlgorithmMismatch, err)
}
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
}
}
// Standalone helper functions for one-off operations
// GenerateHS256Token creates HS256 JWT without manager instance
func GenerateHS256Token(secret []byte, userID string, claims map[string]any, lifetime time.Duration) (string, error) {
if len(secret) < 32 {
return "", ErrSecretTooShort
}
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, customClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(lifetime)),
},
Extra: claims,
})
return token.SignedString(secret)
}
// ValidateHS256Token verifies HS256 JWT without manager instance
func ValidateHS256Token(secret []byte, tokenString string) (string, map[string]any, error) {
if len(secret) < 32 {
return "", nil, ErrSecretTooShort
}
parser := jwt.NewParser(
jwt.WithValidMethods([]string{"HS256"}),
jwt.WithLeeway(DefaultLeeway),
)
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
return secret, nil
})
if err != nil {
return "", nil, mapJWTError(err)
}
claims, ok := token.Claims.(*customClaims)
if !ok || !token.Valid {
return "", nil, ErrTokenMalformed
}
return claims.Subject, claims.Extra, nil
}
// RSA Utilities
// NewJWTRSAFromPEM creates a JWT manager for RS256 from raw PEM-encoded private key data.
func NewJWTRSAFromPEM(privateKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
privateKey, err := parseRSAPrivateKey(privateKeyPEM)
if err != nil {
return nil, err
}
// Call the original constructor with the now-parsed key
return NewJWTRSA(privateKey, opts...)
}
// NewJWTVerifierFromPEM creates a JWT manager for verification from raw PEM-encoded public key data.
func NewJWTVerifierFromPEM(publicKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
publicKey, err := parseRSAPublicKey(publicKeyPEM)
if err != nil {
return nil, err
}
// Call the original constructor with the now-parsed key
return NewJWTVerifier(publicKey, opts...)
}
// parseRSAPrivateKey parses a PEM-encoded RSA private key.
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(pemBytes) block, _ := pem.Decode(pemBytes)
if block == nil { if block == nil {
@ -187,7 +297,7 @@ func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
return key, nil return key, nil
} }
// parseRSAPublicKey parses PEM encoded RSA public key // parseRSAPublicKey parses a PEM-encoded RSA public key.
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) { func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
block, _ := pem.Decode(pemBytes) block, _ := pem.Decode(pemBytes)
if block == nil { if block == nil {

View File

@ -4,19 +4,20 @@ package auth
import ( import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "crypto/x509"
"errors" "encoding/pem"
"fmt"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestJWTHS256(t *testing.T) { func TestJWTHS256(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) secret := []byte("test-secret-key-must-be-32-bytes")
jwtMgr, err := NewJWT(secret)
require.NoError(t, err) require.NoError(t, err)
userID := "user123" userID := "user123"
@ -26,54 +27,17 @@ func TestJWTHS256(t *testing.T) {
} }
// Generate token // Generate token
token, err := auth.GenerateToken(userID, claims) token, err := jwtMgr.GenerateToken(userID, claims)
require.NoError(t, err, "Failed to generate token") require.NoError(t, err)
assert.NotEmpty(t, token) assert.NotEmpty(t, token)
// Validate token // Validate token
extractedUserID, extractedClaims, err := auth.ValidateToken(token) extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
require.NoError(t, err, "Failed to validate token") require.NoError(t, err)
assert.Equal(t, userID, extractedUserID) assert.Equal(t, userID, extractedUserID)
assert.Equal(t, "test@example.com", extractedClaims["email"]) assert.Equal(t, "test@example.com", extractedClaims["email"])
assert.Equal(t, "admin", extractedClaims["role"]) assert.Equal(t, "admin", extractedClaims["role"])
// Test invalid token
_, _, err = auth.ValidateToken("invalid.token.here")
assert.Error(t, err)
assert.True(t, errors.Is(err, ErrTokenInvalidJSON))
// Test tampered token
parts := strings.Split(token, ".")
require.Len(t, parts, 3, "JWT should have 3 parts")
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
_, _, err = auth.ValidateToken(tampered)
assert.Error(t, err)
assert.True(t, errors.Is(err, ErrTokenInvalidSignature))
// Test reserved claims cannot be overridden
overrideClaims := map[string]any{
"sub": "override",
"iat": 12345,
"exp": 67890,
"nbf": 11111,
"iss": "attacker",
"aud": "victim",
"jti": "fake",
}
token, err = auth.GenerateToken(userID, overrideClaims)
require.NoError(t, err)
extractedUserID, extractedClaims, err = auth.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, userID, extractedUserID, "UserID should not be overridden")
assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden")
assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden")
assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims")
assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims")
} }
func TestJWTRS256(t *testing.T) { func TestJWTRS256(t *testing.T) {
@ -82,95 +46,214 @@ func TestJWTRS256(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test with private key (can sign and verify) // Test with private key (can sign and verify)
authPriv, err := NewAuthenticator(privateKey, "RS256") jwtMgr, err := NewJWTRSA(privateKey)
require.NoError(t, err) require.NoError(t, err)
userID := "user456" userID := "user456"
claims := map[string]any{ claims := map[string]any{
"email": "rs256@example.com",
"scope": "read:all", "scope": "read:all",
} }
// Generate token // Generate token
token, err := authPriv.GenerateToken(userID, claims) token, err := jwtMgr.GenerateToken(userID, claims)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, token) assert.NotEmpty(t, token)
// Validate token with private key auth (has public key too) // Validate with same manager
extractedUserID, extractedClaims, err := authPriv.ValidateToken(token) extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, userID, extractedUserID) assert.Equal(t, userID, extractedUserID)
assert.Equal(t, "rs256@example.com", extractedClaims["email"])
assert.Equal(t, "read:all", extractedClaims["scope"]) assert.Equal(t, "read:all", extractedClaims["scope"])
// Test with public key only (can only verify) // Test with verifier only (public key)
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256") verifier, err := NewJWTVerifier(&privateKey.PublicKey)
require.NoError(t, err) require.NoError(t, err)
// Should be able to validate token // Should validate token
extractedUserID, extractedClaims, err = authPub.ValidateToken(token) extractedUserID, _, err = verifier.ValidateToken(token)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, userID, extractedUserID) assert.Equal(t, userID, extractedUserID)
// Should not be able to generate token // Should not generate token
_, err = authPub.GenerateToken(userID, claims) _, err = verifier.GenerateToken(userID, claims)
assert.Error(t, err, "Public key only auth should not generate tokens")
assert.Equal(t, ErrTokenNoPrivateKey, err) assert.Equal(t, ErrTokenNoPrivateKey, err)
// Test algorithm mismatch
authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
require.NoError(t, err)
_, _, err = authHS256.ValidateToken(token)
assert.Error(t, err, "HS256 auth should not validate RS256 token")
// assert.True(t, errors.Is(err, ErrInvalidToken))
assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch))
fmt.Println(err)
} }
func TestExpiredToken(t *testing.T) { func TestJWTOptions(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) secret := []byte("test-secret-key-must-be-32-bytes")
// Test custom lifetime
jwtMgr, err := NewJWT(secret,
WithTokenLifetime(1*time.Hour),
WithIssuer("test-issuer"),
WithAudience([]string{"api.example.com"}),
)
require.NoError(t, err) require.NoError(t, err)
userID := "user123" token, err := jwtMgr.GenerateToken("user1", nil)
// Generate normal token (should have 7 days expiry)
token, err := auth.GenerateToken(userID, nil)
require.NoError(t, err) require.NoError(t, err)
_, extractedClaims, err := auth.ValidateToken(token) // Parse token to check claims
require.NoError(t, err) parsed, _ := jwt.Parse(token, func(token *jwt.Token) (any, error) {
return secret, nil
})
// Check expiry is in future (approximately 7 days) claims := parsed.Claims.(jwt.MapClaims)
expiry := extractedClaims["exp"].(float64)
now := time.Now().Unix()
assert.Greater(t, expiry, float64(now), "Token expiry should be in future") // Check issuer
assert.InDelta(t, expiry, float64(now+7*24*60*60), 10, assert.Equal(t, "test-issuer", claims["iss"])
"Token expiry should be approximately 7 days from now")
// Check audience
aud := claims["aud"].([]any)
assert.Contains(t, aud, "api.example.com")
// Check expiration is ~1 hour
exp := int64(claims["exp"].(float64))
iat := int64(claims["iat"].(float64))
assert.InDelta(t, 3600, exp-iat, 10)
} }
func TestCorruptJWTParts(t *testing.T) { func TestJWTErrors(t *testing.T) {
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) secret := []byte("test-secret-key-must-be-32-bytes")
jwtMgr, err := NewJWT(secret)
require.NoError(t, err) require.NoError(t, err)
// Test with missing parts // Empty user ID
_, _, err = auth.ValidateToken("only.two") _, err = jwtMgr.GenerateToken("", nil)
assert.True(t, errors.Is(err, ErrTokenMalformed)) assert.Equal(t, ErrTokenEmptyUserID, err)
// Test with invalid header encoding // Invalid token format
_, _, err = auth.ValidateToken("not-base64!.valid.valid") _, _, err = jwtMgr.ValidateToken("invalid.token")
assert.Error(t, err) assert.ErrorIs(t, err, ErrTokenMalformed)
// Test with invalid claims encoding // Tampered signature
validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" token, _ := jwtMgr.GenerateToken("user1", nil)
_, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid") parts := strings.Split(token, ".")
assert.Error(t, err) tampered := parts[0] + "." + parts[1] + ".invalidsignature"
_, _, err = jwtMgr.ValidateToken(tampered)
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
// Test with invalid JSON in claims // Wrong algorithm
invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json")) rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
_, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature") rsaMgr, _ := NewJWTRSA(rsaKey)
assert.Error(t, err) rsaToken, _ := rsaMgr.GenerateToken("user1", nil)
_, _, err = jwtMgr.ValidateToken(rsaToken)
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
}
func TestJWTExpiration(t *testing.T) {
secret := []byte("test-secret-key-must-be-32-bytes")
// Create token with 1 second lifetime
jwtMgr, err := NewJWT(secret, WithTokenLifetime(1*time.Second), WithLeeway(0))
require.NoError(t, err)
token, err := jwtMgr.GenerateToken("user1", nil)
require.NoError(t, err)
// Should be valid immediately
_, _, err = jwtMgr.ValidateToken(token)
assert.NoError(t, err)
// Wait for expiration
time.Sleep(2 * time.Second)
// Should be expired
_, _, err = jwtMgr.ValidateToken(token)
assert.ErrorIs(t, err, ErrTokenExpired)
}
func TestLeeway(t *testing.T) {
secret := []byte("test-secret-key-must-be-32-bytes")
// Create manager with no leeway
jwtMgr, err := NewJWT(secret, WithLeeway(0))
require.NoError(t, err)
// Manually create a token with NotBefore in future
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "user1",
"nbf": now.Add(2 * time.Second).Unix(),
"exp": now.Add(1 * time.Hour).Unix(),
})
tokenString, err := token.SignedString(secret)
require.NoError(t, err)
// Should fail immediately (not valid yet)
_, _, err = jwtMgr.ValidateToken(tokenString)
assert.ErrorIs(t, err, ErrTokenNotYetValid)
// Create manager with leeway
jwtMgrWithLeeway, err := NewJWT(secret, WithLeeway(5*time.Second))
require.NoError(t, err)
// Should pass with leeway
_, _, err = jwtMgrWithLeeway.ValidateToken(tokenString)
assert.NoError(t, err)
}
func TestStandaloneFunctions(t *testing.T) {
secret := []byte("test-secret-key-must-be-32-bytes")
userID := "standalone-user"
claims := map[string]any{"test": "value"}
// Generate token
token, err := GenerateHS256Token(secret, userID, claims, 1*time.Hour)
require.NoError(t, err)
// Validate token
extractedUserID, extractedClaims, err := ValidateHS256Token(secret, token)
require.NoError(t, err)
assert.Equal(t, userID, extractedUserID)
assert.Equal(t, "value", extractedClaims["test"])
// Test with short secret
_, err = GenerateHS256Token([]byte("short"), userID, claims, 1*time.Hour)
assert.Equal(t, ErrSecretTooShort, err)
}
func TestJWTRSAFromPEM(t *testing.T) {
// 1. Generate a new RSA key pair for this test
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
// 2. Encode the private key to PEM format
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
// 3. Encode the public key to PEM format
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
})
// 4. Test the PEM constructor for the signer
jwtMgr, err := NewJWTRSAFromPEM(privateKeyPEM)
require.NoError(t, err)
token, err := jwtMgr.GenerateToken("user-from-pem", nil)
require.NoError(t, err)
assert.NotEmpty(t, token)
// 5. Test the PEM constructor for the verifier
verifier, err := NewJWTVerifierFromPEM(publicKeyPEM)
require.NoError(t, err)
userID, _, err := verifier.ValidateToken(token)
require.NoError(t, err)
assert.Equal(t, "user-from-pem", userID)
// 6. Test failure cases with invalid data
_, err = NewJWTRSAFromPEM([]byte("invalid pem data"))
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
_, err = NewJWTVerifierFromPEM([]byte("invalid pem data"))
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
} }

View File

@ -143,6 +143,10 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
return nil, ErrSCRAMSaltTooShort return nil, ErrSCRAMSaltTooShort
} }
if time == 0 || memory == 0 || threads == 0 {
return nil, ErrSCRAMZeroParams
}
// Derive salted password using Argon2id // Derive salted password using Argon2id
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen) saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen)

146
scram_test.go Normal file
View File

@ -0,0 +1,146 @@
// FILE: auth/scram_test.go
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// setupScramTest is a helper to initialize a server and user credential for testing.
// It performs the full Argon2 -> SCRAM migration workflow.
func setupScramTest(t *testing.T) (server *ScramServer, username, password string, cred *Credential) {
username = "testuser"
password = "SecurePassword123"
// 1. Start with an Argon2 PHC hash, as a real application would.
phcHash, err := HashPassword(password)
require.NoError(t, err, "Setup failed: could not hash password")
// 2. Migrate the PHC hash to a SCRAM credential.
cred, err = MigrateFromPHC(username, password, phcHash)
require.NoError(t, err, "Setup failed: could not migrate from PHC hash")
// 3. Create a server and add the new credential.
server = NewScramServer()
server.AddCredential(cred)
return server, username, password, cred
}
// TestScram_FullRoundtrip_Success simulates a complete, successful authentication handshake.
func TestScram_FullRoundtrip_Success(t *testing.T) {
server, username, password, _ := setupScramTest(t)
client := NewScramClient(username, password)
// --- Step 1: Client sends its first message ---
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
// --- Step 2: Server receives client's message and responds ---
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err, "Server failed to process client's first message")
// --- Step 3: Client receives server's message, computes proof ---
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err, "Client failed to process server's first message")
// --- Step 4: Server receives client's proof and verifies it ---
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
require.NoError(t, err, "Server failed to verify client's final proof")
assert.NotEmpty(t, serverFinal.ServerSignature, "Server signature should not be empty")
// --- Step 5: Client verifies server's signature (mutual authentication) ---
err = client.VerifyServerFinalMessage(serverFinal)
assert.NoError(t, err, "Client failed to verify server's final signature")
t.Log("SCRAM full roundtrip successful")
}
// TestScram_FullRoundtrip_WrongPassword ensures authentication fails with an incorrect password.
func TestScram_FullRoundtrip_WrongPassword(t *testing.T) {
server, username, _, _ := setupScramTest(t)
// Create a client with the WRONG password
client := NewScramClient(username, "WrongPassword!!!")
// Steps 1-3 will appear to succeed, as the client doesn't know the password is wrong yet.
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err)
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err)
// --- Step 4: Server verification should fail here ---
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should reject proof from wrong password")
t.Log("SCRAM correctly failed for wrong password")
}
// TestScram_FullRoundtrip_UserNotFound tests for user enumeration protection.
// The server should not reveal whether a user exists or not in its first message.
func TestScram_FullRoundtrip_UserNotFound(t *testing.T) {
server, _, _, _ := setupScramTest(t)
client := NewScramClient("unknown_user", "any_password")
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
// --- Step 2: Server should return an error, but also a FAKE response ---
// This prevents an attacker from knowing if the user exists based on the response structure.
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should return an error for an unknown user")
assert.NotEmpty(t, serverFirst.FullNonce, "Server must still provide a nonce to prevent enumeration")
assert.NotEmpty(t, serverFirst.Salt, "Server must still provide a salt to prevent enumeration")
t.Log("SCRAM correctly protected against user enumeration")
}
// TestScram_InvalidNonce simulates a replay attack or message mismatch.
func TestScram_InvalidNonce(t *testing.T) {
server, username, password, _ := setupScramTest(t)
client := NewScramClient(username, password)
// Perform the first part of the handshake
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
clientFinal, _ := client.ProcessServerFirstMessage(serverFirst)
// Attempt to finalize with a completely different nonce
_, err := server.ProcessClientFinalMessage("this-is-a-bad-nonce", clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrSCRAMInvalidNonce, "Server should reject a final message with an unknown nonce")
}
// TestScram_CredentialImportExport verifies that credentials can be serialized and deserialized correctly.
func TestScram_CredentialImportExport(t *testing.T) {
_, _, _, originalCred := setupScramTest(t)
// Export the credential to a map
exportedData := originalCred.Export()
require.NotNil(t, exportedData)
// Assert that required fields exist and are strings (as they are base64 encoded)
assert.IsType(t, "", exportedData["salt"])
assert.IsType(t, "", exportedData["stored_key"])
assert.IsType(t, "", exportedData["server_key"])
// Import the credential back from the map
importedCred, err := ImportCredential(exportedData)
require.NoError(t, err)
require.NotNil(t, importedCred)
// Verify that the imported credential is identical to the original
assert.Equal(t, originalCred.Username, importedCred.Username)
assert.Equal(t, originalCred.Salt, importedCred.Salt)
assert.Equal(t, originalCred.ArgonTime, importedCred.ArgonTime)
assert.Equal(t, originalCred.ArgonMemory, importedCred.ArgonMemory)
assert.Equal(t, originalCred.ArgonThreads, importedCred.ArgonThreads)
assert.Equal(t, originalCred.StoredKey, importedCred.StoredKey)
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
t.Log("SCRAM credential import/export successful")
}