v0.2.0 restructured and generalized to be more modular, added golang-jwt dependency
This commit is contained in:
41
README.md
41
README.md
@ -1,38 +1,45 @@
|
||||
# Auth Package
|
||||
|
||||
Pluggable authentication utilities for Go applications.
|
||||
Modular authentication utilities for Go applications.
|
||||
|
||||
## Features
|
||||
|
||||
- **Password Hashing**: Argon2id with PHC format
|
||||
- **JWT**: HS256/RS256 token generation and validation
|
||||
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF
|
||||
- **HTTP Auth**: Basic/Bearer header parsing
|
||||
- **Password Hashing**: Standalone Argon2id hashing with PHC format.
|
||||
- **JWT**: HS256/RS256 token management via a simple facade over `golang-jwt`.
|
||||
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF.
|
||||
- **HTTP Auth**: Helpers for parsing Basic and Bearer authentication headers.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
// Argon2 Password Hashing
|
||||
hash, _ := auth.HashPassword("password123")
|
||||
err := auth.VerifyPassword("password123", hash)
|
||||
|
||||
// JWT with HS256
|
||||
auth, _ := auth.NewAuthenticator([]byte("32-byte-secret-key..."))
|
||||
token, _ := auth.GenerateToken("user123", map[string]interface{}{"role": "admin"})
|
||||
userID, claims, _ := auth.ValidateToken(token)
|
||||
jwtMgr, _ := auth.NewJWT([]byte("a-very-secure-32-byte-secret-key"))
|
||||
token, _ := jwtMgr.GenerateToken("user123", map[string]any{"role": "admin"})
|
||||
userID, claims, _ := jwtMgr.ValidateToken(token)
|
||||
|
||||
// SCRAM authentication
|
||||
server := auth.NewScramServer()
|
||||
cred, _ := auth.DeriveCredential("user", "password", salt, 1, 65536, 4)
|
||||
phcHash, _ := auth.HashPassword("password123")
|
||||
cred, _ := auth.MigrateFromPHC("user", "password123", phcHash)
|
||||
server.AddCredential(cred)
|
||||
```
|
||||
|
||||
## Package Structure
|
||||
|
||||
- `interfaces.go` - Core interfaces
|
||||
- `jwt.go` - JWT token operations
|
||||
- `argon2.go` - Password hashing
|
||||
- `scram.go` - SCRAM-SHA256 protocol
|
||||
- `token.go` - Token validation utilities
|
||||
- `http.go` - HTTP header parsing
|
||||
- `errors.go` - Error definitions
|
||||
- `doc.go` - Overview and package documentation
|
||||
- `argon2.go` - Standalone Argon2id password hashing
|
||||
- `jwt.go` - JWT manager (HS256/RS256) wrapping `golang-jwt`
|
||||
- `scram.go` - SCRAM-SHA256 client/server protocol
|
||||
- `http.go` - HTTP Basic/Bearer header parsing
|
||||
- `token.go` - Simple in-memory token validator
|
||||
- `error.go` - Centralized error definitions
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
go test -v ./auth
|
||||
go test -v ./
|
||||
```
|
||||
93
argon2.go
93
argon2.go
@ -13,40 +13,85 @@ import (
|
||||
|
||||
// Default Argon2id parameters
|
||||
const (
|
||||
DefaultArgonTime = 3 // iterations (reduce for faster but less secure auth)
|
||||
DefaultArgonTime = 3 // iterations
|
||||
DefaultArgonMemory = 64 * 1024 // 64 MB
|
||||
DefaultArgonThreads = 4
|
||||
DefaultArgonSaltLen = 16
|
||||
DefaultArgonKeyLen = 32
|
||||
)
|
||||
|
||||
// HashPassword creates an Argon2id PHC-format hash
|
||||
func (a *Authenticator) HashPassword(password string) (string, error) {
|
||||
// argonParams holds configurable Argon2id parameters
|
||||
type argonParams struct {
|
||||
time uint32
|
||||
memory uint32
|
||||
threads uint8
|
||||
keyLen uint32
|
||||
saltLen uint32
|
||||
}
|
||||
|
||||
// Option configures Argon2id hashing parameters
|
||||
type Option func(*argonParams)
|
||||
|
||||
// WithTime sets Argon2 iterations
|
||||
func WithTime(t uint32) Option {
|
||||
return func(p *argonParams) {
|
||||
if t > 0 {
|
||||
p.time = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithMemory sets Argon2 memory in KiB
|
||||
func WithMemory(m uint32) Option {
|
||||
return func(p *argonParams) {
|
||||
if m > 0 {
|
||||
p.memory = m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithThreads sets Argon2 parallelism
|
||||
func WithThreads(t uint8) Option {
|
||||
return func(p *argonParams) {
|
||||
if t > 0 {
|
||||
p.threads = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HashPassword creates Argon2id PHC-format hash (standalone)
|
||||
func HashPassword(password string, opts ...Option) (string, error) {
|
||||
if len(password) < 8 {
|
||||
return "", ErrWeakPassword
|
||||
}
|
||||
|
||||
// Generate salt
|
||||
salt := make([]byte, DefaultArgonSaltLen)
|
||||
params := &argonParams{
|
||||
time: DefaultArgonTime,
|
||||
memory: DefaultArgonMemory,
|
||||
threads: DefaultArgonThreads,
|
||||
keyLen: DefaultArgonKeyLen,
|
||||
saltLen: DefaultArgonSaltLen,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(params)
|
||||
}
|
||||
|
||||
salt := make([]byte, params.saltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err)
|
||||
}
|
||||
|
||||
// Derive key using Argon2id
|
||||
hash := argon2.IDKey([]byte(password), salt, a.argonTime, a.argonMemory, a.argonThreads, DefaultArgonKeyLen)
|
||||
hash := argon2.IDKey([]byte(password), salt, params.time, params.memory, params.threads, params.keyLen)
|
||||
|
||||
// Construct PHC format
|
||||
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
||||
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, a.argonMemory, a.argonTime, a.argonThreads, saltB64, hashB64)
|
||||
|
||||
return phcHash, nil
|
||||
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, params.memory, params.time, params.threads, saltB64, hashB64), nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks password against PHC-format hash
|
||||
func (a *Authenticator) VerifyPassword(password, phcHash string) error {
|
||||
// Parse PHC format
|
||||
// VerifyPassword checks password against PHC-format hash (standalone)
|
||||
func VerifyPassword(password, phcHash string) error {
|
||||
parts := strings.Split(phcHash, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return ErrPHCInvalidFormat
|
||||
@ -66,10 +111,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error {
|
||||
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
|
||||
}
|
||||
|
||||
// Compute hash with same parameters
|
||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||
|
||||
// Constant-time comparison
|
||||
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
@ -77,9 +120,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential
|
||||
// MigrateFromPHC converts PHC hash to SCRAM credential
|
||||
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
||||
// Parse PHC format
|
||||
parts := strings.Split(phcHash, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return nil, ErrPHCInvalidFormat
|
||||
@ -94,17 +136,10 @@ func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
||||
return nil, ErrPHCInvalidSalt
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return nil, ErrPHCInvalidHash
|
||||
// Use standalone function for verification
|
||||
if err := VerifyPassword(password, phcHash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify password against hash
|
||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Derive SCRAM credential with same parameters
|
||||
return DeriveCredential(username, password, salt, time, memory, threads)
|
||||
}
|
||||
@ -11,13 +11,10 @@ import (
|
||||
)
|
||||
|
||||
func TestPasswordHashing(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err, "Failed to create authenticator")
|
||||
|
||||
password := "testPassword123"
|
||||
|
||||
// Test hashing
|
||||
hash, err := auth.HashPassword(password)
|
||||
// Test hashing with default parameters
|
||||
hash, err := HashPassword(password)
|
||||
require.NoError(t, err, "Failed to hash password")
|
||||
|
||||
// Verify PHC format
|
||||
@ -25,20 +22,30 @@ func TestPasswordHashing(t *testing.T) {
|
||||
"Hash should have argon2id prefix, got: %s", hash)
|
||||
|
||||
// Test verification with correct password
|
||||
err = auth.VerifyPassword(password, hash)
|
||||
err = VerifyPassword(password, hash)
|
||||
assert.NoError(t, err, "Failed to verify correct password")
|
||||
|
||||
// Test verification with incorrect password
|
||||
err = auth.VerifyPassword("wrongPassword", hash)
|
||||
err = VerifyPassword("wrongPassword", hash)
|
||||
assert.Error(t, err, "Verification should fail for incorrect password")
|
||||
assert.Equal(t, ErrInvalidCredentials, err)
|
||||
|
||||
// Test weak password
|
||||
_, err = auth.HashPassword("weak")
|
||||
_, err = HashPassword("weak")
|
||||
assert.Equal(t, ErrWeakPassword, err, "Should reject weak password")
|
||||
|
||||
// Test with custom options
|
||||
hash, err = HashPassword(password,
|
||||
WithTime(5),
|
||||
WithMemory(128*1024),
|
||||
WithThreads(8))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = VerifyPassword(password, hash)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test malformed PHC hash
|
||||
err = auth.VerifyPassword(password, "$invalid$format")
|
||||
err = VerifyPassword(password, "$invalid$format")
|
||||
assert.Error(t, err, "Should reject malformed hash")
|
||||
|
||||
// Test corrupted salt
|
||||
@ -47,33 +54,27 @@ func TestPasswordHashing(t *testing.T) {
|
||||
if len(parts) == 6 {
|
||||
parts[4] = "invalid!base64"
|
||||
corruptedHash = strings.Join(parts, "$")
|
||||
err = auth.VerifyPassword(password, corruptedHash)
|
||||
err = VerifyPassword(password, corruptedHash)
|
||||
assert.Error(t, err, "Should reject corrupted salt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyPasswordAfterValidation(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Empty password should be rejected by length check
|
||||
_, err = auth.HashPassword("")
|
||||
_, err := HashPassword("")
|
||||
assert.Equal(t, ErrWeakPassword, err)
|
||||
|
||||
// 8-character password should pass
|
||||
hash, err := auth.HashPassword("12345678")
|
||||
hash, err := HashPassword("12345678")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = auth.VerifyPassword("12345678", hash)
|
||||
err = VerifyPassword("12345678", hash)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConcurrentPasswordOperations(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
password := "testPassword123"
|
||||
hash, err := auth.HashPassword(password)
|
||||
hash, err := HashPassword(password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test concurrent verification
|
||||
@ -82,7 +83,7 @@ func TestConcurrentPasswordOperations(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := auth.VerifyPassword(password, hash)
|
||||
err := VerifyPassword(password, hash)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
@ -90,14 +91,11 @@ func TestConcurrentPasswordOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPHCMigration(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
password := "testPassword123"
|
||||
username := "migrationUser"
|
||||
|
||||
// Generate PHC hash
|
||||
phcHash, err := auth.HashPassword(password)
|
||||
phcHash, err := HashPassword(password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Migrate to SCRAM credential
|
||||
|
||||
71
auth.go
71
auth.go
@ -1,71 +0,0 @@
|
||||
// FILE: auth/auth.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Authenticator provides password hashing and JWT operations
|
||||
type Authenticator struct {
|
||||
algorithm string
|
||||
jwtSecret []byte // For HS256
|
||||
privateKey *rsa.PrivateKey // For RS256
|
||||
publicKey *rsa.PublicKey // For RS256
|
||||
argonTime uint32
|
||||
argonMemory uint32
|
||||
argonThreads uint8
|
||||
}
|
||||
|
||||
// NewAuthenticator creates a new authenticator with specified algorithm
|
||||
func NewAuthenticator(key any, algorithm ...string) (*Authenticator, error) {
|
||||
alg := "HS256"
|
||||
if len(algorithm) > 0 && algorithm[0] != "" {
|
||||
alg = algorithm[0]
|
||||
}
|
||||
|
||||
auth := &Authenticator{
|
||||
algorithm: alg,
|
||||
argonTime: DefaultArgonTime,
|
||||
argonMemory: DefaultArgonMemory,
|
||||
argonThreads: DefaultArgonThreads,
|
||||
}
|
||||
|
||||
switch alg {
|
||||
case "HS256":
|
||||
secret, ok := key.([]byte)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKeyType
|
||||
}
|
||||
if len(secret) < 32 {
|
||||
return nil, ErrSecretTooShort
|
||||
}
|
||||
auth.jwtSecret = secret
|
||||
|
||||
case "RS256":
|
||||
switch k := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
auth.privateKey = k
|
||||
auth.publicKey = &k.PublicKey
|
||||
case *rsa.PublicKey:
|
||||
auth.publicKey = k
|
||||
case []byte:
|
||||
// Try parsing as PEM
|
||||
if privKey, err := parseRSAPrivateKey(k); err == nil {
|
||||
auth.privateKey = privKey
|
||||
auth.publicKey = &privKey.PublicKey
|
||||
} else if pubKey, err := parseRSAPublicKey(k); err == nil {
|
||||
auth.publicKey = pubKey
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to parse RSA key: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, ErrInvalidKeyType
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, ErrInvalidAlgorithm
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
68
auth_test.go
68
auth_test.go
@ -1,68 +0,0 @@
|
||||
// FILE: auth/auth_test.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewAuthenticator(t *testing.T) {
|
||||
// Test HS256 creation
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err, "Failed to create HS256 authenticator")
|
||||
assert.Equal(t, "HS256", auth.algorithm)
|
||||
|
||||
// Test with short secret
|
||||
_, err = NewAuthenticator([]byte("short"))
|
||||
assert.Equal(t, ErrSecretTooShort, err)
|
||||
|
||||
// Test RS256 with private key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
authRS, err := NewAuthenticator(privateKey, "RS256")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RS256", authRS.algorithm)
|
||||
assert.NotNil(t, authRS.privateKey)
|
||||
assert.NotNil(t, authRS.publicKey)
|
||||
|
||||
// Test RS256 with public key only
|
||||
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RS256", authPub.algorithm)
|
||||
assert.Nil(t, authPub.privateKey)
|
||||
assert.NotNil(t, authPub.publicKey)
|
||||
|
||||
// Test invalid algorithm
|
||||
_, err = NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"), "INVALID")
|
||||
assert.Equal(t, ErrInvalidAlgorithm, err)
|
||||
|
||||
// Test invalid key type for HS256
|
||||
_, err = NewAuthenticator(privateKey, "HS256")
|
||||
assert.Equal(t, ErrInvalidKeyType, err)
|
||||
}
|
||||
|
||||
func TestInterfaceCompliance(t *testing.T) {
|
||||
// Verify Authenticator implements AuthenticatorInterface
|
||||
auth, _ := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
|
||||
var _ AuthenticatorInterface = auth
|
||||
|
||||
// Test interface methods work
|
||||
hash, err := auth.HashPassword("testpass123")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = auth.VerifyPassword("testpass123", hash)
|
||||
assert.NoError(t, err)
|
||||
|
||||
token, err := auth.GenerateToken("user1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID, _, err := auth.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "user1", userID)
|
||||
}
|
||||
53
doc.go
Normal file
53
doc.go
Normal file
@ -0,0 +1,53 @@
|
||||
// FILE: auth/doc.go
|
||||
package auth
|
||||
|
||||
/*
|
||||
Package auth provides modular authentication components:
|
||||
|
||||
# Argon2 Password Hashing
|
||||
|
||||
Standalone password hashing using Argon2id:
|
||||
|
||||
hash, err := auth.HashPassword("password123")
|
||||
err = auth.VerifyPassword("password123", hash)
|
||||
|
||||
// With custom parameters
|
||||
hash, err := auth.HashPassword("password123",
|
||||
auth.WithTime(5),
|
||||
auth.WithMemory(128*1024))
|
||||
|
||||
# JWT Token Management
|
||||
|
||||
JSON Web Token generation and validation:
|
||||
|
||||
// HS256 (symmetric)
|
||||
jwtMgr, _ := auth.NewJWT(secret)
|
||||
token, _ := jwtMgr.GenerateToken("user1", claims)
|
||||
userID, claims, _ := jwtMgr.ValidateToken(token)
|
||||
|
||||
// RS256 (asymmetric)
|
||||
jwtMgr, _ := auth.NewJWTRSA(privateKey)
|
||||
|
||||
// One-off operations
|
||||
token, _ := auth.GenerateHS256Token(secret, "user1", claims, 1*time.Hour)
|
||||
|
||||
# SCRAM-SHA256 Authentication
|
||||
|
||||
Server and client implementation for SCRAM:
|
||||
|
||||
// Server
|
||||
server := auth.NewScramServer()
|
||||
server.AddCredential(credential)
|
||||
|
||||
// Client
|
||||
client := auth.NewScramClient(username, password)
|
||||
|
||||
# HTTP Authentication Parsing
|
||||
|
||||
Utility functions for HTTP headers:
|
||||
|
||||
username, password, _ := auth.ParseBasicAuth(header)
|
||||
token, _ := auth.ParseBearerToken(header)
|
||||
|
||||
Each module can be used independently without initializing other components.
|
||||
*/
|
||||
5
error.go
5
error.go
@ -10,8 +10,6 @@ import (
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrWeakPassword = errors.New("password must be at least 8 characters")
|
||||
ErrInvalidAlgorithm = errors.New("invalid algorithm")
|
||||
ErrInvalidKeyType = errors.New("invalid key type for algorithm")
|
||||
)
|
||||
|
||||
// JWT-specific errors
|
||||
@ -22,9 +20,6 @@ var (
|
||||
ErrTokenInvalidSignature = errors.New("token: invalid signature")
|
||||
ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch")
|
||||
ErrTokenMissingClaim = errors.New("token: missing required claim")
|
||||
ErrTokenInvalidHeader = errors.New("token: invalid header encoding")
|
||||
ErrTokenInvalidClaims = errors.New("token: invalid claims encoding")
|
||||
ErrTokenInvalidJSON = errors.New("token: malformed JSON")
|
||||
ErrTokenEmptyUserID = errors.New("token: empty user ID")
|
||||
ErrTokenNoPrivateKey = errors.New("token: private key required for signing")
|
||||
ErrTokenNoPublicKey = errors.New("token: public key required for verification")
|
||||
|
||||
3
go.mod
3
go.mod
@ -1,8 +1,9 @@
|
||||
module auth
|
||||
module github.com/lixenwraith/auth
|
||||
|
||||
go 1.25.3
|
||||
|
||||
require (
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/crypto v0.43.0
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@ -1,5 +1,7 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
|
||||
17
interface.go
17
interface.go
@ -1,17 +0,0 @@
|
||||
// FILE: auth/interface.go
|
||||
package auth
|
||||
|
||||
// AuthenticatorInterface defines the authentication operations
|
||||
type AuthenticatorInterface interface {
|
||||
HashPassword(password string) (hash string, err error)
|
||||
VerifyPassword(password, hash string) (err error)
|
||||
GenerateToken(userID string, claims map[string]any) (token string, err error)
|
||||
ValidateToken(token string) (userID string, claims map[string]any, err error)
|
||||
}
|
||||
|
||||
// TokenValidator validates bearer tokens
|
||||
type TokenValidator interface {
|
||||
ValidateToken(token string) (valid bool)
|
||||
AddToken(token string)
|
||||
RemoveToken(token string)
|
||||
}
|
||||
408
jwt.go
408
jwt.go
@ -2,179 +2,289 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// GenerateToken creates a JWT token with user claims
|
||||
func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) {
|
||||
// JWT configuration defaults
|
||||
const (
|
||||
DefaultTokenLifetime = 24 * time.Hour
|
||||
DefaultLeeway = 5 * time.Minute
|
||||
)
|
||||
|
||||
// customClaims extends RegisteredClaims with arbitrary user data
|
||||
type customClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Extra map[string]any `json:"extra,omitempty"`
|
||||
}
|
||||
|
||||
// JWT manages token generation and validation
|
||||
type JWT struct {
|
||||
algorithm jwt.SigningMethod
|
||||
signKey any // []byte for HMAC, *rsa.PrivateKey for RSA
|
||||
verifyKey any // []byte for HMAC, *rsa.PublicKey for RSA
|
||||
tokenLifetime time.Duration
|
||||
leeway time.Duration
|
||||
issuer string
|
||||
audience []string
|
||||
}
|
||||
|
||||
// JWTOption configures JWT behavior
|
||||
type JWTOption func(*JWT)
|
||||
|
||||
// WithTokenLifetime sets token expiration duration
|
||||
func WithTokenLifetime(d time.Duration) JWTOption {
|
||||
return func(j *JWT) {
|
||||
if d > 0 {
|
||||
j.tokenLifetime = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithLeeway sets clock skew tolerance
|
||||
func WithLeeway(d time.Duration) JWTOption {
|
||||
return func(j *JWT) {
|
||||
if d >= 0 {
|
||||
j.leeway = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithIssuer sets token issuer claim
|
||||
func WithIssuer(iss string) JWTOption {
|
||||
return func(j *JWT) {
|
||||
j.issuer = iss
|
||||
}
|
||||
}
|
||||
|
||||
// WithAudience sets token audience claim
|
||||
func WithAudience(aud []string) JWTOption {
|
||||
return func(j *JWT) {
|
||||
j.audience = aud
|
||||
}
|
||||
}
|
||||
|
||||
// NewJWT creates JWT manager for HS256 (symmetric)
|
||||
func NewJWT(secret []byte, opts ...JWTOption) (*JWT, error) {
|
||||
if len(secret) < 32 {
|
||||
return nil, ErrSecretTooShort
|
||||
}
|
||||
|
||||
j := &JWT{
|
||||
algorithm: jwt.SigningMethodHS256,
|
||||
signKey: secret,
|
||||
verifyKey: secret,
|
||||
tokenLifetime: DefaultTokenLifetime,
|
||||
leeway: DefaultLeeway,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(j)
|
||||
}
|
||||
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// NewJWTRSA creates JWT manager for RS256 (asymmetric)
|
||||
func NewJWTRSA(privateKey *rsa.PrivateKey, opts ...JWTOption) (*JWT, error) {
|
||||
if privateKey == nil {
|
||||
return nil, ErrTokenNoPrivateKey
|
||||
}
|
||||
|
||||
j := &JWT{
|
||||
algorithm: jwt.SigningMethodRS256,
|
||||
signKey: privateKey,
|
||||
verifyKey: &privateKey.PublicKey,
|
||||
tokenLifetime: DefaultTokenLifetime,
|
||||
leeway: DefaultLeeway,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(j)
|
||||
}
|
||||
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// NewJWTVerifier creates JWT manager for verification only (RS256)
|
||||
func NewJWTVerifier(publicKey *rsa.PublicKey, opts ...JWTOption) (*JWT, error) {
|
||||
if publicKey == nil {
|
||||
return nil, ErrTokenNoPublicKey
|
||||
}
|
||||
|
||||
j := &JWT{
|
||||
algorithm: jwt.SigningMethodRS256,
|
||||
signKey: nil, // Cannot sign
|
||||
verifyKey: publicKey,
|
||||
tokenLifetime: DefaultTokenLifetime,
|
||||
leeway: DefaultLeeway,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(j)
|
||||
}
|
||||
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// GenerateToken creates signed JWT with claims
|
||||
func (j *JWT) GenerateToken(userID string, claims map[string]any) (string, error) {
|
||||
if userID == "" {
|
||||
return "", ErrTokenEmptyUserID
|
||||
}
|
||||
|
||||
if a.algorithm == "RS256" && a.privateKey == nil {
|
||||
if j.signKey == nil {
|
||||
return "", ErrTokenNoPrivateKey
|
||||
}
|
||||
|
||||
// Build JWT claims
|
||||
now := time.Now()
|
||||
jwtClaims := map[string]any{
|
||||
"sub": userID,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry
|
||||
registeredClaims := jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
Issuer: j.issuer,
|
||||
Audience: j.audience,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(j.tokenLifetime)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
}
|
||||
|
||||
// Reserved claims that cannot be overridden
|
||||
reservedClaims := map[string]bool{
|
||||
"sub": true, "iat": true, "exp": true, "nbf": true,
|
||||
"iss": true, "aud": true, "jti": true, "typ": true,
|
||||
"alg": true,
|
||||
}
|
||||
token := jwt.NewWithClaims(j.algorithm, customClaims{
|
||||
RegisteredClaims: registeredClaims,
|
||||
Extra: claims,
|
||||
})
|
||||
|
||||
// Merge custom claims
|
||||
for k, v := range claims {
|
||||
if !reservedClaims[k] {
|
||||
jwtClaims[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Create JWT header
|
||||
header := map[string]any{
|
||||
"alg": a.algorithm,
|
||||
"typ": "JWT",
|
||||
}
|
||||
|
||||
// Encode header and payload
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
claimsJSON, _ := json.Marshal(jwtClaims)
|
||||
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
// Create signature
|
||||
signingInput := headerB64 + "." + claimsB64
|
||||
var signature string
|
||||
|
||||
switch a.algorithm {
|
||||
case "HS256":
|
||||
h := hmac.New(sha256.New, a.jwtSecret)
|
||||
h.Write([]byte(signingInput))
|
||||
signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
case "RS256":
|
||||
hashed := sha256.Sum256([]byte(signingInput))
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
signature = base64.RawURLEncoding.EncodeToString(sig)
|
||||
}
|
||||
|
||||
// Combine to form JWT
|
||||
token := signingInput + "." + signature
|
||||
|
||||
return token, nil
|
||||
return token.SignedString(j.signKey)
|
||||
}
|
||||
|
||||
// ValidateToken verifies JWT and returns userID and claims
|
||||
func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) {
|
||||
// Split token
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
// ValidateToken verifies JWT and extracts claims
|
||||
func (j *JWT) ValidateToken(tokenString string) (string, map[string]any, error) {
|
||||
parser := jwt.NewParser(
|
||||
jwt.WithLeeway(j.leeway),
|
||||
jwt.WithAudience(j.audience...),
|
||||
jwt.WithIssuer(j.issuer),
|
||||
jwt.WithValidMethods([]string{j.algorithm.Alg()}),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
|
||||
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// Algorithm already validated by WithValidMethods
|
||||
return j.verifyKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", nil, mapJWTError(err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*customClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", nil, ErrTokenMalformed
|
||||
}
|
||||
|
||||
// Decode header to check algorithm
|
||||
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidHeader
|
||||
}
|
||||
|
||||
var header map[string]any
|
||||
if err = json.Unmarshal(headerJSON, &header); err != nil {
|
||||
return "", nil, ErrTokenInvalidJSON
|
||||
}
|
||||
|
||||
// Verify algorithm matches
|
||||
if alg, ok := header["alg"].(string); !ok || alg != a.algorithm {
|
||||
return "", nil, ErrTokenAlgorithmMismatch
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
signingInput := parts[0] + "." + parts[1]
|
||||
|
||||
switch a.algorithm {
|
||||
case "HS256":
|
||||
h := hmac.New(sha256.New, a.jwtSecret)
|
||||
h.Write([]byte(signingInput))
|
||||
expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
|
||||
case "RS256":
|
||||
if a.publicKey == nil {
|
||||
return "", nil, ErrTokenNoPublicKey
|
||||
}
|
||||
|
||||
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
|
||||
hashed := sha256.Sum256([]byte(signingInput))
|
||||
if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
}
|
||||
|
||||
// Decode claims
|
||||
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
|
||||
return "", nil, ErrTokenInvalidJSON
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
return "", nil, ErrTokenExpired
|
||||
}
|
||||
}
|
||||
|
||||
// Check not before
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if time.Now().Unix() < int64(nbf) {
|
||||
return "", nil, ErrTokenNotYetValid
|
||||
}
|
||||
}
|
||||
|
||||
// Extract userID
|
||||
userID, ok := claims["sub"].(string)
|
||||
if !ok {
|
||||
return "", nil, ErrTokenMissingClaim
|
||||
}
|
||||
|
||||
return userID, claims, nil
|
||||
return claims.Subject, claims.Extra, nil
|
||||
}
|
||||
|
||||
// parseRSAPrivateKey parses PEM encoded RSA private key
|
||||
// mapJWTError translates jwt library errors to auth package errors
|
||||
func mapJWTError(err error) error {
|
||||
switch {
|
||||
case errors.Is(err, jwt.ErrTokenMalformed):
|
||||
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||
case errors.Is(err, jwt.ErrTokenUnverifiable):
|
||||
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
|
||||
return fmt.Errorf("%w : %w", ErrTokenInvalidSignature, err)
|
||||
case errors.Is(err, jwt.ErrTokenExpired):
|
||||
return fmt.Errorf("%w : %w", ErrTokenExpired, err)
|
||||
case errors.Is(err, jwt.ErrTokenNotValidYet):
|
||||
return fmt.Errorf("%w : %w", ErrTokenNotYetValid, err)
|
||||
case errors.Is(err, jwt.ErrTokenInvalidAudience):
|
||||
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
|
||||
case errors.Is(err, jwt.ErrTokenInvalidIssuer):
|
||||
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
|
||||
default:
|
||||
// Check for algorithm mismatch in error message
|
||||
if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
|
||||
return fmt.Errorf("%w : %w", ErrTokenAlgorithmMismatch, err)
|
||||
}
|
||||
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Standalone helper functions for one-off operations
|
||||
|
||||
// GenerateHS256Token creates HS256 JWT without manager instance
|
||||
func GenerateHS256Token(secret []byte, userID string, claims map[string]any, lifetime time.Duration) (string, error) {
|
||||
if len(secret) < 32 {
|
||||
return "", ErrSecretTooShort
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, customClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(lifetime)),
|
||||
},
|
||||
Extra: claims,
|
||||
})
|
||||
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// ValidateHS256Token verifies HS256 JWT without manager instance
|
||||
func ValidateHS256Token(secret []byte, tokenString string) (string, map[string]any, error) {
|
||||
if len(secret) < 32 {
|
||||
return "", nil, ErrSecretTooShort
|
||||
}
|
||||
|
||||
parser := jwt.NewParser(
|
||||
jwt.WithValidMethods([]string{"HS256"}),
|
||||
jwt.WithLeeway(DefaultLeeway),
|
||||
)
|
||||
|
||||
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
|
||||
return secret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", nil, mapJWTError(err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*customClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", nil, ErrTokenMalformed
|
||||
}
|
||||
|
||||
return claims.Subject, claims.Extra, nil
|
||||
}
|
||||
|
||||
// RSA Utilities
|
||||
|
||||
// NewJWTRSAFromPEM creates a JWT manager for RS256 from raw PEM-encoded private key data.
|
||||
func NewJWTRSAFromPEM(privateKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
|
||||
privateKey, err := parseRSAPrivateKey(privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Call the original constructor with the now-parsed key
|
||||
return NewJWTRSA(privateKey, opts...)
|
||||
}
|
||||
|
||||
// NewJWTVerifierFromPEM creates a JWT manager for verification from raw PEM-encoded public key data.
|
||||
func NewJWTVerifierFromPEM(publicKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
|
||||
publicKey, err := parseRSAPublicKey(publicKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Call the original constructor with the now-parsed key
|
||||
return NewJWTVerifier(publicKey, opts...)
|
||||
}
|
||||
|
||||
// parseRSAPrivateKey parses a PEM-encoded RSA private key.
|
||||
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
@ -187,7 +297,7 @@ func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// parseRSAPublicKey parses PEM encoded RSA public key
|
||||
// parseRSAPublicKey parses a PEM-encoded RSA public key.
|
||||
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
|
||||
283
jwt_test.go
283
jwt_test.go
@ -4,19 +4,20 @@ package auth
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJWTHS256(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
jwtMgr, err := NewJWT(secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user123"
|
||||
@ -26,54 +27,17 @@ func TestJWTHS256(t *testing.T) {
|
||||
}
|
||||
|
||||
// Generate token
|
||||
token, err := auth.GenerateToken(userID, claims)
|
||||
require.NoError(t, err, "Failed to generate token")
|
||||
token, err := jwtMgr.GenerateToken(userID, claims)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Validate token
|
||||
extractedUserID, extractedClaims, err := auth.ValidateToken(token)
|
||||
require.NoError(t, err, "Failed to validate token")
|
||||
extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
assert.Equal(t, "test@example.com", extractedClaims["email"])
|
||||
assert.Equal(t, "admin", extractedClaims["role"])
|
||||
|
||||
// Test invalid token
|
||||
_, _, err = auth.ValidateToken("invalid.token.here")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrTokenInvalidJSON))
|
||||
|
||||
// Test tampered token
|
||||
parts := strings.Split(token, ".")
|
||||
require.Len(t, parts, 3, "JWT should have 3 parts")
|
||||
|
||||
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
|
||||
_, _, err = auth.ValidateToken(tampered)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrTokenInvalidSignature))
|
||||
|
||||
// Test reserved claims cannot be overridden
|
||||
overrideClaims := map[string]any{
|
||||
"sub": "override",
|
||||
"iat": 12345,
|
||||
"exp": 67890,
|
||||
"nbf": 11111,
|
||||
"iss": "attacker",
|
||||
"aud": "victim",
|
||||
"jti": "fake",
|
||||
}
|
||||
|
||||
token, err = auth.GenerateToken(userID, overrideClaims)
|
||||
require.NoError(t, err)
|
||||
|
||||
extractedUserID, extractedClaims, err = auth.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID, "UserID should not be overridden")
|
||||
assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden")
|
||||
assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden")
|
||||
assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims")
|
||||
assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims")
|
||||
}
|
||||
|
||||
func TestJWTRS256(t *testing.T) {
|
||||
@ -82,95 +46,214 @@ func TestJWTRS256(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with private key (can sign and verify)
|
||||
authPriv, err := NewAuthenticator(privateKey, "RS256")
|
||||
jwtMgr, err := NewJWTRSA(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user456"
|
||||
claims := map[string]any{
|
||||
"email": "rs256@example.com",
|
||||
"scope": "read:all",
|
||||
}
|
||||
|
||||
// Generate token
|
||||
token, err := authPriv.GenerateToken(userID, claims)
|
||||
token, err := jwtMgr.GenerateToken(userID, claims)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Validate token with private key auth (has public key too)
|
||||
extractedUserID, extractedClaims, err := authPriv.ValidateToken(token)
|
||||
// Validate with same manager
|
||||
extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
assert.Equal(t, "rs256@example.com", extractedClaims["email"])
|
||||
assert.Equal(t, "read:all", extractedClaims["scope"])
|
||||
|
||||
// Test with public key only (can only verify)
|
||||
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
||||
// Test with verifier only (public key)
|
||||
verifier, err := NewJWTVerifier(&privateKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be able to validate token
|
||||
extractedUserID, extractedClaims, err = authPub.ValidateToken(token)
|
||||
// Should validate token
|
||||
extractedUserID, _, err = verifier.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
|
||||
// Should not be able to generate token
|
||||
_, err = authPub.GenerateToken(userID, claims)
|
||||
assert.Error(t, err, "Public key only auth should not generate tokens")
|
||||
// Should not generate token
|
||||
_, err = verifier.GenerateToken(userID, claims)
|
||||
assert.Equal(t, ErrTokenNoPrivateKey, err)
|
||||
|
||||
// Test algorithm mismatch
|
||||
authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = authHS256.ValidateToken(token)
|
||||
assert.Error(t, err, "HS256 auth should not validate RS256 token")
|
||||
// assert.True(t, errors.Is(err, ErrInvalidToken))
|
||||
assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch))
|
||||
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
func TestExpiredToken(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
func TestJWTOptions(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
|
||||
// Test custom lifetime
|
||||
jwtMgr, err := NewJWT(secret,
|
||||
WithTokenLifetime(1*time.Hour),
|
||||
WithIssuer("test-issuer"),
|
||||
WithAudience([]string{"api.example.com"}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user123"
|
||||
|
||||
// Generate normal token (should have 7 days expiry)
|
||||
token, err := auth.GenerateToken(userID, nil)
|
||||
token, err := jwtMgr.GenerateToken("user1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, extractedClaims, err := auth.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
// Parse token to check claims
|
||||
parsed, _ := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
||||
return secret, nil
|
||||
})
|
||||
|
||||
// Check expiry is in future (approximately 7 days)
|
||||
expiry := extractedClaims["exp"].(float64)
|
||||
now := time.Now().Unix()
|
||||
claims := parsed.Claims.(jwt.MapClaims)
|
||||
|
||||
assert.Greater(t, expiry, float64(now), "Token expiry should be in future")
|
||||
assert.InDelta(t, expiry, float64(now+7*24*60*60), 10,
|
||||
"Token expiry should be approximately 7 days from now")
|
||||
// Check issuer
|
||||
assert.Equal(t, "test-issuer", claims["iss"])
|
||||
|
||||
// Check audience
|
||||
aud := claims["aud"].([]any)
|
||||
assert.Contains(t, aud, "api.example.com")
|
||||
|
||||
// Check expiration is ~1 hour
|
||||
exp := int64(claims["exp"].(float64))
|
||||
iat := int64(claims["iat"].(float64))
|
||||
assert.InDelta(t, 3600, exp-iat, 10)
|
||||
}
|
||||
|
||||
func TestCorruptJWTParts(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
func TestJWTErrors(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
jwtMgr, err := NewJWT(secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with missing parts
|
||||
_, _, err = auth.ValidateToken("only.two")
|
||||
assert.True(t, errors.Is(err, ErrTokenMalformed))
|
||||
// Empty user ID
|
||||
_, err = jwtMgr.GenerateToken("", nil)
|
||||
assert.Equal(t, ErrTokenEmptyUserID, err)
|
||||
|
||||
// Test with invalid header encoding
|
||||
_, _, err = auth.ValidateToken("not-base64!.valid.valid")
|
||||
assert.Error(t, err)
|
||||
// Invalid token format
|
||||
_, _, err = jwtMgr.ValidateToken("invalid.token")
|
||||
assert.ErrorIs(t, err, ErrTokenMalformed)
|
||||
|
||||
// Test with invalid claims encoding
|
||||
validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
_, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid")
|
||||
assert.Error(t, err)
|
||||
// Tampered signature
|
||||
token, _ := jwtMgr.GenerateToken("user1", nil)
|
||||
parts := strings.Split(token, ".")
|
||||
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
|
||||
_, _, err = jwtMgr.ValidateToken(tampered)
|
||||
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
|
||||
|
||||
// Test with invalid JSON in claims
|
||||
invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json"))
|
||||
_, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature")
|
||||
assert.Error(t, err)
|
||||
// Wrong algorithm
|
||||
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
rsaMgr, _ := NewJWTRSA(rsaKey)
|
||||
rsaToken, _ := rsaMgr.GenerateToken("user1", nil)
|
||||
|
||||
_, _, err = jwtMgr.ValidateToken(rsaToken)
|
||||
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
|
||||
}
|
||||
|
||||
func TestJWTExpiration(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
|
||||
// Create token with 1 second lifetime
|
||||
jwtMgr, err := NewJWT(secret, WithTokenLifetime(1*time.Second), WithLeeway(0))
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := jwtMgr.GenerateToken("user1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be valid immediately
|
||||
_, _, err = jwtMgr.ValidateToken(token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Should be expired
|
||||
_, _, err = jwtMgr.ValidateToken(token)
|
||||
assert.ErrorIs(t, err, ErrTokenExpired)
|
||||
}
|
||||
|
||||
func TestLeeway(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
|
||||
// Create manager with no leeway
|
||||
jwtMgr, err := NewJWT(secret, WithLeeway(0))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually create a token with NotBefore in future
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": "user1",
|
||||
"nbf": now.Add(2 * time.Second).Unix(),
|
||||
"exp": now.Add(1 * time.Hour).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString(secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should fail immediately (not valid yet)
|
||||
_, _, err = jwtMgr.ValidateToken(tokenString)
|
||||
assert.ErrorIs(t, err, ErrTokenNotYetValid)
|
||||
|
||||
// Create manager with leeway
|
||||
jwtMgrWithLeeway, err := NewJWT(secret, WithLeeway(5*time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should pass with leeway
|
||||
_, _, err = jwtMgrWithLeeway.ValidateToken(tokenString)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStandaloneFunctions(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
userID := "standalone-user"
|
||||
claims := map[string]any{"test": "value"}
|
||||
|
||||
// Generate token
|
||||
token, err := GenerateHS256Token(secret, userID, claims, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate token
|
||||
extractedUserID, extractedClaims, err := ValidateHS256Token(secret, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
assert.Equal(t, "value", extractedClaims["test"])
|
||||
|
||||
// Test with short secret
|
||||
_, err = GenerateHS256Token([]byte("short"), userID, claims, 1*time.Hour)
|
||||
assert.Equal(t, ErrSecretTooShort, err)
|
||||
}
|
||||
|
||||
func TestJWTRSAFromPEM(t *testing.T) {
|
||||
// 1. Generate a new RSA key pair for this test
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Encode the private key to PEM format
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
// 3. Encode the public key to PEM format
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyBytes,
|
||||
})
|
||||
|
||||
// 4. Test the PEM constructor for the signer
|
||||
jwtMgr, err := NewJWTRSAFromPEM(privateKeyPEM)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := jwtMgr.GenerateToken("user-from-pem", nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 5. Test the PEM constructor for the verifier
|
||||
verifier, err := NewJWTVerifierFromPEM(publicKeyPEM)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID, _, err := verifier.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "user-from-pem", userID)
|
||||
|
||||
// 6. Test failure cases with invalid data
|
||||
_, err = NewJWTRSAFromPEM([]byte("invalid pem data"))
|
||||
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
|
||||
|
||||
_, err = NewJWTVerifierFromPEM([]byte("invalid pem data"))
|
||||
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
|
||||
}
|
||||
4
scram.go
4
scram.go
@ -143,6 +143,10 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
|
||||
return nil, ErrSCRAMSaltTooShort
|
||||
}
|
||||
|
||||
if time == 0 || memory == 0 || threads == 0 {
|
||||
return nil, ErrSCRAMZeroParams
|
||||
}
|
||||
|
||||
// Derive salted password using Argon2id
|
||||
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen)
|
||||
|
||||
|
||||
146
scram_test.go
Normal file
146
scram_test.go
Normal file
@ -0,0 +1,146 @@
|
||||
// FILE: auth/scram_test.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setupScramTest is a helper to initialize a server and user credential for testing.
|
||||
// It performs the full Argon2 -> SCRAM migration workflow.
|
||||
func setupScramTest(t *testing.T) (server *ScramServer, username, password string, cred *Credential) {
|
||||
username = "testuser"
|
||||
password = "SecurePassword123"
|
||||
|
||||
// 1. Start with an Argon2 PHC hash, as a real application would.
|
||||
phcHash, err := HashPassword(password)
|
||||
require.NoError(t, err, "Setup failed: could not hash password")
|
||||
|
||||
// 2. Migrate the PHC hash to a SCRAM credential.
|
||||
cred, err = MigrateFromPHC(username, password, phcHash)
|
||||
require.NoError(t, err, "Setup failed: could not migrate from PHC hash")
|
||||
|
||||
// 3. Create a server and add the new credential.
|
||||
server = NewScramServer()
|
||||
server.AddCredential(cred)
|
||||
|
||||
return server, username, password, cred
|
||||
}
|
||||
|
||||
// TestScram_FullRoundtrip_Success simulates a complete, successful authentication handshake.
|
||||
func TestScram_FullRoundtrip_Success(t *testing.T) {
|
||||
server, username, password, _ := setupScramTest(t)
|
||||
client := NewScramClient(username, password)
|
||||
|
||||
// --- Step 1: Client sends its first message ---
|
||||
clientFirst, err := client.StartAuthentication()
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- Step 2: Server receives client's message and responds ---
|
||||
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||
require.NoError(t, err, "Server failed to process client's first message")
|
||||
|
||||
// --- Step 3: Client receives server's message, computes proof ---
|
||||
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
||||
require.NoError(t, err, "Client failed to process server's first message")
|
||||
|
||||
// --- Step 4: Server receives client's proof and verifies it ---
|
||||
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
||||
require.NoError(t, err, "Server failed to verify client's final proof")
|
||||
assert.NotEmpty(t, serverFinal.ServerSignature, "Server signature should not be empty")
|
||||
|
||||
// --- Step 5: Client verifies server's signature (mutual authentication) ---
|
||||
err = client.VerifyServerFinalMessage(serverFinal)
|
||||
assert.NoError(t, err, "Client failed to verify server's final signature")
|
||||
|
||||
t.Log("SCRAM full roundtrip successful")
|
||||
}
|
||||
|
||||
// TestScram_FullRoundtrip_WrongPassword ensures authentication fails with an incorrect password.
|
||||
func TestScram_FullRoundtrip_WrongPassword(t *testing.T) {
|
||||
server, username, _, _ := setupScramTest(t)
|
||||
// Create a client with the WRONG password
|
||||
client := NewScramClient(username, "WrongPassword!!!")
|
||||
|
||||
// Steps 1-3 will appear to succeed, as the client doesn't know the password is wrong yet.
|
||||
clientFirst, err := client.StartAuthentication()
|
||||
require.NoError(t, err)
|
||||
|
||||
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- Step 4: Server verification should fail here ---
|
||||
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
||||
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should reject proof from wrong password")
|
||||
|
||||
t.Log("SCRAM correctly failed for wrong password")
|
||||
}
|
||||
|
||||
// TestScram_FullRoundtrip_UserNotFound tests for user enumeration protection.
|
||||
// The server should not reveal whether a user exists or not in its first message.
|
||||
func TestScram_FullRoundtrip_UserNotFound(t *testing.T) {
|
||||
server, _, _, _ := setupScramTest(t)
|
||||
client := NewScramClient("unknown_user", "any_password")
|
||||
|
||||
clientFirst, err := client.StartAuthentication()
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- Step 2: Server should return an error, but also a FAKE response ---
|
||||
// This prevents an attacker from knowing if the user exists based on the response structure.
|
||||
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should return an error for an unknown user")
|
||||
assert.NotEmpty(t, serverFirst.FullNonce, "Server must still provide a nonce to prevent enumeration")
|
||||
assert.NotEmpty(t, serverFirst.Salt, "Server must still provide a salt to prevent enumeration")
|
||||
|
||||
t.Log("SCRAM correctly protected against user enumeration")
|
||||
}
|
||||
|
||||
// TestScram_InvalidNonce simulates a replay attack or message mismatch.
|
||||
func TestScram_InvalidNonce(t *testing.T) {
|
||||
server, username, password, _ := setupScramTest(t)
|
||||
client := NewScramClient(username, password)
|
||||
|
||||
// Perform the first part of the handshake
|
||||
clientFirst, _ := client.StartAuthentication()
|
||||
serverFirst, _ := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||
clientFinal, _ := client.ProcessServerFirstMessage(serverFirst)
|
||||
|
||||
// Attempt to finalize with a completely different nonce
|
||||
_, err := server.ProcessClientFinalMessage("this-is-a-bad-nonce", clientFinal.ClientProof)
|
||||
assert.ErrorIs(t, err, ErrSCRAMInvalidNonce, "Server should reject a final message with an unknown nonce")
|
||||
}
|
||||
|
||||
// TestScram_CredentialImportExport verifies that credentials can be serialized and deserialized correctly.
|
||||
func TestScram_CredentialImportExport(t *testing.T) {
|
||||
_, _, _, originalCred := setupScramTest(t)
|
||||
|
||||
// Export the credential to a map
|
||||
exportedData := originalCred.Export()
|
||||
require.NotNil(t, exportedData)
|
||||
|
||||
// Assert that required fields exist and are strings (as they are base64 encoded)
|
||||
assert.IsType(t, "", exportedData["salt"])
|
||||
assert.IsType(t, "", exportedData["stored_key"])
|
||||
assert.IsType(t, "", exportedData["server_key"])
|
||||
|
||||
// Import the credential back from the map
|
||||
importedCred, err := ImportCredential(exportedData)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, importedCred)
|
||||
|
||||
// Verify that the imported credential is identical to the original
|
||||
assert.Equal(t, originalCred.Username, importedCred.Username)
|
||||
assert.Equal(t, originalCred.Salt, importedCred.Salt)
|
||||
assert.Equal(t, originalCred.ArgonTime, importedCred.ArgonTime)
|
||||
assert.Equal(t, originalCred.ArgonMemory, importedCred.ArgonMemory)
|
||||
assert.Equal(t, originalCred.ArgonThreads, importedCred.ArgonThreads)
|
||||
assert.Equal(t, originalCred.StoredKey, importedCred.StoredKey)
|
||||
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
|
||||
|
||||
t.Log("SCRAM credential import/export successful")
|
||||
}
|
||||
Reference in New Issue
Block a user