315 lines
7.9 KiB
Go
315 lines
7.9 KiB
Go
// FILE: auth/jwt.go
|
|
package auth
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// JWT configuration defaults
|
|
const (
|
|
DefaultTokenLifetime = 24 * time.Hour
|
|
DefaultLeeway = 5 * time.Minute
|
|
)
|
|
|
|
// customClaims extends RegisteredClaims with arbitrary user data
|
|
type customClaims struct {
|
|
jwt.RegisteredClaims
|
|
Extra map[string]any `json:"extra,omitempty"`
|
|
}
|
|
|
|
// JWT manages token generation and validation
|
|
type JWT struct {
|
|
algorithm jwt.SigningMethod
|
|
signKey any // []byte for HMAC, *rsa.PrivateKey for RSA
|
|
verifyKey any // []byte for HMAC, *rsa.PublicKey for RSA
|
|
tokenLifetime time.Duration
|
|
leeway time.Duration
|
|
issuer string
|
|
audience []string
|
|
}
|
|
|
|
// JWTOption configures JWT behavior
|
|
type JWTOption func(*JWT)
|
|
|
|
// WithTokenLifetime sets token expiration duration
|
|
func WithTokenLifetime(d time.Duration) JWTOption {
|
|
return func(j *JWT) {
|
|
if d > 0 {
|
|
j.tokenLifetime = d
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithLeeway sets clock skew tolerance
|
|
func WithLeeway(d time.Duration) JWTOption {
|
|
return func(j *JWT) {
|
|
if d >= 0 {
|
|
j.leeway = d
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithIssuer sets token issuer claim
|
|
func WithIssuer(iss string) JWTOption {
|
|
return func(j *JWT) {
|
|
j.issuer = iss
|
|
}
|
|
}
|
|
|
|
// WithAudience sets token audience claim
|
|
func WithAudience(aud []string) JWTOption {
|
|
return func(j *JWT) {
|
|
j.audience = aud
|
|
}
|
|
}
|
|
|
|
// NewJWT creates JWT manager for HS256 (symmetric)
|
|
func NewJWT(secret []byte, opts ...JWTOption) (*JWT, error) {
|
|
if len(secret) < 32 {
|
|
return nil, ErrSecretTooShort
|
|
}
|
|
|
|
j := &JWT{
|
|
algorithm: jwt.SigningMethodHS256,
|
|
signKey: secret,
|
|
verifyKey: secret,
|
|
tokenLifetime: DefaultTokenLifetime,
|
|
leeway: DefaultLeeway,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(j)
|
|
}
|
|
|
|
return j, nil
|
|
}
|
|
|
|
// NewJWTRSA creates JWT manager for RS256 (asymmetric)
|
|
func NewJWTRSA(privateKey *rsa.PrivateKey, opts ...JWTOption) (*JWT, error) {
|
|
if privateKey == nil {
|
|
return nil, ErrTokenNoPrivateKey
|
|
}
|
|
|
|
j := &JWT{
|
|
algorithm: jwt.SigningMethodRS256,
|
|
signKey: privateKey,
|
|
verifyKey: &privateKey.PublicKey,
|
|
tokenLifetime: DefaultTokenLifetime,
|
|
leeway: DefaultLeeway,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(j)
|
|
}
|
|
|
|
return j, nil
|
|
}
|
|
|
|
// NewJWTVerifier creates JWT manager for verification only (RS256)
|
|
func NewJWTVerifier(publicKey *rsa.PublicKey, opts ...JWTOption) (*JWT, error) {
|
|
if publicKey == nil {
|
|
return nil, ErrTokenNoPublicKey
|
|
}
|
|
|
|
j := &JWT{
|
|
algorithm: jwt.SigningMethodRS256,
|
|
signKey: nil, // Cannot sign
|
|
verifyKey: publicKey,
|
|
tokenLifetime: DefaultTokenLifetime,
|
|
leeway: DefaultLeeway,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(j)
|
|
}
|
|
|
|
return j, nil
|
|
}
|
|
|
|
// GenerateToken creates signed JWT with claims
|
|
func (j *JWT) GenerateToken(userID string, claims map[string]any) (string, error) {
|
|
if userID == "" {
|
|
return "", ErrTokenEmptyUserID
|
|
}
|
|
|
|
if j.signKey == nil {
|
|
return "", ErrTokenNoPrivateKey
|
|
}
|
|
|
|
now := time.Now()
|
|
registeredClaims := jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
Issuer: j.issuer,
|
|
Audience: j.audience,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(j.tokenLifetime)),
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
}
|
|
|
|
token := jwt.NewWithClaims(j.algorithm, customClaims{
|
|
RegisteredClaims: registeredClaims,
|
|
Extra: claims,
|
|
})
|
|
|
|
return token.SignedString(j.signKey)
|
|
}
|
|
|
|
// ValidateToken verifies JWT and extracts claims
|
|
func (j *JWT) ValidateToken(tokenString string) (string, map[string]any, error) {
|
|
parser := jwt.NewParser(
|
|
jwt.WithLeeway(j.leeway),
|
|
jwt.WithAudience(j.audience...),
|
|
jwt.WithIssuer(j.issuer),
|
|
jwt.WithValidMethods([]string{j.algorithm.Alg()}),
|
|
jwt.WithExpirationRequired(),
|
|
)
|
|
|
|
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
|
|
// Algorithm already validated by WithValidMethods
|
|
return j.verifyKey, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return "", nil, mapJWTError(err)
|
|
}
|
|
|
|
claims, ok := token.Claims.(*customClaims)
|
|
if !ok || !token.Valid {
|
|
return "", nil, ErrTokenMalformed
|
|
}
|
|
|
|
return claims.Subject, claims.Extra, nil
|
|
}
|
|
|
|
// mapJWTError translates jwt library errors to auth package errors
|
|
func mapJWTError(err error) error {
|
|
switch {
|
|
case errors.Is(err, jwt.ErrTokenMalformed):
|
|
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
|
case errors.Is(err, jwt.ErrTokenUnverifiable):
|
|
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
|
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
|
|
return fmt.Errorf("%w : %w", ErrTokenInvalidSignature, err)
|
|
case errors.Is(err, jwt.ErrTokenExpired):
|
|
return fmt.Errorf("%w : %w", ErrTokenExpired, err)
|
|
case errors.Is(err, jwt.ErrTokenNotValidYet):
|
|
return fmt.Errorf("%w : %w", ErrTokenNotYetValid, err)
|
|
case errors.Is(err, jwt.ErrTokenInvalidAudience):
|
|
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
|
|
case errors.Is(err, jwt.ErrTokenInvalidIssuer):
|
|
return fmt.Errorf("%w : %w", ErrTokenMissingClaim, err)
|
|
default:
|
|
// Check for algorithm mismatch in error message
|
|
if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
|
|
return fmt.Errorf("%w : %w", ErrTokenAlgorithmMismatch, err)
|
|
}
|
|
return fmt.Errorf("%w : %w", ErrTokenMalformed, err)
|
|
}
|
|
}
|
|
|
|
// Standalone helper functions for one-off operations
|
|
|
|
// GenerateHS256Token creates HS256 JWT without manager instance
|
|
func GenerateHS256Token(secret []byte, userID string, claims map[string]any, lifetime time.Duration) (string, error) {
|
|
if len(secret) < 32 {
|
|
return "", ErrSecretTooShort
|
|
}
|
|
|
|
now := time.Now()
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, customClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(lifetime)),
|
|
},
|
|
Extra: claims,
|
|
})
|
|
|
|
return token.SignedString(secret)
|
|
}
|
|
|
|
// ValidateHS256Token verifies HS256 JWT without manager instance
|
|
func ValidateHS256Token(secret []byte, tokenString string) (string, map[string]any, error) {
|
|
if len(secret) < 32 {
|
|
return "", nil, ErrSecretTooShort
|
|
}
|
|
|
|
parser := jwt.NewParser(
|
|
jwt.WithValidMethods([]string{"HS256"}),
|
|
jwt.WithLeeway(DefaultLeeway),
|
|
)
|
|
|
|
token, err := parser.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (any, error) {
|
|
return secret, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return "", nil, mapJWTError(err)
|
|
}
|
|
|
|
claims, ok := token.Claims.(*customClaims)
|
|
if !ok || !token.Valid {
|
|
return "", nil, ErrTokenMalformed
|
|
}
|
|
|
|
return claims.Subject, claims.Extra, nil
|
|
}
|
|
|
|
// RSA Utilities
|
|
|
|
// NewJWTRSAFromPEM creates a JWT manager for RS256 from raw PEM-encoded private key data.
|
|
func NewJWTRSAFromPEM(privateKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
|
|
privateKey, err := parseRSAPrivateKey(privateKeyPEM)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Call the original constructor with the now-parsed key
|
|
return NewJWTRSA(privateKey, opts...)
|
|
}
|
|
|
|
// NewJWTVerifierFromPEM creates a JWT manager for verification from raw PEM-encoded public key data.
|
|
func NewJWTVerifierFromPEM(publicKeyPEM []byte, opts ...JWTOption) (*JWT, error) {
|
|
publicKey, err := parseRSAPublicKey(publicKeyPEM)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Call the original constructor with the now-parsed key
|
|
return NewJWTVerifier(publicKey, opts...)
|
|
}
|
|
|
|
// parseRSAPrivateKey parses a PEM-encoded RSA private key.
|
|
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
|
block, _ := pem.Decode(pemBytes)
|
|
if block == nil {
|
|
return nil, ErrRSAInvalidPEM
|
|
}
|
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, ErrRSAInvalidPrivateKey
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
// parseRSAPublicKey parses a PEM-encoded RSA public key.
|
|
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
|
|
block, _ := pem.Decode(pemBytes)
|
|
if block == nil {
|
|
return nil, ErrRSAInvalidPEM
|
|
}
|
|
pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, ErrRSAInvalidPublicKey
|
|
}
|
|
pubKey, ok := pubInterface.(*rsa.PublicKey)
|
|
if !ok {
|
|
return nil, ErrRSANotPublicKey
|
|
}
|
|
return pubKey, nil
|
|
} |