v0.2.0 restructured and generalized to be more modular, added golang-jwt dependency
This commit is contained in:
408
jwt.go
408
jwt.go
@ -2,179 +2,289 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// GenerateToken creates a JWT token with user claims
|
||||
func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) {
|
||||
// JWT configuration defaults
|
||||
const (
|
||||
DefaultTokenLifetime = 24 * time.Hour
|
||||
DefaultLeeway = 5 * time.Minute
|
||||
)
|
||||
|
||||
// customClaims extends RegisteredClaims with arbitrary user data
|
||||
type customClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Extra map[string]any `json:"extra,omitempty"`
|
||||
}
|
||||
|
||||
// JWT manages token generation and validation
|
||||
type JWT struct {
|
||||
algorithm jwt.SigningMethod
|
||||
signKey any // []byte for HMAC, *rsa.PrivateKey for RSA
|
||||
verifyKey any // []byte for HMAC, *rsa.PublicKey for RSA
|
||||
tokenLifetime time.Duration
|
||||
leeway time.Duration
|
||||
issuer string
|
||||
audience []string
|
||||
}
|
||||
|
||||
// JWTOption configures JWT behavior
|
||||
type JWTOption func(*JWT)
|
||||
|
||||
// WithTokenLifetime sets token expiration duration
|
||||
func WithTokenLifetime(d time.Duration) JWTOption {
|
||||
return func(j *JWT) {
|
||||
if d > 0 {
|
||||
j.tokenLifetime = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithLeeway sets clock skew tolerance
|
||||
func WithLeeway(d time.Duration) JWTOption {
|
||||
return func(j *JWT) {
|
||||
if d >= 0 {
|
||||
j.leeway = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithIssuer sets token issuer claim
|
||||
func WithIssuer(iss string) JWTOption {
|
||||
return func(j *JWT) {
|
||||
j.issuer = iss
|
||||
}
|
||||
}
|
||||
|
||||
// WithAudience sets token audience claim
|
||||
func WithAudience(aud []string) JWTOption {
|
||||
return func(j *JWT) {
|
||||
j.audience = aud
|
||||
}
|
||||
}
|
||||
|
||||
// NewJWT creates JWT manager for HS256 (symmetric)
|
||||
func NewJWT(secret []byte, opts ...JWTOption) (*JWT, error) {
|
||||
if len(secret) < 32 {
|
||||
return nil, ErrSecretTooShort
|
||||
}
|
||||
|
||||
j := &JWT{
|
||||
algorithm: jwt.SigningMethodHS256,
|
||||
signKey: secret,
|
||||
verifyKey: secret,
|
||||
tokenLifetime: DefaultTokenLifetime,
|
||||
leeway: DefaultLeeway,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(j)
|
||||
}
|
||||
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// NewJWTRSA creates JWT manager for RS256 (asymmetric)
|
||||
func NewJWTRSA(privateKey *rsa.PrivateKey, opts ...JWTOption) (*JWT, error) {
|
||||
if privateKey == nil {
|
||||
return nil, ErrTokenNoPrivateKey
|
||||
}
|
||||
|
||||
j := &JWT{
|
||||
algorithm: jwt.SigningMethodRS256,
|
||||
signKey: privateKey,
|
||||
verifyKey: &privateKey.PublicKey,
|
||||
tokenLifetime: DefaultTokenLifetime,
|
||||
leeway: DefaultLeeway,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(j)
|
||||
}
|
||||
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// NewJWTVerifier creates JWT manager for verification only (RS256)
|
||||
func NewJWTVerifier(publicKey *rsa.PublicKey, opts ...JWTOption) (*JWT, error) {
|
||||
if publicKey == nil {
|
||||
return nil, ErrTokenNoPublicKey
|
||||
}
|
||||
|
||||
j := &JWT{
|
||||
algorithm: jwt.SigningMethodRS256,
|
||||
signKey: nil, // Cannot sign
|
||||
verifyKey: publicKey,
|
||||
tokenLifetime: DefaultTokenLifetime,
|
||||
leeway: DefaultLeeway,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(j)
|
||||
}
|
||||
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// GenerateToken creates signed JWT with claims
|
||||
func (j *JWT) GenerateToken(userID string, claims map[string]any) (string, error) {
|
||||
if userID == "" {
|
||||
return "", ErrTokenEmptyUserID
|
||||
}
|
||||
|
||||
if a.algorithm == "RS256" && a.privateKey == nil {
|
||||
if j.signKey == nil {
|
||||
return "", ErrTokenNoPrivateKey
|
||||
}
|
||||
|
||||
// Build JWT claims
|
||||
now := time.Now()
|
||||
jwtClaims := map[string]any{
|
||||
"sub": userID,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry
|
||||
registeredClaims := jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
Issuer: j.issuer,
|
||||
Audience: j.audience,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(j.tokenLifetime)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
}
|
||||
|
||||
// Reserved claims that cannot be overridden
|
||||
reservedClaims := map[string]bool{
|
||||
"sub": true, "iat": true, "exp": true, "nbf": true,
|
||||
"iss": true, "aud": true, "jti": true, "typ": true,
|
||||
"alg": true,
|
||||
}
|
||||
token := jwt.NewWithClaims(j.algorithm, customClaims{
|
||||
RegisteredClaims: registeredClaims,
|
||||
Extra: claims,
|
||||
})
|
||||
|
||||
// Merge custom claims
|
||||
for k, v := range claims {
|
||||
if !reservedClaims[k] {
|
||||
jwtClaims[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Create JWT header
|
||||
header := map[string]any{
|
||||
"alg": a.algorithm,
|
||||
"typ": "JWT",
|
||||
}
|
||||
|
||||
// Encode header and payload
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
claimsJSON, _ := json.Marshal(jwtClaims)
|
||||
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
// Create signature
|
||||
signingInput := headerB64 + "." + claimsB64
|
||||
var signature string
|
||||
|
||||
switch a.algorithm {
|
||||
case "HS256":
|
||||
h := hmac.New(sha256.New, a.jwtSecret)
|
||||
h.Write([]byte(signingInput))
|
||||
signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
case "RS256":
|
||||
hashed := sha256.Sum256([]byte(signingInput))
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
signature = base64.RawURLEncoding.EncodeToString(sig)
|
||||
}
|
||||
|
||||
// Combine to form JWT
|
||||
token := signingInput + "." + signature
|
||||
|
||||
return token, nil
|
||||
return token.SignedString(j.signKey)
|
||||
}
|
||||
|
||||
// ValidateToken verifies JWT and returns userID and claims
|
||||
func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) {
|
||||
// Split token
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
// ValidateToken verifies JWT and extracts claims
|
||||
func (j *JWT) ValidateToken(tokenString string) (string, map[string]any, error) {
|
||||
parser := jwt.NewParser(
|
||||
jwt.WithLeeway(j.leeway),
|
||||
jwt.WithAudience(j.audience...),
|
||||
jwt.WithIssuer(j.issuer),
|
||||
jwt.WithValidMethods([]string{j.algorithm.Alg()}),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
|
||||
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// Algorithm already validated by WithValidMethods
|
||||
return j.verifyKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", nil, mapJWTError(err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*customClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", nil, ErrTokenMalformed
|
||||
}
|
||||
|
||||
// Decode header to check algorithm
|
||||
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidHeader
|
||||
}
|
||||
|
||||
var header map[string]any
|
||||
if err = json.Unmarshal(headerJSON, &header); err != nil {
|
||||
return "", nil, ErrTokenInvalidJSON
|
||||
}
|
||||
|
||||
// Verify algorithm matches
|
||||
if alg, ok := header["alg"].(string); !ok || alg != a.algorithm {
|
||||
return "", nil, ErrTokenAlgorithmMismatch
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
signingInput := parts[0] + "." + parts[1]
|
||||
|
||||
switch a.algorithm {
|
||||
case "HS256":
|
||||
h := hmac.New(sha256.New, a.jwtSecret)
|
||||
h.Write([]byte(signingInput))
|
||||
expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
|
||||
case "RS256":
|
||||
if a.publicKey == nil {
|
||||
return "", nil, ErrTokenNoPublicKey
|
||||
}
|
||||
|
||||
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
|
||||
hashed := sha256.Sum256([]byte(signingInput))
|
||||
if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
}
|
||||
|
||||
// Decode claims
|
||||
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
|
||||
return "", nil, ErrTokenInvalidJSON
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
return "", nil, ErrTokenExpired
|
||||
}
|
||||
}
|
||||
|
||||
// Check not before
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if time.Now().Unix() < int64(nbf) {
|
||||
return "", nil, ErrTokenNotYetValid
|
||||
}
|
||||
}
|
||||
|
||||
// Extract userID
|
||||
userID, ok := claims["sub"].(string)
|
||||
if !ok {
|
||||
return "", nil, ErrTokenMissingClaim
|
||||
}
|
||||
|
||||
return userID, claims, nil
|
||||
return claims.Subject, claims.Extra, nil
|
||||
}
|
||||
|
||||
// parseRSAPrivateKey parses PEM encoded RSA private key
|
||||
// mapJWTError translates jwt library errors to auth package errors
|
||||
func mapJWTError(err error) error {
|
||||
switch {
|
||||
case errors.Is(err, jwt.ErrTokenMalformed):
|
||||
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||
case errors.Is(err, jwt.ErrTokenUnverifiable):
|
||||
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
|
||||
return fmt.Errorf("%w : %w", ErrTokenInvalidSignature, err)
|
||||
case errors.Is(err, jwt.ErrTokenExpired):
|
||||
return fmt.Errorf("%w : %w", ErrTokenExpired, err)
|
||||
case errors.Is(err, jwt.ErrTokenNotValidYet):
|
||||
return fmt.Errorf("%w : %w", ErrTokenNotYetValid, err)
|
||||
case errors.Is(err, jwt.ErrTokenInvalidAudience):
|
||||
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
|
||||
case errors.Is(err, jwt.ErrTokenInvalidIssuer):
|
||||
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
|
||||
default:
|
||||
// Check for algorithm mismatch in error message
|
||||
if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
|
||||
return fmt.Errorf("%w : %w", ErrTokenAlgorithmMismatch, err)
|
||||
}
|
||||
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Standalone helper functions for one-off operations
|
||||
|
||||
// GenerateHS256Token creates HS256 JWT without manager instance
|
||||
func GenerateHS256Token(secret []byte, userID string, claims map[string]any, lifetime time.Duration) (string, error) {
|
||||
if len(secret) < 32 {
|
||||
return "", ErrSecretTooShort
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, customClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(lifetime)),
|
||||
},
|
||||
Extra: claims,
|
||||
})
|
||||
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// ValidateHS256Token verifies HS256 JWT without manager instance
|
||||
func ValidateHS256Token(secret []byte, tokenString string) (string, map[string]any, error) {
|
||||
if len(secret) < 32 {
|
||||
return "", nil, ErrSecretTooShort
|
||||
}
|
||||
|
||||
parser := jwt.NewParser(
|
||||
jwt.WithValidMethods([]string{"HS256"}),
|
||||
jwt.WithLeeway(DefaultLeeway),
|
||||
)
|
||||
|
||||
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
|
||||
return secret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", nil, mapJWTError(err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*customClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", nil, ErrTokenMalformed
|
||||
}
|
||||
|
||||
return claims.Subject, claims.Extra, nil
|
||||
}
|
||||
|
||||
// RSA Utilities
|
||||
|
||||
// NewJWTRSAFromPEM creates a JWT manager for RS256 from raw PEM-encoded private key data.
|
||||
func NewJWTRSAFromPEM(privateKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
|
||||
privateKey, err := parseRSAPrivateKey(privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Call the original constructor with the now-parsed key
|
||||
return NewJWTRSA(privateKey, opts...)
|
||||
}
|
||||
|
||||
// NewJWTVerifierFromPEM creates a JWT manager for verification from raw PEM-encoded public key data.
|
||||
func NewJWTVerifierFromPEM(publicKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
|
||||
publicKey, err := parseRSAPublicKey(publicKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Call the original constructor with the now-parsed key
|
||||
return NewJWTVerifier(publicKey, opts...)
|
||||
}
|
||||
|
||||
// parseRSAPrivateKey parses a PEM-encoded RSA private key.
|
||||
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
@ -187,7 +297,7 @@ func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// parseRSAPublicKey parses PEM encoded RSA public key
|
||||
// parseRSAPublicKey parses a PEM-encoded RSA public key.
|
||||
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
|
||||
Reference in New Issue
Block a user