v0.2.0 restructured and generalized to be more modular, added golang-jwt dependency
This commit is contained in:
41
README.md
41
README.md
@ -1,38 +1,45 @@
|
|||||||
# Auth Package
|
# Auth Package
|
||||||
|
|
||||||
Pluggable authentication utilities for Go applications.
|
Modular authentication utilities for Go applications.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Password Hashing**: Argon2id with PHC format
|
- **Password Hashing**: Standalone Argon2id hashing with PHC format.
|
||||||
- **JWT**: HS256/RS256 token generation and validation
|
- **JWT**: HS256/RS256 token management via a simple facade over `golang-jwt`.
|
||||||
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF
|
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF.
|
||||||
- **HTTP Auth**: Basic/Bearer header parsing
|
- **HTTP Auth**: Helpers for parsing Basic and Bearer authentication headers.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
```go
|
```go
|
||||||
|
// Argon2 Password Hashing
|
||||||
|
hash, _ := auth.HashPassword("password123")
|
||||||
|
err := auth.VerifyPassword("password123", hash)
|
||||||
|
|
||||||
// JWT with HS256
|
// JWT with HS256
|
||||||
auth, _ := auth.NewAuthenticator([]byte("32-byte-secret-key..."))
|
jwtMgr, _ := auth.NewJWT([]byte("a-very-secure-32-byte-secret-key"))
|
||||||
token, _ := auth.GenerateToken("user123", map[string]interface{}{"role": "admin"})
|
token, _ := jwtMgr.GenerateToken("user123", map[string]any{"role": "admin"})
|
||||||
userID, claims, _ := auth.ValidateToken(token)
|
userID, claims, _ := jwtMgr.ValidateToken(token)
|
||||||
|
|
||||||
// SCRAM authentication
|
// SCRAM authentication
|
||||||
server := auth.NewScramServer()
|
server := auth.NewScramServer()
|
||||||
cred, _ := auth.DeriveCredential("user", "password", salt, 1, 65536, 4)
|
phcHash, _ := auth.HashPassword("password123")
|
||||||
|
cred, _ := auth.MigrateFromPHC("user", "password123", phcHash)
|
||||||
server.AddCredential(cred)
|
server.AddCredential(cred)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Package Structure
|
## Package Structure
|
||||||
|
|
||||||
- `interfaces.go` - Core interfaces
|
- `doc.go` - Overview and package documentation
|
||||||
- `jwt.go` - JWT token operations
|
- `argon2.go` - Standalone Argon2id password hashing
|
||||||
- `argon2.go` - Password hashing
|
- `jwt.go` - JWT manager (HS256/RS256) wrapping `golang-jwt`
|
||||||
- `scram.go` - SCRAM-SHA256 protocol
|
- `scram.go` - SCRAM-SHA256 client/server protocol
|
||||||
- `token.go` - Token validation utilities
|
- `http.go` - HTTP Basic/Bearer header parsing
|
||||||
- `http.go` - HTTP header parsing
|
- `token.go` - Simple in-memory token validator
|
||||||
- `errors.go` - Error definitions
|
- `error.go` - Centralized error definitions
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
go test -v ./auth
|
go test -v ./
|
||||||
```
|
```
|
||||||
93
argon2.go
93
argon2.go
@ -13,40 +13,85 @@ import (
|
|||||||
|
|
||||||
// Default Argon2id parameters
|
// Default Argon2id parameters
|
||||||
const (
|
const (
|
||||||
DefaultArgonTime = 3 // iterations (reduce for faster but less secure auth)
|
DefaultArgonTime = 3 // iterations
|
||||||
DefaultArgonMemory = 64 * 1024 // 64 MB
|
DefaultArgonMemory = 64 * 1024 // 64 MB
|
||||||
DefaultArgonThreads = 4
|
DefaultArgonThreads = 4
|
||||||
DefaultArgonSaltLen = 16
|
DefaultArgonSaltLen = 16
|
||||||
DefaultArgonKeyLen = 32
|
DefaultArgonKeyLen = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
// HashPassword creates an Argon2id PHC-format hash
|
// argonParams holds configurable Argon2id parameters
|
||||||
func (a *Authenticator) HashPassword(password string) (string, error) {
|
type argonParams struct {
|
||||||
|
time uint32
|
||||||
|
memory uint32
|
||||||
|
threads uint8
|
||||||
|
keyLen uint32
|
||||||
|
saltLen uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures Argon2id hashing parameters
|
||||||
|
type Option func(*argonParams)
|
||||||
|
|
||||||
|
// WithTime sets Argon2 iterations
|
||||||
|
func WithTime(t uint32) Option {
|
||||||
|
return func(p *argonParams) {
|
||||||
|
if t > 0 {
|
||||||
|
p.time = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMemory sets Argon2 memory in KiB
|
||||||
|
func WithMemory(m uint32) Option {
|
||||||
|
return func(p *argonParams) {
|
||||||
|
if m > 0 {
|
||||||
|
p.memory = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithThreads sets Argon2 parallelism
|
||||||
|
func WithThreads(t uint8) Option {
|
||||||
|
return func(p *argonParams) {
|
||||||
|
if t > 0 {
|
||||||
|
p.threads = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashPassword creates Argon2id PHC-format hash (standalone)
|
||||||
|
func HashPassword(password string, opts ...Option) (string, error) {
|
||||||
if len(password) < 8 {
|
if len(password) < 8 {
|
||||||
return "", ErrWeakPassword
|
return "", ErrWeakPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate salt
|
params := &argonParams{
|
||||||
salt := make([]byte, DefaultArgonSaltLen)
|
time: DefaultArgonTime,
|
||||||
|
memory: DefaultArgonMemory,
|
||||||
|
threads: DefaultArgonThreads,
|
||||||
|
keyLen: DefaultArgonKeyLen,
|
||||||
|
saltLen: DefaultArgonSaltLen,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
salt := make([]byte, params.saltLen)
|
||||||
if _, err := rand.Read(salt); err != nil {
|
if _, err := rand.Read(salt); err != nil {
|
||||||
return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err)
|
return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Derive key using Argon2id
|
hash := argon2.IDKey([]byte(password), salt, params.time, params.memory, params.threads, params.keyLen)
|
||||||
hash := argon2.IDKey([]byte(password), salt, a.argonTime, a.argonMemory, a.argonThreads, DefaultArgonKeyLen)
|
|
||||||
|
|
||||||
// Construct PHC format
|
|
||||||
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||||
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
||||||
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||||
argon2.Version, a.argonMemory, a.argonTime, a.argonThreads, saltB64, hashB64)
|
argon2.Version, params.memory, params.time, params.threads, saltB64, hashB64), nil
|
||||||
|
|
||||||
return phcHash, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyPassword checks password against PHC-format hash
|
// VerifyPassword checks password against PHC-format hash (standalone)
|
||||||
func (a *Authenticator) VerifyPassword(password, phcHash string) error {
|
func VerifyPassword(password, phcHash string) error {
|
||||||
// Parse PHC format
|
|
||||||
parts := strings.Split(phcHash, "$")
|
parts := strings.Split(phcHash, "$")
|
||||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||||
return ErrPHCInvalidFormat
|
return ErrPHCInvalidFormat
|
||||||
@ -66,10 +111,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error {
|
|||||||
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
|
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute hash with same parameters
|
|
||||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||||
|
|
||||||
// Constant-time comparison
|
|
||||||
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
||||||
return ErrInvalidCredentials
|
return ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
@ -77,9 +120,8 @@ func (a *Authenticator) VerifyPassword(password, phcHash string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential
|
// MigrateFromPHC converts PHC hash to SCRAM credential
|
||||||
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
||||||
// Parse PHC format
|
|
||||||
parts := strings.Split(phcHash, "$")
|
parts := strings.Split(phcHash, "$")
|
||||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||||
return nil, ErrPHCInvalidFormat
|
return nil, ErrPHCInvalidFormat
|
||||||
@ -94,17 +136,10 @@ func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
|||||||
return nil, ErrPHCInvalidSalt
|
return nil, ErrPHCInvalidSalt
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
// Use standalone function for verification
|
||||||
if err != nil {
|
if err := VerifyPassword(password, phcHash); err != nil {
|
||||||
return nil, ErrPHCInvalidHash
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify password against hash
|
|
||||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
|
||||||
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
|
||||||
return nil, ErrInvalidCredentials
|
|
||||||
}
|
|
||||||
|
|
||||||
// Derive SCRAM credential with same parameters
|
|
||||||
return DeriveCredential(username, password, salt, time, memory, threads)
|
return DeriveCredential(username, password, salt, time, memory, threads)
|
||||||
}
|
}
|
||||||
@ -11,13 +11,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPasswordHashing(t *testing.T) {
|
func TestPasswordHashing(t *testing.T) {
|
||||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
|
||||||
require.NoError(t, err, "Failed to create authenticator")
|
|
||||||
|
|
||||||
password := "testPassword123"
|
password := "testPassword123"
|
||||||
|
|
||||||
// Test hashing
|
// Test hashing with default parameters
|
||||||
hash, err := auth.HashPassword(password)
|
hash, err := HashPassword(password)
|
||||||
require.NoError(t, err, "Failed to hash password")
|
require.NoError(t, err, "Failed to hash password")
|
||||||
|
|
||||||
// Verify PHC format
|
// Verify PHC format
|
||||||
@ -25,20 +22,30 @@ func TestPasswordHashing(t *testing.T) {
|
|||||||
"Hash should have argon2id prefix, got: %s", hash)
|
"Hash should have argon2id prefix, got: %s", hash)
|
||||||
|
|
||||||
// Test verification with correct password
|
// Test verification with correct password
|
||||||
err = auth.VerifyPassword(password, hash)
|
err = VerifyPassword(password, hash)
|
||||||
assert.NoError(t, err, "Failed to verify correct password")
|
assert.NoError(t, err, "Failed to verify correct password")
|
||||||
|
|
||||||
// Test verification with incorrect password
|
// Test verification with incorrect password
|
||||||
err = auth.VerifyPassword("wrongPassword", hash)
|
err = VerifyPassword("wrongPassword", hash)
|
||||||
assert.Error(t, err, "Verification should fail for incorrect password")
|
assert.Error(t, err, "Verification should fail for incorrect password")
|
||||||
assert.Equal(t, ErrInvalidCredentials, err)
|
assert.Equal(t, ErrInvalidCredentials, err)
|
||||||
|
|
||||||
// Test weak password
|
// Test weak password
|
||||||
_, err = auth.HashPassword("weak")
|
_, err = HashPassword("weak")
|
||||||
assert.Equal(t, ErrWeakPassword, err, "Should reject weak password")
|
assert.Equal(t, ErrWeakPassword, err, "Should reject weak password")
|
||||||
|
|
||||||
|
// Test with custom options
|
||||||
|
hash, err = HashPassword(password,
|
||||||
|
WithTime(5),
|
||||||
|
WithMemory(128*1024),
|
||||||
|
WithThreads(8))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = VerifyPassword(password, hash)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Test malformed PHC hash
|
// Test malformed PHC hash
|
||||||
err = auth.VerifyPassword(password, "$invalid$format")
|
err = VerifyPassword(password, "$invalid$format")
|
||||||
assert.Error(t, err, "Should reject malformed hash")
|
assert.Error(t, err, "Should reject malformed hash")
|
||||||
|
|
||||||
// Test corrupted salt
|
// Test corrupted salt
|
||||||
@ -47,33 +54,27 @@ func TestPasswordHashing(t *testing.T) {
|
|||||||
if len(parts) == 6 {
|
if len(parts) == 6 {
|
||||||
parts[4] = "invalid!base64"
|
parts[4] = "invalid!base64"
|
||||||
corruptedHash = strings.Join(parts, "$")
|
corruptedHash = strings.Join(parts, "$")
|
||||||
err = auth.VerifyPassword(password, corruptedHash)
|
err = VerifyPassword(password, corruptedHash)
|
||||||
assert.Error(t, err, "Should reject corrupted salt")
|
assert.Error(t, err, "Should reject corrupted salt")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEmptyPasswordAfterValidation(t *testing.T) {
|
func TestEmptyPasswordAfterValidation(t *testing.T) {
|
||||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Empty password should be rejected by length check
|
// Empty password should be rejected by length check
|
||||||
_, err = auth.HashPassword("")
|
_, err := HashPassword("")
|
||||||
assert.Equal(t, ErrWeakPassword, err)
|
assert.Equal(t, ErrWeakPassword, err)
|
||||||
|
|
||||||
// 8-character password should pass
|
// 8-character password should pass
|
||||||
hash, err := auth.HashPassword("12345678")
|
hash, err := HashPassword("12345678")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = auth.VerifyPassword("12345678", hash)
|
err = VerifyPassword("12345678", hash)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConcurrentPasswordOperations(t *testing.T) {
|
func TestConcurrentPasswordOperations(t *testing.T) {
|
||||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
password := "testPassword123"
|
password := "testPassword123"
|
||||||
hash, err := auth.HashPassword(password)
|
hash, err := HashPassword(password)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test concurrent verification
|
// Test concurrent verification
|
||||||
@ -82,7 +83,7 @@ func TestConcurrentPasswordOperations(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err := auth.VerifyPassword(password, hash)
|
err := VerifyPassword(password, hash)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@ -90,14 +91,11 @@ func TestConcurrentPasswordOperations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPHCMigration(t *testing.T) {
|
func TestPHCMigration(t *testing.T) {
|
||||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
password := "testPassword123"
|
password := "testPassword123"
|
||||||
username := "migrationUser"
|
username := "migrationUser"
|
||||||
|
|
||||||
// Generate PHC hash
|
// Generate PHC hash
|
||||||
phcHash, err := auth.HashPassword(password)
|
phcHash, err := HashPassword(password)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Migrate to SCRAM credential
|
// Migrate to SCRAM credential
|
||||||
|
|||||||
71
auth.go
71
auth.go
@ -1,71 +0,0 @@
|
|||||||
// FILE: auth/auth.go
|
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rsa"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Authenticator provides password hashing and JWT operations
|
|
||||||
type Authenticator struct {
|
|
||||||
algorithm string
|
|
||||||
jwtSecret []byte // For HS256
|
|
||||||
privateKey *rsa.PrivateKey // For RS256
|
|
||||||
publicKey *rsa.PublicKey // For RS256
|
|
||||||
argonTime uint32
|
|
||||||
argonMemory uint32
|
|
||||||
argonThreads uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAuthenticator creates a new authenticator with specified algorithm
|
|
||||||
func NewAuthenticator(key any, algorithm ...string) (*Authenticator, error) {
|
|
||||||
alg := "HS256"
|
|
||||||
if len(algorithm) > 0 && algorithm[0] != "" {
|
|
||||||
alg = algorithm[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
auth := &Authenticator{
|
|
||||||
algorithm: alg,
|
|
||||||
argonTime: DefaultArgonTime,
|
|
||||||
argonMemory: DefaultArgonMemory,
|
|
||||||
argonThreads: DefaultArgonThreads,
|
|
||||||
}
|
|
||||||
|
|
||||||
switch alg {
|
|
||||||
case "HS256":
|
|
||||||
secret, ok := key.([]byte)
|
|
||||||
if !ok {
|
|
||||||
return nil, ErrInvalidKeyType
|
|
||||||
}
|
|
||||||
if len(secret) < 32 {
|
|
||||||
return nil, ErrSecretTooShort
|
|
||||||
}
|
|
||||||
auth.jwtSecret = secret
|
|
||||||
|
|
||||||
case "RS256":
|
|
||||||
switch k := key.(type) {
|
|
||||||
case *rsa.PrivateKey:
|
|
||||||
auth.privateKey = k
|
|
||||||
auth.publicKey = &k.PublicKey
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
auth.publicKey = k
|
|
||||||
case []byte:
|
|
||||||
// Try parsing as PEM
|
|
||||||
if privKey, err := parseRSAPrivateKey(k); err == nil {
|
|
||||||
auth.privateKey = privKey
|
|
||||||
auth.publicKey = &privKey.PublicKey
|
|
||||||
} else if pubKey, err := parseRSAPublicKey(k); err == nil {
|
|
||||||
auth.publicKey = pubKey
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("failed to parse RSA key: %w", err)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, ErrInvalidKeyType
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, ErrInvalidAlgorithm
|
|
||||||
}
|
|
||||||
|
|
||||||
return auth, nil
|
|
||||||
}
|
|
||||||
68
auth_test.go
68
auth_test.go
@ -1,68 +0,0 @@
|
|||||||
// FILE: auth/auth_test.go
|
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewAuthenticator(t *testing.T) {
|
|
||||||
// Test HS256 creation
|
|
||||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
|
||||||
require.NoError(t, err, "Failed to create HS256 authenticator")
|
|
||||||
assert.Equal(t, "HS256", auth.algorithm)
|
|
||||||
|
|
||||||
// Test with short secret
|
|
||||||
_, err = NewAuthenticator([]byte("short"))
|
|
||||||
assert.Equal(t, ErrSecretTooShort, err)
|
|
||||||
|
|
||||||
// Test RS256 with private key
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
authRS, err := NewAuthenticator(privateKey, "RS256")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "RS256", authRS.algorithm)
|
|
||||||
assert.NotNil(t, authRS.privateKey)
|
|
||||||
assert.NotNil(t, authRS.publicKey)
|
|
||||||
|
|
||||||
// Test RS256 with public key only
|
|
||||||
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "RS256", authPub.algorithm)
|
|
||||||
assert.Nil(t, authPub.privateKey)
|
|
||||||
assert.NotNil(t, authPub.publicKey)
|
|
||||||
|
|
||||||
// Test invalid algorithm
|
|
||||||
_, err = NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"), "INVALID")
|
|
||||||
assert.Equal(t, ErrInvalidAlgorithm, err)
|
|
||||||
|
|
||||||
// Test invalid key type for HS256
|
|
||||||
_, err = NewAuthenticator(privateKey, "HS256")
|
|
||||||
assert.Equal(t, ErrInvalidKeyType, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInterfaceCompliance(t *testing.T) {
|
|
||||||
// Verify Authenticator implements AuthenticatorInterface
|
|
||||||
auth, _ := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
|
||||||
|
|
||||||
var _ AuthenticatorInterface = auth
|
|
||||||
|
|
||||||
// Test interface methods work
|
|
||||||
hash, err := auth.HashPassword("testpass123")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = auth.VerifyPassword("testpass123", hash)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
token, err := auth.GenerateToken("user1", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
userID, _, err := auth.ValidateToken(token)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "user1", userID)
|
|
||||||
}
|
|
||||||
53
doc.go
Normal file
53
doc.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
// FILE: auth/doc.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package auth provides modular authentication components:
|
||||||
|
|
||||||
|
# Argon2 Password Hashing
|
||||||
|
|
||||||
|
Standalone password hashing using Argon2id:
|
||||||
|
|
||||||
|
hash, err := auth.HashPassword("password123")
|
||||||
|
err = auth.VerifyPassword("password123", hash)
|
||||||
|
|
||||||
|
// With custom parameters
|
||||||
|
hash, err := auth.HashPassword("password123",
|
||||||
|
auth.WithTime(5),
|
||||||
|
auth.WithMemory(128*1024))
|
||||||
|
|
||||||
|
# JWT Token Management
|
||||||
|
|
||||||
|
JSON Web Token generation and validation:
|
||||||
|
|
||||||
|
// HS256 (symmetric)
|
||||||
|
jwtMgr, _ := auth.NewJWT(secret)
|
||||||
|
token, _ := jwtMgr.GenerateToken("user1", claims)
|
||||||
|
userID, claims, _ := jwtMgr.ValidateToken(token)
|
||||||
|
|
||||||
|
// RS256 (asymmetric)
|
||||||
|
jwtMgr, _ := auth.NewJWTRSA(privateKey)
|
||||||
|
|
||||||
|
// One-off operations
|
||||||
|
token, _ := auth.GenerateHS256Token(secret, "user1", claims, 1*time.Hour)
|
||||||
|
|
||||||
|
# SCRAM-SHA256 Authentication
|
||||||
|
|
||||||
|
Server and client implementation for SCRAM:
|
||||||
|
|
||||||
|
// Server
|
||||||
|
server := auth.NewScramServer()
|
||||||
|
server.AddCredential(credential)
|
||||||
|
|
||||||
|
// Client
|
||||||
|
client := auth.NewScramClient(username, password)
|
||||||
|
|
||||||
|
# HTTP Authentication Parsing
|
||||||
|
|
||||||
|
Utility functions for HTTP headers:
|
||||||
|
|
||||||
|
username, password, _ := auth.ParseBasicAuth(header)
|
||||||
|
token, _ := auth.ParseBearerToken(header)
|
||||||
|
|
||||||
|
Each module can be used independently without initializing other components.
|
||||||
|
*/
|
||||||
5
error.go
5
error.go
@ -10,8 +10,6 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||||
ErrWeakPassword = errors.New("password must be at least 8 characters")
|
ErrWeakPassword = errors.New("password must be at least 8 characters")
|
||||||
ErrInvalidAlgorithm = errors.New("invalid algorithm")
|
|
||||||
ErrInvalidKeyType = errors.New("invalid key type for algorithm")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// JWT-specific errors
|
// JWT-specific errors
|
||||||
@ -22,9 +20,6 @@ var (
|
|||||||
ErrTokenInvalidSignature = errors.New("token: invalid signature")
|
ErrTokenInvalidSignature = errors.New("token: invalid signature")
|
||||||
ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch")
|
ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch")
|
||||||
ErrTokenMissingClaim = errors.New("token: missing required claim")
|
ErrTokenMissingClaim = errors.New("token: missing required claim")
|
||||||
ErrTokenInvalidHeader = errors.New("token: invalid header encoding")
|
|
||||||
ErrTokenInvalidClaims = errors.New("token: invalid claims encoding")
|
|
||||||
ErrTokenInvalidJSON = errors.New("token: malformed JSON")
|
|
||||||
ErrTokenEmptyUserID = errors.New("token: empty user ID")
|
ErrTokenEmptyUserID = errors.New("token: empty user ID")
|
||||||
ErrTokenNoPrivateKey = errors.New("token: private key required for signing")
|
ErrTokenNoPrivateKey = errors.New("token: private key required for signing")
|
||||||
ErrTokenNoPublicKey = errors.New("token: public key required for verification")
|
ErrTokenNoPublicKey = errors.New("token: public key required for verification")
|
||||||
|
|||||||
3
go.mod
3
go.mod
@ -1,8 +1,9 @@
|
|||||||
module auth
|
module github.com/lixenwraith/auth
|
||||||
|
|
||||||
go 1.25.3
|
go 1.25.3
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
golang.org/x/crypto v0.43.0
|
golang.org/x/crypto v0.43.0
|
||||||
)
|
)
|
||||||
|
|||||||
2
go.sum
2
go.sum
@ -1,5 +1,7 @@
|
|||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
|||||||
17
interface.go
17
interface.go
@ -1,17 +0,0 @@
|
|||||||
// FILE: auth/interface.go
|
|
||||||
package auth
|
|
||||||
|
|
||||||
// AuthenticatorInterface defines the authentication operations
|
|
||||||
type AuthenticatorInterface interface {
|
|
||||||
HashPassword(password string) (hash string, err error)
|
|
||||||
VerifyPassword(password, hash string) (err error)
|
|
||||||
GenerateToken(userID string, claims map[string]any) (token string, err error)
|
|
||||||
ValidateToken(token string) (userID string, claims map[string]any, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenValidator validates bearer tokens
|
|
||||||
type TokenValidator interface {
|
|
||||||
ValidateToken(token string) (valid bool)
|
|
||||||
AddToken(token string)
|
|
||||||
RemoveToken(token string)
|
|
||||||
}
|
|
||||||
376
jwt.go
376
jwt.go
@ -2,179 +2,289 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/subtle"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenerateToken creates a JWT token with user claims
|
// JWT configuration defaults
|
||||||
func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) {
|
const (
|
||||||
|
DefaultTokenLifetime = 24 * time.Hour
|
||||||
|
DefaultLeeway = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// customClaims extends RegisteredClaims with arbitrary user data
|
||||||
|
type customClaims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Extra map[string]any `json:"extra,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWT manages token generation and validation
|
||||||
|
type JWT struct {
|
||||||
|
algorithm jwt.SigningMethod
|
||||||
|
signKey any // []byte for HMAC, *rsa.PrivateKey for RSA
|
||||||
|
verifyKey any // []byte for HMAC, *rsa.PublicKey for RSA
|
||||||
|
tokenLifetime time.Duration
|
||||||
|
leeway time.Duration
|
||||||
|
issuer string
|
||||||
|
audience []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTOption configures JWT behavior
|
||||||
|
type JWTOption func(*JWT)
|
||||||
|
|
||||||
|
// WithTokenLifetime sets token expiration duration
|
||||||
|
func WithTokenLifetime(d time.Duration) JWTOption {
|
||||||
|
return func(j *JWT) {
|
||||||
|
if d > 0 {
|
||||||
|
j.tokenLifetime = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLeeway sets clock skew tolerance
|
||||||
|
func WithLeeway(d time.Duration) JWTOption {
|
||||||
|
return func(j *JWT) {
|
||||||
|
if d >= 0 {
|
||||||
|
j.leeway = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithIssuer sets token issuer claim
|
||||||
|
func WithIssuer(iss string) JWTOption {
|
||||||
|
return func(j *JWT) {
|
||||||
|
j.issuer = iss
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAudience sets token audience claim
|
||||||
|
func WithAudience(aud []string) JWTOption {
|
||||||
|
return func(j *JWT) {
|
||||||
|
j.audience = aud
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWT creates JWT manager for HS256 (symmetric)
|
||||||
|
func NewJWT(secret []byte, opts ...JWTOption) (*JWT, error) {
|
||||||
|
if len(secret) < 32 {
|
||||||
|
return nil, ErrSecretTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
j := &JWT{
|
||||||
|
algorithm: jwt.SigningMethodHS256,
|
||||||
|
signKey: secret,
|
||||||
|
verifyKey: secret,
|
||||||
|
tokenLifetime: DefaultTokenLifetime,
|
||||||
|
leeway: DefaultLeeway,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(j)
|
||||||
|
}
|
||||||
|
|
||||||
|
return j, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWTRSA creates JWT manager for RS256 (asymmetric)
|
||||||
|
func NewJWTRSA(privateKey *rsa.PrivateKey, opts ...JWTOption) (*JWT, error) {
|
||||||
|
if privateKey == nil {
|
||||||
|
return nil, ErrTokenNoPrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
j := &JWT{
|
||||||
|
algorithm: jwt.SigningMethodRS256,
|
||||||
|
signKey: privateKey,
|
||||||
|
verifyKey: &privateKey.PublicKey,
|
||||||
|
tokenLifetime: DefaultTokenLifetime,
|
||||||
|
leeway: DefaultLeeway,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(j)
|
||||||
|
}
|
||||||
|
|
||||||
|
return j, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWTVerifier creates JWT manager for verification only (RS256)
|
||||||
|
func NewJWTVerifier(publicKey *rsa.PublicKey, opts ...JWTOption) (*JWT, error) {
|
||||||
|
if publicKey == nil {
|
||||||
|
return nil, ErrTokenNoPublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
j := &JWT{
|
||||||
|
algorithm: jwt.SigningMethodRS256,
|
||||||
|
signKey: nil, // Cannot sign
|
||||||
|
verifyKey: publicKey,
|
||||||
|
tokenLifetime: DefaultTokenLifetime,
|
||||||
|
leeway: DefaultLeeway,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(j)
|
||||||
|
}
|
||||||
|
|
||||||
|
return j, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateToken creates signed JWT with claims
|
||||||
|
func (j *JWT) GenerateToken(userID string, claims map[string]any) (string, error) {
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
return "", ErrTokenEmptyUserID
|
return "", ErrTokenEmptyUserID
|
||||||
}
|
}
|
||||||
|
|
||||||
if a.algorithm == "RS256" && a.privateKey == nil {
|
if j.signKey == nil {
|
||||||
return "", ErrTokenNoPrivateKey
|
return "", ErrTokenNoPrivateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build JWT claims
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
jwtClaims := map[string]any{
|
registeredClaims := jwt.RegisteredClaims{
|
||||||
"sub": userID,
|
Subject: userID,
|
||||||
"iat": now.Unix(),
|
Issuer: j.issuer,
|
||||||
"exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry
|
Audience: j.audience,
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(j.tokenLifetime)),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reserved claims that cannot be overridden
|
token := jwt.NewWithClaims(j.algorithm, customClaims{
|
||||||
reservedClaims := map[string]bool{
|
RegisteredClaims: registeredClaims,
|
||||||
"sub": true, "iat": true, "exp": true, "nbf": true,
|
Extra: claims,
|
||||||
"iss": true, "aud": true, "jti": true, "typ": true,
|
})
|
||||||
"alg": true,
|
|
||||||
|
return token.SignedString(j.signKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge custom claims
|
// ValidateToken verifies JWT and extracts claims
|
||||||
for k, v := range claims {
|
func (j *JWT) ValidateToken(tokenString string) (string, map[string]any, error) {
|
||||||
if !reservedClaims[k] {
|
parser := jwt.NewParser(
|
||||||
jwtClaims[k] = v
|
jwt.WithLeeway(j.leeway),
|
||||||
}
|
jwt.WithAudience(j.audience...),
|
||||||
}
|
jwt.WithIssuer(j.issuer),
|
||||||
|
jwt.WithValidMethods([]string{j.algorithm.Alg()}),
|
||||||
|
jwt.WithExpirationRequired(),
|
||||||
|
)
|
||||||
|
|
||||||
// Create JWT header
|
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
|
||||||
header := map[string]any{
|
// Algorithm already validated by WithValidMethods
|
||||||
"alg": a.algorithm,
|
return j.verifyKey, nil
|
||||||
"typ": "JWT",
|
})
|
||||||
}
|
|
||||||
|
|
||||||
// Encode header and payload
|
|
||||||
headerJSON, _ := json.Marshal(header)
|
|
||||||
claimsJSON, _ := json.Marshal(jwtClaims)
|
|
||||||
|
|
||||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
|
||||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
||||||
|
|
||||||
// Create signature
|
|
||||||
signingInput := headerB64 + "." + claimsB64
|
|
||||||
var signature string
|
|
||||||
|
|
||||||
switch a.algorithm {
|
|
||||||
case "HS256":
|
|
||||||
h := hmac.New(sha256.New, a.jwtSecret)
|
|
||||||
h.Write([]byte(signingInput))
|
|
||||||
signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
|
||||||
|
|
||||||
case "RS256":
|
|
||||||
hashed := sha256.Sum256([]byte(signingInput))
|
|
||||||
sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
return "", nil, mapJWTError(err)
|
||||||
}
|
|
||||||
signature = base64.RawURLEncoding.EncodeToString(sig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combine to form JWT
|
claims, ok := token.Claims.(*customClaims)
|
||||||
token := signingInput + "." + signature
|
if !ok || !token.Valid {
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateToken verifies JWT and returns userID and claims
|
|
||||||
func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) {
|
|
||||||
// Split token
|
|
||||||
parts := strings.Split(token, ".")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return "", nil, ErrTokenMalformed
|
return "", nil, ErrTokenMalformed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode header to check algorithm
|
return claims.Subject, claims.Extra, nil
|
||||||
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
}
|
||||||
|
|
||||||
|
// mapJWTError translates jwt library errors to auth package errors
|
||||||
|
func mapJWTError(err error) error {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, jwt.ErrTokenMalformed):
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||||
|
case errors.Is(err, jwt.ErrTokenUnverifiable):
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||||
|
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenInvalidSignature, err)
|
||||||
|
case errors.Is(err, jwt.ErrTokenExpired):
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenExpired, err)
|
||||||
|
case errors.Is(err, jwt.ErrTokenNotValidYet):
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenNotYetValid, err)
|
||||||
|
case errors.Is(err, jwt.ErrTokenInvalidAudience):
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
|
||||||
|
case errors.Is(err, jwt.ErrTokenInvalidIssuer):
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
|
||||||
|
default:
|
||||||
|
// Check for algorithm mismatch in error message
|
||||||
|
if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenAlgorithmMismatch, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standalone helper functions for one-off operations
|
||||||
|
|
||||||
|
// GenerateHS256Token creates HS256 JWT without manager instance
|
||||||
|
func GenerateHS256Token(secret []byte, userID string, claims map[string]any, lifetime time.Duration) (string, error) {
|
||||||
|
if len(secret) < 32 {
|
||||||
|
return "", ErrSecretTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, customClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID,
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(lifetime)),
|
||||||
|
},
|
||||||
|
Extra: claims,
|
||||||
|
})
|
||||||
|
|
||||||
|
return token.SignedString(secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateHS256Token verifies HS256 JWT without manager instance
|
||||||
|
func ValidateHS256Token(secret []byte, tokenString string) (string, map[string]any, error) {
|
||||||
|
if len(secret) < 32 {
|
||||||
|
return "", nil, ErrSecretTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := jwt.NewParser(
|
||||||
|
jwt.WithValidMethods([]string{"HS256"}),
|
||||||
|
jwt.WithLeeway(DefaultLeeway),
|
||||||
|
)
|
||||||
|
|
||||||
|
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
|
||||||
|
return secret, nil
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, ErrTokenInvalidHeader
|
return "", nil, mapJWTError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var header map[string]any
|
claims, ok := token.Claims.(*customClaims)
|
||||||
if err = json.Unmarshal(headerJSON, &header); err != nil {
|
if !ok || !token.Valid {
|
||||||
return "", nil, ErrTokenInvalidJSON
|
return "", nil, ErrTokenMalformed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify algorithm matches
|
return claims.Subject, claims.Extra, nil
|
||||||
if alg, ok := header["alg"].(string); !ok || alg != a.algorithm {
|
|
||||||
return "", nil, ErrTokenAlgorithmMismatch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify signature
|
// RSA Utilities
|
||||||
signingInput := parts[0] + "." + parts[1]
|
|
||||||
|
|
||||||
switch a.algorithm {
|
// NewJWTRSAFromPEM creates a JWT manager for RS256 from raw PEM-encoded private key data.
|
||||||
case "HS256":
|
func NewJWTRSAFromPEM(privateKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
|
||||||
h := hmac.New(sha256.New, a.jwtSecret)
|
privateKey, err := parseRSAPrivateKey(privateKeyPEM)
|
||||||
h.Write([]byte(signingInput))
|
|
||||||
expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
|
||||||
|
|
||||||
if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 {
|
|
||||||
return "", nil, ErrTokenInvalidSignature
|
|
||||||
}
|
|
||||||
|
|
||||||
case "RS256":
|
|
||||||
if a.publicKey == nil {
|
|
||||||
return "", nil, ErrTokenNoPublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, ErrTokenInvalidSignature
|
return nil, err
|
||||||
|
}
|
||||||
|
// Call the original constructor with the now-parsed key
|
||||||
|
return NewJWTRSA(privateKey, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
hashed := sha256.Sum256([]byte(signingInput))
|
// NewJWTVerifierFromPEM creates a JWT manager for verification from raw PEM-encoded public key data.
|
||||||
if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil {
|
func NewJWTVerifierFromPEM(publicKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
|
||||||
return "", nil, ErrTokenInvalidSignature
|
publicKey, err := parseRSAPublicKey(publicKeyPEM)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode claims
|
|
||||||
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, ErrTokenInvalidClaims
|
return nil, err
|
||||||
|
}
|
||||||
|
// Call the original constructor with the now-parsed key
|
||||||
|
return NewJWTVerifier(publicKey, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
var claims map[string]any
|
// parseRSAPrivateKey parses a PEM-encoded RSA private key.
|
||||||
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
|
|
||||||
return "", nil, ErrTokenInvalidJSON
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check expiration
|
|
||||||
if exp, ok := claims["exp"].(float64); ok {
|
|
||||||
if time.Now().Unix() > int64(exp) {
|
|
||||||
return "", nil, ErrTokenExpired
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check not before
|
|
||||||
if nbf, ok := claims["nbf"].(float64); ok {
|
|
||||||
if time.Now().Unix() < int64(nbf) {
|
|
||||||
return "", nil, ErrTokenNotYetValid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract userID
|
|
||||||
userID, ok := claims["sub"].(string)
|
|
||||||
if !ok {
|
|
||||||
return "", nil, ErrTokenMissingClaim
|
|
||||||
}
|
|
||||||
|
|
||||||
return userID, claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRSAPrivateKey parses PEM encoded RSA private key
|
|
||||||
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||||
block, _ := pem.Decode(pemBytes)
|
block, _ := pem.Decode(pemBytes)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
@ -187,7 +297,7 @@ func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
|||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseRSAPublicKey parses PEM encoded RSA public key
|
// parseRSAPublicKey parses a PEM-encoded RSA public key.
|
||||||
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
|
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
|
||||||
block, _ := pem.Decode(pemBytes)
|
block, _ := pem.Decode(pemBytes)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
|
|||||||
283
jwt_test.go
283
jwt_test.go
@ -4,19 +4,20 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/base64"
|
"crypto/x509"
|
||||||
"errors"
|
"encoding/pem"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWTHS256(t *testing.T) {
|
func TestJWTHS256(t *testing.T) {
|
||||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||||
|
jwtMgr, err := NewJWT(secret)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
userID := "user123"
|
userID := "user123"
|
||||||
@ -26,54 +27,17 @@ func TestJWTHS256(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate token
|
// Generate token
|
||||||
token, err := auth.GenerateToken(userID, claims)
|
token, err := jwtMgr.GenerateToken(userID, claims)
|
||||||
require.NoError(t, err, "Failed to generate token")
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, token)
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
// Validate token
|
// Validate token
|
||||||
extractedUserID, extractedClaims, err := auth.ValidateToken(token)
|
extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
|
||||||
require.NoError(t, err, "Failed to validate token")
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, userID, extractedUserID)
|
assert.Equal(t, userID, extractedUserID)
|
||||||
assert.Equal(t, "test@example.com", extractedClaims["email"])
|
assert.Equal(t, "test@example.com", extractedClaims["email"])
|
||||||
assert.Equal(t, "admin", extractedClaims["role"])
|
assert.Equal(t, "admin", extractedClaims["role"])
|
||||||
|
|
||||||
// Test invalid token
|
|
||||||
_, _, err = auth.ValidateToken("invalid.token.here")
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.True(t, errors.Is(err, ErrTokenInvalidJSON))
|
|
||||||
|
|
||||||
// Test tampered token
|
|
||||||
parts := strings.Split(token, ".")
|
|
||||||
require.Len(t, parts, 3, "JWT should have 3 parts")
|
|
||||||
|
|
||||||
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
|
|
||||||
_, _, err = auth.ValidateToken(tampered)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.True(t, errors.Is(err, ErrTokenInvalidSignature))
|
|
||||||
|
|
||||||
// Test reserved claims cannot be overridden
|
|
||||||
overrideClaims := map[string]any{
|
|
||||||
"sub": "override",
|
|
||||||
"iat": 12345,
|
|
||||||
"exp": 67890,
|
|
||||||
"nbf": 11111,
|
|
||||||
"iss": "attacker",
|
|
||||||
"aud": "victim",
|
|
||||||
"jti": "fake",
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err = auth.GenerateToken(userID, overrideClaims)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
extractedUserID, extractedClaims, err = auth.ValidateToken(token)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, userID, extractedUserID, "UserID should not be overridden")
|
|
||||||
assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden")
|
|
||||||
assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden")
|
|
||||||
assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims")
|
|
||||||
assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJWTRS256(t *testing.T) {
|
func TestJWTRS256(t *testing.T) {
|
||||||
@ -82,95 +46,214 @@ func TestJWTRS256(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test with private key (can sign and verify)
|
// Test with private key (can sign and verify)
|
||||||
authPriv, err := NewAuthenticator(privateKey, "RS256")
|
jwtMgr, err := NewJWTRSA(privateKey)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
userID := "user456"
|
userID := "user456"
|
||||||
claims := map[string]any{
|
claims := map[string]any{
|
||||||
"email": "rs256@example.com",
|
|
||||||
"scope": "read:all",
|
"scope": "read:all",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate token
|
// Generate token
|
||||||
token, err := authPriv.GenerateToken(userID, claims)
|
token, err := jwtMgr.GenerateToken(userID, claims)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, token)
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
// Validate token with private key auth (has public key too)
|
// Validate with same manager
|
||||||
extractedUserID, extractedClaims, err := authPriv.ValidateToken(token)
|
extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, userID, extractedUserID)
|
assert.Equal(t, userID, extractedUserID)
|
||||||
assert.Equal(t, "rs256@example.com", extractedClaims["email"])
|
|
||||||
assert.Equal(t, "read:all", extractedClaims["scope"])
|
assert.Equal(t, "read:all", extractedClaims["scope"])
|
||||||
|
|
||||||
// Test with public key only (can only verify)
|
// Test with verifier only (public key)
|
||||||
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
verifier, err := NewJWTVerifier(&privateKey.PublicKey)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Should be able to validate token
|
// Should validate token
|
||||||
extractedUserID, extractedClaims, err = authPub.ValidateToken(token)
|
extractedUserID, _, err = verifier.ValidateToken(token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, userID, extractedUserID)
|
assert.Equal(t, userID, extractedUserID)
|
||||||
|
|
||||||
// Should not be able to generate token
|
// Should not generate token
|
||||||
_, err = authPub.GenerateToken(userID, claims)
|
_, err = verifier.GenerateToken(userID, claims)
|
||||||
assert.Error(t, err, "Public key only auth should not generate tokens")
|
|
||||||
assert.Equal(t, ErrTokenNoPrivateKey, err)
|
assert.Equal(t, ErrTokenNoPrivateKey, err)
|
||||||
|
|
||||||
// Test algorithm mismatch
|
|
||||||
authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, _, err = authHS256.ValidateToken(token)
|
|
||||||
assert.Error(t, err, "HS256 auth should not validate RS256 token")
|
|
||||||
// assert.True(t, errors.Is(err, ErrInvalidToken))
|
|
||||||
assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch))
|
|
||||||
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExpiredToken(t *testing.T) {
|
func TestJWTOptions(t *testing.T) {
|
||||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||||
|
|
||||||
|
// Test custom lifetime
|
||||||
|
jwtMgr, err := NewJWT(secret,
|
||||||
|
WithTokenLifetime(1*time.Hour),
|
||||||
|
WithIssuer("test-issuer"),
|
||||||
|
WithAudience([]string{"api.example.com"}),
|
||||||
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
userID := "user123"
|
token, err := jwtMgr.GenerateToken("user1", nil)
|
||||||
|
|
||||||
// Generate normal token (should have 7 days expiry)
|
|
||||||
token, err := auth.GenerateToken(userID, nil)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, extractedClaims, err := auth.ValidateToken(token)
|
// Parse token to check claims
|
||||||
require.NoError(t, err)
|
parsed, _ := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
||||||
|
return secret, nil
|
||||||
|
})
|
||||||
|
|
||||||
// Check expiry is in future (approximately 7 days)
|
claims := parsed.Claims.(jwt.MapClaims)
|
||||||
expiry := extractedClaims["exp"].(float64)
|
|
||||||
now := time.Now().Unix()
|
|
||||||
|
|
||||||
assert.Greater(t, expiry, float64(now), "Token expiry should be in future")
|
// Check issuer
|
||||||
assert.InDelta(t, expiry, float64(now+7*24*60*60), 10,
|
assert.Equal(t, "test-issuer", claims["iss"])
|
||||||
"Token expiry should be approximately 7 days from now")
|
|
||||||
|
// Check audience
|
||||||
|
aud := claims["aud"].([]any)
|
||||||
|
assert.Contains(t, aud, "api.example.com")
|
||||||
|
|
||||||
|
// Check expiration is ~1 hour
|
||||||
|
exp := int64(claims["exp"].(float64))
|
||||||
|
iat := int64(claims["iat"].(float64))
|
||||||
|
assert.InDelta(t, 3600, exp-iat, 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCorruptJWTParts(t *testing.T) {
|
func TestJWTErrors(t *testing.T) {
|
||||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||||
|
jwtMgr, err := NewJWT(secret)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test with missing parts
|
// Empty user ID
|
||||||
_, _, err = auth.ValidateToken("only.two")
|
_, err = jwtMgr.GenerateToken("", nil)
|
||||||
assert.True(t, errors.Is(err, ErrTokenMalformed))
|
assert.Equal(t, ErrTokenEmptyUserID, err)
|
||||||
|
|
||||||
// Test with invalid header encoding
|
// Invalid token format
|
||||||
_, _, err = auth.ValidateToken("not-base64!.valid.valid")
|
_, _, err = jwtMgr.ValidateToken("invalid.token")
|
||||||
assert.Error(t, err)
|
assert.ErrorIs(t, err, ErrTokenMalformed)
|
||||||
|
|
||||||
// Test with invalid claims encoding
|
// Tampered signature
|
||||||
validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
token, _ := jwtMgr.GenerateToken("user1", nil)
|
||||||
_, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid")
|
parts := strings.Split(token, ".")
|
||||||
assert.Error(t, err)
|
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
|
||||||
|
_, _, err = jwtMgr.ValidateToken(tampered)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
|
||||||
|
|
||||||
// Test with invalid JSON in claims
|
// Wrong algorithm
|
||||||
invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json"))
|
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
_, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature")
|
rsaMgr, _ := NewJWTRSA(rsaKey)
|
||||||
assert.Error(t, err)
|
rsaToken, _ := rsaMgr.GenerateToken("user1", nil)
|
||||||
|
|
||||||
|
_, _, err = jwtMgr.ValidateToken(rsaToken)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTExpiration(t *testing.T) {
|
||||||
|
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||||
|
|
||||||
|
// Create token with 1 second lifetime
|
||||||
|
jwtMgr, err := NewJWT(secret, WithTokenLifetime(1*time.Second), WithLeeway(0))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := jwtMgr.GenerateToken("user1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should be valid immediately
|
||||||
|
_, _, err = jwtMgr.ValidateToken(token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for expiration
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
// Should be expired
|
||||||
|
_, _, err = jwtMgr.ValidateToken(token)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenExpired)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLeeway(t *testing.T) {
|
||||||
|
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||||
|
|
||||||
|
// Create manager with no leeway
|
||||||
|
jwtMgr, err := NewJWT(secret, WithLeeway(0))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Manually create a token with NotBefore in future
|
||||||
|
now := time.Now()
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"sub": "user1",
|
||||||
|
"nbf": now.Add(2 * time.Second).Unix(),
|
||||||
|
"exp": now.Add(1 * time.Hour).Unix(),
|
||||||
|
})
|
||||||
|
tokenString, err := token.SignedString(secret)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should fail immediately (not valid yet)
|
||||||
|
_, _, err = jwtMgr.ValidateToken(tokenString)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenNotYetValid)
|
||||||
|
|
||||||
|
// Create manager with leeway
|
||||||
|
jwtMgrWithLeeway, err := NewJWT(secret, WithLeeway(5*time.Second))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should pass with leeway
|
||||||
|
_, _, err = jwtMgrWithLeeway.ValidateToken(tokenString)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStandaloneFunctions(t *testing.T) {
|
||||||
|
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||||
|
userID := "standalone-user"
|
||||||
|
claims := map[string]any{"test": "value"}
|
||||||
|
|
||||||
|
// Generate token
|
||||||
|
token, err := GenerateHS256Token(secret, userID, claims, 1*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Validate token
|
||||||
|
extractedUserID, extractedClaims, err := ValidateHS256Token(secret, token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, userID, extractedUserID)
|
||||||
|
assert.Equal(t, "value", extractedClaims["test"])
|
||||||
|
|
||||||
|
// Test with short secret
|
||||||
|
_, err = GenerateHS256Token([]byte("short"), userID, claims, 1*time.Hour)
|
||||||
|
assert.Equal(t, ErrSecretTooShort, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTRSAFromPEM(t *testing.T) {
|
||||||
|
// 1. Generate a new RSA key pair for this test
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 2. Encode the private key to PEM format
|
||||||
|
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
// 3. Encode the public key to PEM format
|
||||||
|
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "PUBLIC KEY",
|
||||||
|
Bytes: publicKeyBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 4. Test the PEM constructor for the signer
|
||||||
|
jwtMgr, err := NewJWTRSAFromPEM(privateKeyPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := jwtMgr.GenerateToken("user-from-pem", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
|
// 5. Test the PEM constructor for the verifier
|
||||||
|
verifier, err := NewJWTVerifierFromPEM(publicKeyPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID, _, err := verifier.ValidateToken(token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "user-from-pem", userID)
|
||||||
|
|
||||||
|
// 6. Test failure cases with invalid data
|
||||||
|
_, err = NewJWTRSAFromPEM([]byte("invalid pem data"))
|
||||||
|
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
|
||||||
|
|
||||||
|
_, err = NewJWTVerifierFromPEM([]byte("invalid pem data"))
|
||||||
|
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
|
||||||
}
|
}
|
||||||
4
scram.go
4
scram.go
@ -143,6 +143,10 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
|
|||||||
return nil, ErrSCRAMSaltTooShort
|
return nil, ErrSCRAMSaltTooShort
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if time == 0 || memory == 0 || threads == 0 {
|
||||||
|
return nil, ErrSCRAMZeroParams
|
||||||
|
}
|
||||||
|
|
||||||
// Derive salted password using Argon2id
|
// Derive salted password using Argon2id
|
||||||
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen)
|
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen)
|
||||||
|
|
||||||
|
|||||||
146
scram_test.go
Normal file
146
scram_test.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
// FILE: auth/scram_test.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupScramTest is a helper to initialize a server and user credential for testing.
|
||||||
|
// It performs the full Argon2 -> SCRAM migration workflow.
|
||||||
|
func setupScramTest(t *testing.T) (server *ScramServer, username, password string, cred *Credential) {
|
||||||
|
username = "testuser"
|
||||||
|
password = "SecurePassword123"
|
||||||
|
|
||||||
|
// 1. Start with an Argon2 PHC hash, as a real application would.
|
||||||
|
phcHash, err := HashPassword(password)
|
||||||
|
require.NoError(t, err, "Setup failed: could not hash password")
|
||||||
|
|
||||||
|
// 2. Migrate the PHC hash to a SCRAM credential.
|
||||||
|
cred, err = MigrateFromPHC(username, password, phcHash)
|
||||||
|
require.NoError(t, err, "Setup failed: could not migrate from PHC hash")
|
||||||
|
|
||||||
|
// 3. Create a server and add the new credential.
|
||||||
|
server = NewScramServer()
|
||||||
|
server.AddCredential(cred)
|
||||||
|
|
||||||
|
return server, username, password, cred
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScram_FullRoundtrip_Success simulates a complete, successful authentication handshake.
|
||||||
|
func TestScram_FullRoundtrip_Success(t *testing.T) {
|
||||||
|
server, username, password, _ := setupScramTest(t)
|
||||||
|
client := NewScramClient(username, password)
|
||||||
|
|
||||||
|
// --- Step 1: Client sends its first message ---
|
||||||
|
clientFirst, err := client.StartAuthentication()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// --- Step 2: Server receives client's message and responds ---
|
||||||
|
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||||
|
require.NoError(t, err, "Server failed to process client's first message")
|
||||||
|
|
||||||
|
// --- Step 3: Client receives server's message, computes proof ---
|
||||||
|
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
||||||
|
require.NoError(t, err, "Client failed to process server's first message")
|
||||||
|
|
||||||
|
// --- Step 4: Server receives client's proof and verifies it ---
|
||||||
|
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
||||||
|
require.NoError(t, err, "Server failed to verify client's final proof")
|
||||||
|
assert.NotEmpty(t, serverFinal.ServerSignature, "Server signature should not be empty")
|
||||||
|
|
||||||
|
// --- Step 5: Client verifies server's signature (mutual authentication) ---
|
||||||
|
err = client.VerifyServerFinalMessage(serverFinal)
|
||||||
|
assert.NoError(t, err, "Client failed to verify server's final signature")
|
||||||
|
|
||||||
|
t.Log("SCRAM full roundtrip successful")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScram_FullRoundtrip_WrongPassword ensures authentication fails with an incorrect password.
|
||||||
|
func TestScram_FullRoundtrip_WrongPassword(t *testing.T) {
|
||||||
|
server, username, _, _ := setupScramTest(t)
|
||||||
|
// Create a client with the WRONG password
|
||||||
|
client := NewScramClient(username, "WrongPassword!!!")
|
||||||
|
|
||||||
|
// Steps 1-3 will appear to succeed, as the client doesn't know the password is wrong yet.
|
||||||
|
clientFirst, err := client.StartAuthentication()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// --- Step 4: Server verification should fail here ---
|
||||||
|
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should reject proof from wrong password")
|
||||||
|
|
||||||
|
t.Log("SCRAM correctly failed for wrong password")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScram_FullRoundtrip_UserNotFound tests for user enumeration protection.
|
||||||
|
// The server should not reveal whether a user exists or not in its first message.
|
||||||
|
func TestScram_FullRoundtrip_UserNotFound(t *testing.T) {
|
||||||
|
server, _, _, _ := setupScramTest(t)
|
||||||
|
client := NewScramClient("unknown_user", "any_password")
|
||||||
|
|
||||||
|
clientFirst, err := client.StartAuthentication()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// --- Step 2: Server should return an error, but also a FAKE response ---
|
||||||
|
// This prevents an attacker from knowing if the user exists based on the response structure.
|
||||||
|
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should return an error for an unknown user")
|
||||||
|
assert.NotEmpty(t, serverFirst.FullNonce, "Server must still provide a nonce to prevent enumeration")
|
||||||
|
assert.NotEmpty(t, serverFirst.Salt, "Server must still provide a salt to prevent enumeration")
|
||||||
|
|
||||||
|
t.Log("SCRAM correctly protected against user enumeration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScram_InvalidNonce simulates a replay attack or message mismatch.
|
||||||
|
func TestScram_InvalidNonce(t *testing.T) {
|
||||||
|
server, username, password, _ := setupScramTest(t)
|
||||||
|
client := NewScramClient(username, password)
|
||||||
|
|
||||||
|
// Perform the first part of the handshake
|
||||||
|
clientFirst, _ := client.StartAuthentication()
|
||||||
|
serverFirst, _ := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||||
|
clientFinal, _ := client.ProcessServerFirstMessage(serverFirst)
|
||||||
|
|
||||||
|
// Attempt to finalize with a completely different nonce
|
||||||
|
_, err := server.ProcessClientFinalMessage("this-is-a-bad-nonce", clientFinal.ClientProof)
|
||||||
|
assert.ErrorIs(t, err, ErrSCRAMInvalidNonce, "Server should reject a final message with an unknown nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScram_CredentialImportExport verifies that credentials can be serialized and deserialized correctly.
|
||||||
|
func TestScram_CredentialImportExport(t *testing.T) {
|
||||||
|
_, _, _, originalCred := setupScramTest(t)
|
||||||
|
|
||||||
|
// Export the credential to a map
|
||||||
|
exportedData := originalCred.Export()
|
||||||
|
require.NotNil(t, exportedData)
|
||||||
|
|
||||||
|
// Assert that required fields exist and are strings (as they are base64 encoded)
|
||||||
|
assert.IsType(t, "", exportedData["salt"])
|
||||||
|
assert.IsType(t, "", exportedData["stored_key"])
|
||||||
|
assert.IsType(t, "", exportedData["server_key"])
|
||||||
|
|
||||||
|
// Import the credential back from the map
|
||||||
|
importedCred, err := ImportCredential(exportedData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, importedCred)
|
||||||
|
|
||||||
|
// Verify that the imported credential is identical to the original
|
||||||
|
assert.Equal(t, originalCred.Username, importedCred.Username)
|
||||||
|
assert.Equal(t, originalCred.Salt, importedCred.Salt)
|
||||||
|
assert.Equal(t, originalCred.ArgonTime, importedCred.ArgonTime)
|
||||||
|
assert.Equal(t, originalCred.ArgonMemory, importedCred.ArgonMemory)
|
||||||
|
assert.Equal(t, originalCred.ArgonThreads, importedCred.ArgonThreads)
|
||||||
|
assert.Equal(t, originalCred.StoredKey, importedCred.StoredKey)
|
||||||
|
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
|
||||||
|
|
||||||
|
t.Log("SCRAM credential import/export successful")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user