v0.2.0 restructured and generalized to be more modular, added golang-jwt dependency
This commit is contained in:
283
jwt_test.go
283
jwt_test.go
@ -4,19 +4,20 @@ package auth
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJWTHS256(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
jwtMgr, err := NewJWT(secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user123"
|
||||
@ -26,54 +27,17 @@ func TestJWTHS256(t *testing.T) {
|
||||
}
|
||||
|
||||
// Generate token
|
||||
token, err := auth.GenerateToken(userID, claims)
|
||||
require.NoError(t, err, "Failed to generate token")
|
||||
token, err := jwtMgr.GenerateToken(userID, claims)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Validate token
|
||||
extractedUserID, extractedClaims, err := auth.ValidateToken(token)
|
||||
require.NoError(t, err, "Failed to validate token")
|
||||
extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
assert.Equal(t, "test@example.com", extractedClaims["email"])
|
||||
assert.Equal(t, "admin", extractedClaims["role"])
|
||||
|
||||
// Test invalid token
|
||||
_, _, err = auth.ValidateToken("invalid.token.here")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrTokenInvalidJSON))
|
||||
|
||||
// Test tampered token
|
||||
parts := strings.Split(token, ".")
|
||||
require.Len(t, parts, 3, "JWT should have 3 parts")
|
||||
|
||||
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
|
||||
_, _, err = auth.ValidateToken(tampered)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrTokenInvalidSignature))
|
||||
|
||||
// Test reserved claims cannot be overridden
|
||||
overrideClaims := map[string]any{
|
||||
"sub": "override",
|
||||
"iat": 12345,
|
||||
"exp": 67890,
|
||||
"nbf": 11111,
|
||||
"iss": "attacker",
|
||||
"aud": "victim",
|
||||
"jti": "fake",
|
||||
}
|
||||
|
||||
token, err = auth.GenerateToken(userID, overrideClaims)
|
||||
require.NoError(t, err)
|
||||
|
||||
extractedUserID, extractedClaims, err = auth.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID, "UserID should not be overridden")
|
||||
assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden")
|
||||
assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden")
|
||||
assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims")
|
||||
assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims")
|
||||
}
|
||||
|
||||
func TestJWTRS256(t *testing.T) {
|
||||
@ -82,95 +46,214 @@ func TestJWTRS256(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with private key (can sign and verify)
|
||||
authPriv, err := NewAuthenticator(privateKey, "RS256")
|
||||
jwtMgr, err := NewJWTRSA(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user456"
|
||||
claims := map[string]any{
|
||||
"email": "rs256@example.com",
|
||||
"scope": "read:all",
|
||||
}
|
||||
|
||||
// Generate token
|
||||
token, err := authPriv.GenerateToken(userID, claims)
|
||||
token, err := jwtMgr.GenerateToken(userID, claims)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Validate token with private key auth (has public key too)
|
||||
extractedUserID, extractedClaims, err := authPriv.ValidateToken(token)
|
||||
// Validate with same manager
|
||||
extractedUserID, extractedClaims, err := jwtMgr.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
assert.Equal(t, "rs256@example.com", extractedClaims["email"])
|
||||
assert.Equal(t, "read:all", extractedClaims["scope"])
|
||||
|
||||
// Test with public key only (can only verify)
|
||||
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
||||
// Test with verifier only (public key)
|
||||
verifier, err := NewJWTVerifier(&privateKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be able to validate token
|
||||
extractedUserID, extractedClaims, err = authPub.ValidateToken(token)
|
||||
// Should validate token
|
||||
extractedUserID, _, err = verifier.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
|
||||
// Should not be able to generate token
|
||||
_, err = authPub.GenerateToken(userID, claims)
|
||||
assert.Error(t, err, "Public key only auth should not generate tokens")
|
||||
// Should not generate token
|
||||
_, err = verifier.GenerateToken(userID, claims)
|
||||
assert.Equal(t, ErrTokenNoPrivateKey, err)
|
||||
|
||||
// Test algorithm mismatch
|
||||
authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = authHS256.ValidateToken(token)
|
||||
assert.Error(t, err, "HS256 auth should not validate RS256 token")
|
||||
// assert.True(t, errors.Is(err, ErrInvalidToken))
|
||||
assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch))
|
||||
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
func TestExpiredToken(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
func TestJWTOptions(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
|
||||
// Test custom lifetime
|
||||
jwtMgr, err := NewJWT(secret,
|
||||
WithTokenLifetime(1*time.Hour),
|
||||
WithIssuer("test-issuer"),
|
||||
WithAudience([]string{"api.example.com"}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user123"
|
||||
|
||||
// Generate normal token (should have 7 days expiry)
|
||||
token, err := auth.GenerateToken(userID, nil)
|
||||
token, err := jwtMgr.GenerateToken("user1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, extractedClaims, err := auth.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
// Parse token to check claims
|
||||
parsed, _ := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
||||
return secret, nil
|
||||
})
|
||||
|
||||
// Check expiry is in future (approximately 7 days)
|
||||
expiry := extractedClaims["exp"].(float64)
|
||||
now := time.Now().Unix()
|
||||
claims := parsed.Claims.(jwt.MapClaims)
|
||||
|
||||
assert.Greater(t, expiry, float64(now), "Token expiry should be in future")
|
||||
assert.InDelta(t, expiry, float64(now+7*24*60*60), 10,
|
||||
"Token expiry should be approximately 7 days from now")
|
||||
// Check issuer
|
||||
assert.Equal(t, "test-issuer", claims["iss"])
|
||||
|
||||
// Check audience
|
||||
aud := claims["aud"].([]any)
|
||||
assert.Contains(t, aud, "api.example.com")
|
||||
|
||||
// Check expiration is ~1 hour
|
||||
exp := int64(claims["exp"].(float64))
|
||||
iat := int64(claims["iat"].(float64))
|
||||
assert.InDelta(t, 3600, exp-iat, 10)
|
||||
}
|
||||
|
||||
func TestCorruptJWTParts(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
func TestJWTErrors(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
jwtMgr, err := NewJWT(secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with missing parts
|
||||
_, _, err = auth.ValidateToken("only.two")
|
||||
assert.True(t, errors.Is(err, ErrTokenMalformed))
|
||||
// Empty user ID
|
||||
_, err = jwtMgr.GenerateToken("", nil)
|
||||
assert.Equal(t, ErrTokenEmptyUserID, err)
|
||||
|
||||
// Test with invalid header encoding
|
||||
_, _, err = auth.ValidateToken("not-base64!.valid.valid")
|
||||
assert.Error(t, err)
|
||||
// Invalid token format
|
||||
_, _, err = jwtMgr.ValidateToken("invalid.token")
|
||||
assert.ErrorIs(t, err, ErrTokenMalformed)
|
||||
|
||||
// Test with invalid claims encoding
|
||||
validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
_, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid")
|
||||
assert.Error(t, err)
|
||||
// Tampered signature
|
||||
token, _ := jwtMgr.GenerateToken("user1", nil)
|
||||
parts := strings.Split(token, ".")
|
||||
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
|
||||
_, _, err = jwtMgr.ValidateToken(tampered)
|
||||
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
|
||||
|
||||
// Test with invalid JSON in claims
|
||||
invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json"))
|
||||
_, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature")
|
||||
assert.Error(t, err)
|
||||
// Wrong algorithm
|
||||
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
rsaMgr, _ := NewJWTRSA(rsaKey)
|
||||
rsaToken, _ := rsaMgr.GenerateToken("user1", nil)
|
||||
|
||||
_, _, err = jwtMgr.ValidateToken(rsaToken)
|
||||
assert.ErrorIs(t, err, ErrTokenInvalidSignature)
|
||||
}
|
||||
|
||||
func TestJWTExpiration(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
|
||||
// Create token with 1 second lifetime
|
||||
jwtMgr, err := NewJWT(secret, WithTokenLifetime(1*time.Second), WithLeeway(0))
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := jwtMgr.GenerateToken("user1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be valid immediately
|
||||
_, _, err = jwtMgr.ValidateToken(token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Should be expired
|
||||
_, _, err = jwtMgr.ValidateToken(token)
|
||||
assert.ErrorIs(t, err, ErrTokenExpired)
|
||||
}
|
||||
|
||||
func TestLeeway(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
|
||||
// Create manager with no leeway
|
||||
jwtMgr, err := NewJWT(secret, WithLeeway(0))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually create a token with NotBefore in future
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": "user1",
|
||||
"nbf": now.Add(2 * time.Second).Unix(),
|
||||
"exp": now.Add(1 * time.Hour).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString(secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should fail immediately (not valid yet)
|
||||
_, _, err = jwtMgr.ValidateToken(tokenString)
|
||||
assert.ErrorIs(t, err, ErrTokenNotYetValid)
|
||||
|
||||
// Create manager with leeway
|
||||
jwtMgrWithLeeway, err := NewJWT(secret, WithLeeway(5*time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should pass with leeway
|
||||
_, _, err = jwtMgrWithLeeway.ValidateToken(tokenString)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStandaloneFunctions(t *testing.T) {
|
||||
secret := []byte("test-secret-key-must-be-32-bytes")
|
||||
userID := "standalone-user"
|
||||
claims := map[string]any{"test": "value"}
|
||||
|
||||
// Generate token
|
||||
token, err := GenerateHS256Token(secret, userID, claims, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate token
|
||||
extractedUserID, extractedClaims, err := ValidateHS256Token(secret, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
assert.Equal(t, "value", extractedClaims["test"])
|
||||
|
||||
// Test with short secret
|
||||
_, err = GenerateHS256Token([]byte("short"), userID, claims, 1*time.Hour)
|
||||
assert.Equal(t, ErrSecretTooShort, err)
|
||||
}
|
||||
|
||||
func TestJWTRSAFromPEM(t *testing.T) {
|
||||
// 1. Generate a new RSA key pair for this test
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Encode the private key to PEM format
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
// 3. Encode the public key to PEM format
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyBytes,
|
||||
})
|
||||
|
||||
// 4. Test the PEM constructor for the signer
|
||||
jwtMgr, err := NewJWTRSAFromPEM(privateKeyPEM)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := jwtMgr.GenerateToken("user-from-pem", nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 5. Test the PEM constructor for the verifier
|
||||
verifier, err := NewJWTVerifierFromPEM(publicKeyPEM)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID, _, err := verifier.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "user-from-pem", userID)
|
||||
|
||||
// 6. Test failure cases with invalid data
|
||||
_, err = NewJWTRSAFromPEM([]byte("invalid pem data"))
|
||||
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
|
||||
|
||||
_, err = NewJWTVerifierFromPEM([]byte("invalid pem data"))
|
||||
assert.ErrorIs(t, err, ErrRSAInvalidPEM)
|
||||
}
|
||||
Reference in New Issue
Block a user