v0.1.0 initial commit, auth features exatracted from logwisp to be a standalone utility package
This commit is contained in:
12
.gitignore
vendored
Normal file
12
.gitignore
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
.idea
|
||||
data
|
||||
dev
|
||||
log
|
||||
logs
|
||||
cert
|
||||
bin
|
||||
script
|
||||
build
|
||||
*.log
|
||||
*.toml
|
||||
build.sh
|
||||
28
LICENSE
Normal file
28
LICENSE
Normal file
@ -0,0 +1,28 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2025, Lixen Wraith
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
38
README.md
Normal file
38
README.md
Normal file
@ -0,0 +1,38 @@
|
||||
# Auth Package
|
||||
|
||||
Pluggable authentication utilities for Go applications.
|
||||
|
||||
## Features
|
||||
|
||||
- **Password Hashing**: Argon2id with PHC format
|
||||
- **JWT**: HS256/RS256 token generation and validation
|
||||
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF
|
||||
- **HTTP Auth**: Basic/Bearer header parsing
|
||||
|
||||
## Usage
|
||||
```go
|
||||
// JWT with HS256
|
||||
auth, _ := auth.NewAuthenticator([]byte("32-byte-secret-key..."))
|
||||
token, _ := auth.GenerateToken("user123", map[string]interface{}{"role": "admin"})
|
||||
userID, claims, _ := auth.ValidateToken(token)
|
||||
|
||||
// SCRAM authentication
|
||||
server := auth.NewScramServer()
|
||||
cred, _ := auth.DeriveCredential("user", "password", salt, 1, 65536, 4)
|
||||
server.AddCredential(cred)
|
||||
```
|
||||
|
||||
## Package Structure
|
||||
|
||||
- `interfaces.go` - Core interfaces
|
||||
- `jwt.go` - JWT token operations
|
||||
- `argon2.go` - Password hashing
|
||||
- `scram.go` - SCRAM-SHA256 protocol
|
||||
- `token.go` - Token validation utilities
|
||||
- `http.go` - HTTP header parsing
|
||||
- `errors.go` - Error definitions
|
||||
|
||||
## Testing
|
||||
```bash
|
||||
go test -v ./auth
|
||||
```
|
||||
110
argon2.go
Normal file
110
argon2.go
Normal file
@ -0,0 +1,110 @@
|
||||
// FILE: auth/argon2.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
// Default Argon2id parameters
|
||||
const (
|
||||
DefaultArgonTime = 3 // iterations (reduce for faster but less secure auth)
|
||||
DefaultArgonMemory = 64 * 1024 // 64 MB
|
||||
DefaultArgonThreads = 4
|
||||
DefaultArgonSaltLen = 16
|
||||
DefaultArgonKeyLen = 32
|
||||
)
|
||||
|
||||
// HashPassword creates an Argon2id PHC-format hash
|
||||
func (a *Authenticator) HashPassword(password string) (string, error) {
|
||||
if len(password) < 8 {
|
||||
return "", ErrWeakPassword
|
||||
}
|
||||
|
||||
// Generate salt
|
||||
salt := make([]byte, DefaultArgonSaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err)
|
||||
}
|
||||
|
||||
// Derive key using Argon2id
|
||||
hash := argon2.IDKey([]byte(password), salt, a.argonTime, a.argonMemory, a.argonThreads, DefaultArgonKeyLen)
|
||||
|
||||
// Construct PHC format
|
||||
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
||||
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, a.argonMemory, a.argonTime, a.argonThreads, saltB64, hashB64)
|
||||
|
||||
return phcHash, nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks password against PHC-format hash
|
||||
func (a *Authenticator) VerifyPassword(password, phcHash string) error {
|
||||
// Parse PHC format
|
||||
parts := strings.Split(phcHash, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return ErrPHCInvalidFormat
|
||||
}
|
||||
|
||||
var memory, time uint32
|
||||
var threads uint8
|
||||
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrPHCInvalidSalt, err)
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
|
||||
}
|
||||
|
||||
// Compute hash with same parameters
|
||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||
|
||||
// Constant-time comparison
|
||||
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential
|
||||
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
||||
// Parse PHC format
|
||||
parts := strings.Split(phcHash, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return nil, ErrPHCInvalidFormat
|
||||
}
|
||||
|
||||
var memory, time uint32
|
||||
var threads uint8
|
||||
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return nil, ErrPHCInvalidSalt
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return nil, ErrPHCInvalidHash
|
||||
}
|
||||
|
||||
// Verify password against hash
|
||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Derive SCRAM credential with same parameters
|
||||
return DeriveCredential(username, password, salt, time, memory, threads)
|
||||
}
|
||||
117
argon2_test.go
Normal file
117
argon2_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
// FILE: auth/argon2_test.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPasswordHashing(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err, "Failed to create authenticator")
|
||||
|
||||
password := "testPassword123"
|
||||
|
||||
// Test hashing
|
||||
hash, err := auth.HashPassword(password)
|
||||
require.NoError(t, err, "Failed to hash password")
|
||||
|
||||
// Verify PHC format
|
||||
assert.True(t, strings.HasPrefix(hash, "$argon2id$"),
|
||||
"Hash should have argon2id prefix, got: %s", hash)
|
||||
|
||||
// Test verification with correct password
|
||||
err = auth.VerifyPassword(password, hash)
|
||||
assert.NoError(t, err, "Failed to verify correct password")
|
||||
|
||||
// Test verification with incorrect password
|
||||
err = auth.VerifyPassword("wrongPassword", hash)
|
||||
assert.Error(t, err, "Verification should fail for incorrect password")
|
||||
assert.Equal(t, ErrInvalidCredentials, err)
|
||||
|
||||
// Test weak password
|
||||
_, err = auth.HashPassword("weak")
|
||||
assert.Equal(t, ErrWeakPassword, err, "Should reject weak password")
|
||||
|
||||
// Test malformed PHC hash
|
||||
err = auth.VerifyPassword(password, "$invalid$format")
|
||||
assert.Error(t, err, "Should reject malformed hash")
|
||||
|
||||
// Test corrupted salt
|
||||
corruptedHash := strings.Replace(hash, "$argon2id$", "$argon2id$", 1)
|
||||
parts := strings.Split(corruptedHash, "$")
|
||||
if len(parts) == 6 {
|
||||
parts[4] = "invalid!base64"
|
||||
corruptedHash = strings.Join(parts, "$")
|
||||
err = auth.VerifyPassword(password, corruptedHash)
|
||||
assert.Error(t, err, "Should reject corrupted salt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyPasswordAfterValidation(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Empty password should be rejected by length check
|
||||
_, err = auth.HashPassword("")
|
||||
assert.Equal(t, ErrWeakPassword, err)
|
||||
|
||||
// 8-character password should pass
|
||||
hash, err := auth.HashPassword("12345678")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = auth.VerifyPassword("12345678", hash)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConcurrentPasswordOperations(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
password := "testPassword123"
|
||||
hash, err := auth.HashPassword(password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test concurrent verification
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := auth.VerifyPassword(password, hash)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPHCMigration(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
password := "testPassword123"
|
||||
username := "migrationUser"
|
||||
|
||||
// Generate PHC hash
|
||||
phcHash, err := auth.HashPassword(password)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Migrate to SCRAM credential
|
||||
cred, err := MigrateFromPHC(username, password, phcHash)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, username, cred.Username)
|
||||
assert.NotNil(t, cred.StoredKey)
|
||||
assert.NotNil(t, cred.ServerKey)
|
||||
|
||||
// Test with wrong password
|
||||
_, err = MigrateFromPHC(username, "wrongPassword", phcHash)
|
||||
assert.Equal(t, ErrInvalidCredentials, err)
|
||||
|
||||
// Test with invalid PHC format
|
||||
_, err = MigrateFromPHC(username, password, "$invalid$format")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
71
auth.go
Normal file
71
auth.go
Normal file
@ -0,0 +1,71 @@
|
||||
// FILE: auth/auth.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Authenticator provides password hashing and JWT operations
|
||||
type Authenticator struct {
|
||||
algorithm string
|
||||
jwtSecret []byte // For HS256
|
||||
privateKey *rsa.PrivateKey // For RS256
|
||||
publicKey *rsa.PublicKey // For RS256
|
||||
argonTime uint32
|
||||
argonMemory uint32
|
||||
argonThreads uint8
|
||||
}
|
||||
|
||||
// NewAuthenticator creates a new authenticator with specified algorithm
|
||||
func NewAuthenticator(key any, algorithm ...string) (*Authenticator, error) {
|
||||
alg := "HS256"
|
||||
if len(algorithm) > 0 && algorithm[0] != "" {
|
||||
alg = algorithm[0]
|
||||
}
|
||||
|
||||
auth := &Authenticator{
|
||||
algorithm: alg,
|
||||
argonTime: DefaultArgonTime,
|
||||
argonMemory: DefaultArgonMemory,
|
||||
argonThreads: DefaultArgonThreads,
|
||||
}
|
||||
|
||||
switch alg {
|
||||
case "HS256":
|
||||
secret, ok := key.([]byte)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKeyType
|
||||
}
|
||||
if len(secret) < 32 {
|
||||
return nil, ErrSecretTooShort
|
||||
}
|
||||
auth.jwtSecret = secret
|
||||
|
||||
case "RS256":
|
||||
switch k := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
auth.privateKey = k
|
||||
auth.publicKey = &k.PublicKey
|
||||
case *rsa.PublicKey:
|
||||
auth.publicKey = k
|
||||
case []byte:
|
||||
// Try parsing as PEM
|
||||
if privKey, err := parseRSAPrivateKey(k); err == nil {
|
||||
auth.privateKey = privKey
|
||||
auth.publicKey = &privKey.PublicKey
|
||||
} else if pubKey, err := parseRSAPublicKey(k); err == nil {
|
||||
auth.publicKey = pubKey
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to parse RSA key: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, ErrInvalidKeyType
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, ErrInvalidAlgorithm
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
68
auth_test.go
Normal file
68
auth_test.go
Normal file
@ -0,0 +1,68 @@
|
||||
// FILE: auth/auth_test.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewAuthenticator(t *testing.T) {
|
||||
// Test HS256 creation
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err, "Failed to create HS256 authenticator")
|
||||
assert.Equal(t, "HS256", auth.algorithm)
|
||||
|
||||
// Test with short secret
|
||||
_, err = NewAuthenticator([]byte("short"))
|
||||
assert.Equal(t, ErrSecretTooShort, err)
|
||||
|
||||
// Test RS256 with private key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
authRS, err := NewAuthenticator(privateKey, "RS256")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RS256", authRS.algorithm)
|
||||
assert.NotNil(t, authRS.privateKey)
|
||||
assert.NotNil(t, authRS.publicKey)
|
||||
|
||||
// Test RS256 with public key only
|
||||
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RS256", authPub.algorithm)
|
||||
assert.Nil(t, authPub.privateKey)
|
||||
assert.NotNil(t, authPub.publicKey)
|
||||
|
||||
// Test invalid algorithm
|
||||
_, err = NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"), "INVALID")
|
||||
assert.Equal(t, ErrInvalidAlgorithm, err)
|
||||
|
||||
// Test invalid key type for HS256
|
||||
_, err = NewAuthenticator(privateKey, "HS256")
|
||||
assert.Equal(t, ErrInvalidKeyType, err)
|
||||
}
|
||||
|
||||
func TestInterfaceCompliance(t *testing.T) {
|
||||
// Verify Authenticator implements AuthenticatorInterface
|
||||
auth, _ := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
|
||||
var _ AuthenticatorInterface = auth
|
||||
|
||||
// Test interface methods work
|
||||
hash, err := auth.HashPassword("testpass123")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = auth.VerifyPassword("testpass123", hash)
|
||||
assert.NoError(t, err)
|
||||
|
||||
token, err := auth.GenerateToken("user1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID, _, err := auth.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "user1", userID)
|
||||
}
|
||||
100
error.go
Normal file
100
error.go
Normal file
@ -0,0 +1,100 @@
|
||||
// FILE: auth/errors.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Base authentication errors
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrWeakPassword = errors.New("password must be at least 8 characters")
|
||||
ErrInvalidAlgorithm = errors.New("invalid algorithm")
|
||||
ErrInvalidKeyType = errors.New("invalid key type for algorithm")
|
||||
)
|
||||
|
||||
// JWT-specific errors
|
||||
var (
|
||||
ErrTokenMalformed = errors.New("token: malformed structure")
|
||||
ErrTokenExpired = errors.New("token: expired")
|
||||
ErrTokenNotYetValid = errors.New("token: not yet valid")
|
||||
ErrTokenInvalidSignature = errors.New("token: invalid signature")
|
||||
ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch")
|
||||
ErrTokenMissingClaim = errors.New("token: missing required claim")
|
||||
ErrTokenInvalidHeader = errors.New("token: invalid header encoding")
|
||||
ErrTokenInvalidClaims = errors.New("token: invalid claims encoding")
|
||||
ErrTokenInvalidJSON = errors.New("token: malformed JSON")
|
||||
ErrTokenEmptyUserID = errors.New("token: empty user ID")
|
||||
ErrTokenNoPrivateKey = errors.New("token: private key required for signing")
|
||||
ErrTokenNoPublicKey = errors.New("token: public key required for verification")
|
||||
)
|
||||
|
||||
// JWT secret errors
|
||||
var (
|
||||
ErrSecretTooShort = errors.New("JWT secret must be at least 32 bytes")
|
||||
)
|
||||
|
||||
// RSA key parsing errors
|
||||
var (
|
||||
ErrRSAInvalidPEM = errors.New("rsa: failed to parse PEM block")
|
||||
ErrRSAInvalidPrivateKey = errors.New("rsa: invalid private key format")
|
||||
ErrRSAInvalidPublicKey = errors.New("rsa: invalid public key format")
|
||||
ErrRSANotPublicKey = errors.New("rsa: not an RSA public key")
|
||||
)
|
||||
|
||||
// PHC format errors
|
||||
var (
|
||||
ErrPHCInvalidFormat = errors.New("phc: invalid format")
|
||||
ErrPHCInvalidSalt = errors.New("phc: invalid salt encoding")
|
||||
ErrPHCInvalidHash = errors.New("phc: invalid hash encoding")
|
||||
)
|
||||
|
||||
// SCRAM-specific errors
|
||||
var (
|
||||
ErrSCRAMInvalidNonce = errors.New("scram: invalid nonce or expired handshake")
|
||||
ErrSCRAMTimeout = errors.New("scram: handshake timeout")
|
||||
ErrSCRAMVerifyInProgress = errors.New("scram: verification already in progress")
|
||||
ErrSCRAMInvalidProof = errors.New("scram: invalid proof encoding")
|
||||
ErrSCRAMInvalidProofLen = errors.New("scram: invalid proof length")
|
||||
ErrSCRAMServerAuthFailed = errors.New("scram: server authentication failed")
|
||||
ErrSCRAMInvalidState = errors.New("scram: invalid handshake state")
|
||||
ErrSCRAMInvalidSalt = errors.New("scram: invalid salt encoding")
|
||||
ErrSCRAMZeroParams = errors.New("scram: invalid Argon2 parameters")
|
||||
ErrSCRAMSaltTooShort = errors.New("scram: salt must be at least 16 bytes")
|
||||
ErrSCRAMNonceGenFailed = errors.New("scram: failed to generate nonce")
|
||||
)
|
||||
|
||||
// Credential import/export errors
|
||||
var (
|
||||
ErrCredMissingUsername = errors.New("credential: missing username")
|
||||
ErrCredMissingSalt = errors.New("credential: missing salt")
|
||||
ErrCredInvalidSalt = errors.New("credential: invalid salt encoding")
|
||||
ErrCredMissingTime = errors.New("credential: missing argon_time")
|
||||
ErrCredMissingMemory = errors.New("credential: missing argon_memory")
|
||||
ErrCredMissingThreads = errors.New("credential: missing argon_threads")
|
||||
ErrCredMissingStoredKey = errors.New("credential: missing stored_key")
|
||||
ErrCredInvalidStoredKey = errors.New("credential: invalid stored_key encoding")
|
||||
ErrCredMissingServerKey = errors.New("credential: missing server_key")
|
||||
ErrCredInvalidServerKey = errors.New("credential: invalid server_key encoding")
|
||||
ErrCredInvalidType = fmt.Errorf("credential: invalid type for field")
|
||||
)
|
||||
|
||||
// HTTP auth parsing errors
|
||||
var (
|
||||
ErrAuthInvalidBasicFormat = errors.New("auth: invalid Basic auth format")
|
||||
ErrAuthInvalidBasicEncoding = errors.New("auth: invalid Basic auth base64 encoding")
|
||||
ErrAuthInvalidBasicCreds = errors.New("auth: invalid Basic auth credentials format")
|
||||
ErrAuthInvalidBearerFormat = errors.New("auth: invalid Bearer auth format")
|
||||
ErrAuthEmptyBearerToken = errors.New("auth: empty Bearer token")
|
||||
)
|
||||
|
||||
// Salt generation errors
|
||||
var (
|
||||
ErrSaltGenerationFailed = errors.New("failed to generate salt")
|
||||
)
|
||||
|
||||
// Key generation errors
|
||||
var (
|
||||
ErrRSAKeyGenFailed = errors.New("failed to generate RSA key")
|
||||
)
|
||||
15
go.mod
Normal file
15
go.mod
Normal file
@ -0,0 +1,15 @@
|
||||
module auth
|
||||
|
||||
go 1.25.3
|
||||
|
||||
require (
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/crypto v0.43.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
14
go.sum
Normal file
14
go.sum
Normal file
@ -0,0 +1,14 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
61
http.go
Normal file
61
http.go
Normal file
@ -0,0 +1,61 @@
|
||||
// FILE: auth/http.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseBasicAuth extracts username/password from Basic auth header
|
||||
func ParseBasicAuth(header string) (username, password string, err error) {
|
||||
const prefix = "Basic "
|
||||
if !strings.HasPrefix(header, prefix) {
|
||||
return "", "", ErrAuthInvalidBasicFormat
|
||||
}
|
||||
|
||||
encoded := strings.TrimPrefix(header, prefix)
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return "", "", ErrAuthInvalidBasicEncoding
|
||||
}
|
||||
|
||||
credentials := string(decoded)
|
||||
idx := strings.IndexByte(credentials, ':')
|
||||
if idx < 0 {
|
||||
return "", "", ErrAuthInvalidBasicCreds
|
||||
}
|
||||
|
||||
return credentials[:idx], credentials[idx+1:], nil
|
||||
}
|
||||
|
||||
// ParseBearerToken extracts token from Bearer auth header
|
||||
func ParseBearerToken(header string) (token string, err error) {
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(header, prefix) {
|
||||
return "", ErrAuthInvalidBearerFormat
|
||||
}
|
||||
|
||||
token = strings.TrimPrefix(header, prefix)
|
||||
if token == "" {
|
||||
return "", ErrAuthEmptyBearerToken
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ExtractAuthType returns authentication type from header
|
||||
func ExtractAuthType(header string) string {
|
||||
if strings.HasPrefix(header, "Basic ") {
|
||||
return "Basic"
|
||||
}
|
||||
if strings.HasPrefix(header, "Bearer ") {
|
||||
return "Bearer"
|
||||
}
|
||||
|
||||
// Extract first word as auth type
|
||||
idx := strings.IndexByte(header, ' ')
|
||||
if idx > 0 {
|
||||
return header[:idx]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
54
http_test.go
Normal file
54
http_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
// FILE: auth/http_test.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHTTPAuthParsing(t *testing.T) {
|
||||
// Test Basic Auth
|
||||
basicHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
|
||||
username, password, err := ParseBasicAuth(basicHeader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "user", username)
|
||||
assert.Equal(t, "pass", password)
|
||||
|
||||
// Test Bearer Token
|
||||
bearerHeader := "Bearer test-token-xyz"
|
||||
token, err := ParseBearerToken(bearerHeader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-token-xyz", token)
|
||||
|
||||
// Test ExtractAuthType
|
||||
assert.Equal(t, "Basic", ExtractAuthType(basicHeader))
|
||||
assert.Equal(t, "Bearer", ExtractAuthType(bearerHeader))
|
||||
assert.Equal(t, "Custom", ExtractAuthType("Custom somedata"))
|
||||
assert.Equal(t, "", ExtractAuthType("InvalidHeader"))
|
||||
|
||||
// Test invalid formats
|
||||
_, _, err = ParseBasicAuth("Invalid header")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrAuthInvalidBasicFormat, err)
|
||||
|
||||
_, err = ParseBearerToken("Invalid header")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrAuthInvalidBearerFormat, err)
|
||||
|
||||
// Test malformed Basic auth
|
||||
_, _, err = ParseBasicAuth("Basic not-base64!")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrAuthInvalidBasicEncoding, err)
|
||||
|
||||
_, _, err = ParseBasicAuth("Basic " + base64.StdEncoding.EncodeToString([]byte("no-colon")))
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrAuthInvalidBasicCreds, err)
|
||||
|
||||
// Test empty Bearer token
|
||||
_, err = ParseBearerToken("Bearer ")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrAuthEmptyBearerToken, err)
|
||||
}
|
||||
17
interface.go
Normal file
17
interface.go
Normal file
@ -0,0 +1,17 @@
|
||||
// FILE: auth/interface.go
|
||||
package auth
|
||||
|
||||
// AuthenticatorInterface defines the authentication operations
|
||||
type AuthenticatorInterface interface {
|
||||
HashPassword(password string) (hash string, err error)
|
||||
VerifyPassword(password, hash string) (err error)
|
||||
GenerateToken(userID string, claims map[string]any) (token string, err error)
|
||||
ValidateToken(token string) (userID string, claims map[string]any, err error)
|
||||
}
|
||||
|
||||
// TokenValidator validates bearer tokens
|
||||
type TokenValidator interface {
|
||||
ValidateToken(token string) (valid bool)
|
||||
AddToken(token string)
|
||||
RemoveToken(token string)
|
||||
}
|
||||
205
jwt.go
Normal file
205
jwt.go
Normal file
@ -0,0 +1,205 @@
|
||||
// FILE: auth/jwt.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GenerateToken creates a JWT token with user claims
|
||||
func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) {
|
||||
if userID == "" {
|
||||
return "", ErrTokenEmptyUserID
|
||||
}
|
||||
|
||||
if a.algorithm == "RS256" && a.privateKey == nil {
|
||||
return "", ErrTokenNoPrivateKey
|
||||
}
|
||||
|
||||
// Build JWT claims
|
||||
now := time.Now()
|
||||
jwtClaims := map[string]any{
|
||||
"sub": userID,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry
|
||||
}
|
||||
|
||||
// Reserved claims that cannot be overridden
|
||||
reservedClaims := map[string]bool{
|
||||
"sub": true, "iat": true, "exp": true, "nbf": true,
|
||||
"iss": true, "aud": true, "jti": true, "typ": true,
|
||||
"alg": true,
|
||||
}
|
||||
|
||||
// Merge custom claims
|
||||
for k, v := range claims {
|
||||
if !reservedClaims[k] {
|
||||
jwtClaims[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Create JWT header
|
||||
header := map[string]any{
|
||||
"alg": a.algorithm,
|
||||
"typ": "JWT",
|
||||
}
|
||||
|
||||
// Encode header and payload
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
claimsJSON, _ := json.Marshal(jwtClaims)
|
||||
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
// Create signature
|
||||
signingInput := headerB64 + "." + claimsB64
|
||||
var signature string
|
||||
|
||||
switch a.algorithm {
|
||||
case "HS256":
|
||||
h := hmac.New(sha256.New, a.jwtSecret)
|
||||
h.Write([]byte(signingInput))
|
||||
signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
case "RS256":
|
||||
hashed := sha256.Sum256([]byte(signingInput))
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
signature = base64.RawURLEncoding.EncodeToString(sig)
|
||||
}
|
||||
|
||||
// Combine to form JWT
|
||||
token := signingInput + "." + signature
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ValidateToken verifies JWT and returns userID and claims
|
||||
func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) {
|
||||
// Split token
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", nil, ErrTokenMalformed
|
||||
}
|
||||
|
||||
// Decode header to check algorithm
|
||||
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidHeader
|
||||
}
|
||||
|
||||
var header map[string]any
|
||||
if err = json.Unmarshal(headerJSON, &header); err != nil {
|
||||
return "", nil, ErrTokenInvalidJSON
|
||||
}
|
||||
|
||||
// Verify algorithm matches
|
||||
if alg, ok := header["alg"].(string); !ok || alg != a.algorithm {
|
||||
return "", nil, ErrTokenAlgorithmMismatch
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
signingInput := parts[0] + "." + parts[1]
|
||||
|
||||
switch a.algorithm {
|
||||
case "HS256":
|
||||
h := hmac.New(sha256.New, a.jwtSecret)
|
||||
h.Write([]byte(signingInput))
|
||||
expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
|
||||
case "RS256":
|
||||
if a.publicKey == nil {
|
||||
return "", nil, ErrTokenNoPublicKey
|
||||
}
|
||||
|
||||
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
|
||||
hashed := sha256.Sum256([]byte(signingInput))
|
||||
if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil {
|
||||
return "", nil, ErrTokenInvalidSignature
|
||||
}
|
||||
}
|
||||
|
||||
// Decode claims
|
||||
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", nil, ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
|
||||
return "", nil, ErrTokenInvalidJSON
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
return "", nil, ErrTokenExpired
|
||||
}
|
||||
}
|
||||
|
||||
// Check not before
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if time.Now().Unix() < int64(nbf) {
|
||||
return "", nil, ErrTokenNotYetValid
|
||||
}
|
||||
}
|
||||
|
||||
// Extract userID
|
||||
userID, ok := claims["sub"].(string)
|
||||
if !ok {
|
||||
return "", nil, ErrTokenMissingClaim
|
||||
}
|
||||
|
||||
return userID, claims, nil
|
||||
}
|
||||
|
||||
// parseRSAPrivateKey parses PEM encoded RSA private key
|
||||
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return nil, ErrRSAInvalidPEM
|
||||
}
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, ErrRSAInvalidPrivateKey
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// parseRSAPublicKey parses PEM encoded RSA public key
|
||||
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return nil, ErrRSAInvalidPEM
|
||||
}
|
||||
pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, ErrRSAInvalidPublicKey
|
||||
}
|
||||
pubKey, ok := pubInterface.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, ErrRSANotPublicKey
|
||||
}
|
||||
return pubKey, nil
|
||||
}
|
||||
176
jwt_test.go
Normal file
176
jwt_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
// FILE: auth/jwt_test.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJWTHS256(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user123"
|
||||
claims := map[string]any{
|
||||
"email": "test@example.com",
|
||||
"role": "admin",
|
||||
}
|
||||
|
||||
// Generate token
|
||||
token, err := auth.GenerateToken(userID, claims)
|
||||
require.NoError(t, err, "Failed to generate token")
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Validate token
|
||||
extractedUserID, extractedClaims, err := auth.ValidateToken(token)
|
||||
require.NoError(t, err, "Failed to validate token")
|
||||
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
assert.Equal(t, "test@example.com", extractedClaims["email"])
|
||||
assert.Equal(t, "admin", extractedClaims["role"])
|
||||
|
||||
// Test invalid token
|
||||
_, _, err = auth.ValidateToken("invalid.token.here")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrTokenInvalidJSON))
|
||||
|
||||
// Test tampered token
|
||||
parts := strings.Split(token, ".")
|
||||
require.Len(t, parts, 3, "JWT should have 3 parts")
|
||||
|
||||
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
|
||||
_, _, err = auth.ValidateToken(tampered)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrTokenInvalidSignature))
|
||||
|
||||
// Test reserved claims cannot be overridden
|
||||
overrideClaims := map[string]any{
|
||||
"sub": "override",
|
||||
"iat": 12345,
|
||||
"exp": 67890,
|
||||
"nbf": 11111,
|
||||
"iss": "attacker",
|
||||
"aud": "victim",
|
||||
"jti": "fake",
|
||||
}
|
||||
|
||||
token, err = auth.GenerateToken(userID, overrideClaims)
|
||||
require.NoError(t, err)
|
||||
|
||||
extractedUserID, extractedClaims, err = auth.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID, "UserID should not be overridden")
|
||||
assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden")
|
||||
assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden")
|
||||
assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims")
|
||||
assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims")
|
||||
}
|
||||
|
||||
func TestJWTRS256(t *testing.T) {
|
||||
// Generate RSA key pair
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with private key (can sign and verify)
|
||||
authPriv, err := NewAuthenticator(privateKey, "RS256")
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user456"
|
||||
claims := map[string]any{
|
||||
"email": "rs256@example.com",
|
||||
"scope": "read:all",
|
||||
}
|
||||
|
||||
// Generate token
|
||||
token, err := authPriv.GenerateToken(userID, claims)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Validate token with private key auth (has public key too)
|
||||
extractedUserID, extractedClaims, err := authPriv.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
assert.Equal(t, "rs256@example.com", extractedClaims["email"])
|
||||
assert.Equal(t, "read:all", extractedClaims["scope"])
|
||||
|
||||
// Test with public key only (can only verify)
|
||||
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be able to validate token
|
||||
extractedUserID, extractedClaims, err = authPub.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userID, extractedUserID)
|
||||
|
||||
// Should not be able to generate token
|
||||
_, err = authPub.GenerateToken(userID, claims)
|
||||
assert.Error(t, err, "Public key only auth should not generate tokens")
|
||||
assert.Equal(t, ErrTokenNoPrivateKey, err)
|
||||
|
||||
// Test algorithm mismatch
|
||||
authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = authHS256.ValidateToken(token)
|
||||
assert.Error(t, err, "HS256 auth should not validate RS256 token")
|
||||
// assert.True(t, errors.Is(err, ErrInvalidToken))
|
||||
assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch))
|
||||
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
func TestExpiredToken(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "user123"
|
||||
|
||||
// Generate normal token (should have 7 days expiry)
|
||||
token, err := auth.GenerateToken(userID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, extractedClaims, err := auth.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check expiry is in future (approximately 7 days)
|
||||
expiry := extractedClaims["exp"].(float64)
|
||||
now := time.Now().Unix()
|
||||
|
||||
assert.Greater(t, expiry, float64(now), "Token expiry should be in future")
|
||||
assert.InDelta(t, expiry, float64(now+7*24*60*60), 10,
|
||||
"Token expiry should be approximately 7 days from now")
|
||||
}
|
||||
|
||||
func TestCorruptJWTParts(t *testing.T) {
|
||||
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with missing parts
|
||||
_, _, err = auth.ValidateToken("only.two")
|
||||
assert.True(t, errors.Is(err, ErrTokenMalformed))
|
||||
|
||||
// Test with invalid header encoding
|
||||
_, _, err = auth.ValidateToken("not-base64!.valid.valid")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test with invalid claims encoding
|
||||
validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
_, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test with invalid JSON in claims
|
||||
invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json"))
|
||||
_, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
498
scram.go
Normal file
498
scram.go
Normal file
@ -0,0 +1,498 @@
|
||||
// FILE: auth/scram.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
// SCRAM-SHA256 implementation
|
||||
|
||||
// Credential stores SCRAM authentication data
|
||||
type Credential struct {
|
||||
Username string
|
||||
Salt []byte
|
||||
ArgonTime uint32
|
||||
ArgonMemory uint32
|
||||
ArgonThreads uint8
|
||||
StoredKey []byte // SHA256(ClientKey)
|
||||
ServerKey []byte
|
||||
}
|
||||
|
||||
// Export returns credential as config-friendly map
|
||||
func (c *Credential) Export() map[string]any {
|
||||
return map[string]any{
|
||||
"username": c.Username,
|
||||
"salt": base64.StdEncoding.EncodeToString(c.Salt),
|
||||
"argon_time": c.ArgonTime,
|
||||
"argon_memory": c.ArgonMemory,
|
||||
"argon_threads": c.ArgonThreads,
|
||||
"stored_key": base64.StdEncoding.EncodeToString(c.StoredKey),
|
||||
"server_key": base64.StdEncoding.EncodeToString(c.ServerKey),
|
||||
}
|
||||
}
|
||||
|
||||
// ImportCredential creates credential from map
|
||||
func ImportCredential(data map[string]any) (*Credential, error) {
|
||||
username, ok := data["username"].(string)
|
||||
if !ok {
|
||||
return nil, ErrCredMissingUsername
|
||||
}
|
||||
|
||||
saltStr, ok := data["salt"].(string)
|
||||
if !ok {
|
||||
return nil, ErrCredMissingSalt
|
||||
}
|
||||
salt, err := base64.StdEncoding.DecodeString(saltStr)
|
||||
if err != nil {
|
||||
return nil, ErrCredInvalidSalt
|
||||
}
|
||||
|
||||
// Handle both float64 (from JSON) and int types
|
||||
getUint32 := func(key string) (uint32, error) {
|
||||
val, ok := data[key]
|
||||
if !ok {
|
||||
switch key {
|
||||
case "argon_time":
|
||||
return 0, ErrCredMissingTime
|
||||
case "argon_memory":
|
||||
return 0, ErrCredMissingMemory
|
||||
default:
|
||||
return 0, fmt.Errorf("missing %s", key)
|
||||
}
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case float64:
|
||||
return uint32(v), nil
|
||||
case int:
|
||||
return uint32(v), nil
|
||||
case uint32:
|
||||
return v, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid type for %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
argonTime, err := getUint32("argon_time")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
argonMemory, err := getUint32("argon_memory")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
threadsVal, ok := data["argon_threads"]
|
||||
if !ok {
|
||||
return nil, ErrCredMissingThreads
|
||||
}
|
||||
var argonThreads uint8
|
||||
switch v := threadsVal.(type) {
|
||||
case float64:
|
||||
argonThreads = uint8(v)
|
||||
case int:
|
||||
argonThreads = uint8(v)
|
||||
case uint8:
|
||||
argonThreads = v
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: argon_threads", ErrCredInvalidType)
|
||||
}
|
||||
|
||||
storedKeyStr, ok := data["stored_key"].(string)
|
||||
if !ok {
|
||||
return nil, ErrCredMissingStoredKey
|
||||
}
|
||||
storedKey, err := base64.StdEncoding.DecodeString(storedKeyStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrCredInvalidStoredKey, err)
|
||||
}
|
||||
|
||||
serverKeyStr, ok := data["server_key"].(string)
|
||||
if !ok {
|
||||
return nil, ErrCredMissingServerKey
|
||||
}
|
||||
serverKey, err := base64.StdEncoding.DecodeString(serverKeyStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrCredInvalidServerKey, err)
|
||||
}
|
||||
|
||||
return &Credential{
|
||||
Username: username,
|
||||
Salt: salt,
|
||||
ArgonTime: argonTime,
|
||||
ArgonMemory: argonMemory,
|
||||
ArgonThreads: argonThreads,
|
||||
StoredKey: storedKey,
|
||||
ServerKey: serverKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeriveCredential creates SCRAM credential from password
|
||||
func DeriveCredential(username, password string, salt []byte, time, memory uint32, threads uint8) (*Credential, error) {
|
||||
if len(salt) < 16 {
|
||||
return nil, ErrSCRAMSaltTooShort
|
||||
}
|
||||
|
||||
// Derive salted password using Argon2id
|
||||
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen)
|
||||
|
||||
// Derive keys
|
||||
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||
storedKey := sha256.Sum256(clientKey)
|
||||
|
||||
return &Credential{
|
||||
Username: username,
|
||||
Salt: salt,
|
||||
ArgonTime: time,
|
||||
ArgonMemory: memory,
|
||||
ArgonThreads: threads,
|
||||
StoredKey: storedKey[:],
|
||||
ServerKey: serverKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ScramServer handles server-side SCRAM authentication
|
||||
type ScramServer struct {
|
||||
credentials map[string]*Credential
|
||||
handshakes map[string]*HandshakeState
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// HandshakeState tracks ongoing authentication
|
||||
type HandshakeState struct {
|
||||
Username string
|
||||
ClientNonce string
|
||||
ServerNonce string
|
||||
FullNonce string
|
||||
Credential *Credential
|
||||
CreatedAt time.Time
|
||||
verifying int32 // Atomic flag to prevent race during verification
|
||||
}
|
||||
|
||||
// NewScramServer creates SCRAM server
|
||||
func NewScramServer() *ScramServer {
|
||||
return &ScramServer{
|
||||
credentials: make(map[string]*Credential),
|
||||
handshakes: make(map[string]*HandshakeState),
|
||||
}
|
||||
}
|
||||
|
||||
// AddCredential registers user credential
|
||||
func (s *ScramServer) AddCredential(cred *Credential) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.credentials[cred.Username] = cred
|
||||
}
|
||||
|
||||
// ProcessClientFirstMessage processes initial auth request
|
||||
func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (ServerFirstMessage, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if user exists
|
||||
cred, exists := s.credentials[username]
|
||||
if !exists {
|
||||
// Prevent user enumeration - still generate response
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
serverNonce := generateNonce()
|
||||
|
||||
return ServerFirstMessage{
|
||||
FullNonce: clientNonce + serverNonce,
|
||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||
ArgonTime: DefaultArgonTime,
|
||||
ArgonMemory: DefaultArgonMemory,
|
||||
ArgonThreads: DefaultArgonThreads,
|
||||
}, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Generate server nonce
|
||||
serverNonce := generateNonce()
|
||||
fullNonce := clientNonce + serverNonce
|
||||
|
||||
// Store handshake state
|
||||
state := &HandshakeState{
|
||||
Username: username,
|
||||
ClientNonce: clientNonce,
|
||||
ServerNonce: serverNonce,
|
||||
FullNonce: fullNonce,
|
||||
Credential: cred,
|
||||
CreatedAt: time.Now(),
|
||||
verifying: 0,
|
||||
}
|
||||
s.handshakes[fullNonce] = state
|
||||
|
||||
// Cleanup old handshakes
|
||||
s.cleanupHandshakes()
|
||||
|
||||
return ServerFirstMessage{
|
||||
FullNonce: fullNonce,
|
||||
Salt: base64.StdEncoding.EncodeToString(cred.Salt),
|
||||
ArgonTime: cred.ArgonTime,
|
||||
ArgonMemory: cred.ArgonMemory,
|
||||
ArgonThreads: cred.ArgonThreads,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessClientFinalMessage verifies client proof
|
||||
func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (ServerFinalMessage, error) {
|
||||
s.mu.RLock()
|
||||
state, exists := s.handshakes[fullNonce]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return ServerFinalMessage{}, ErrSCRAMInvalidNonce
|
||||
}
|
||||
|
||||
// Mark as verifying to prevent deletion race
|
||||
if !atomic.CompareAndSwapInt32(&state.verifying, 0, 1) {
|
||||
return ServerFinalMessage{}, ErrSCRAMVerifyInProgress
|
||||
}
|
||||
defer func() {
|
||||
atomic.StoreInt32(&state.verifying, 0)
|
||||
// Safe to delete after verification completes
|
||||
s.mu.Lock()
|
||||
delete(s.handshakes, fullNonce)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Check timeout
|
||||
if time.Since(state.CreatedAt) > 60*time.Second {
|
||||
return ServerFinalMessage{}, ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
// Decode client proof
|
||||
clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof)
|
||||
if err != nil {
|
||||
return ServerFinalMessage{}, ErrSCRAMInvalidProof
|
||||
}
|
||||
|
||||
// Build auth message
|
||||
clientFirstBare := fmt.Sprintf("u=%s,n=%s", state.Username, state.ClientNonce)
|
||||
serverFirst := ServerFirstMessage{
|
||||
FullNonce: state.FullNonce,
|
||||
Salt: base64.StdEncoding.EncodeToString(state.Credential.Salt),
|
||||
ArgonTime: state.Credential.ArgonTime,
|
||||
ArgonMemory: state.Credential.ArgonMemory,
|
||||
ArgonThreads: state.Credential.ArgonThreads,
|
||||
}
|
||||
clientFinalBare := fmt.Sprintf("r=%s", fullNonce)
|
||||
authMessage := clientFirstBare + "," + serverFirst.Marshal() + "," + clientFinalBare
|
||||
|
||||
// Compute client signature
|
||||
clientSignature := computeHMAC(state.Credential.StoredKey, []byte(authMessage))
|
||||
|
||||
// XOR to get ClientKey
|
||||
if len(clientProofBytes) != len(clientSignature) {
|
||||
return ServerFinalMessage{}, ErrSCRAMInvalidProofLen
|
||||
}
|
||||
clientKey := xorBytes(clientProofBytes, clientSignature)
|
||||
|
||||
// Verify by computing StoredKey
|
||||
computedStoredKey := sha256.Sum256(clientKey)
|
||||
if subtle.ConstantTimeCompare(computedStoredKey[:], state.Credential.StoredKey) != 1 {
|
||||
return ServerFinalMessage{}, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Generate server signature for mutual auth
|
||||
serverSignature := computeHMAC(state.Credential.ServerKey, []byte(authMessage))
|
||||
|
||||
return ServerFinalMessage{
|
||||
ServerSignature: base64.StdEncoding.EncodeToString(serverSignature),
|
||||
Username: state.Username,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ScramServer) cleanupHandshakes() {
|
||||
cutoff := time.Now().Add(-60 * time.Second)
|
||||
for nonce, state := range s.handshakes {
|
||||
if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 {
|
||||
delete(s.handshakes, nonce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ScramClient handles client-side SCRAM authentication
|
||||
type ScramClient struct {
|
||||
Username string
|
||||
Password string
|
||||
clientNonce string
|
||||
serverFirst *ServerFirstMessage
|
||||
authMessage string
|
||||
serverKey []byte
|
||||
startTime time.Time // Track handshake start
|
||||
}
|
||||
|
||||
// NewScramClient creates SCRAM client
|
||||
func NewScramClient(username, password string) *ScramClient {
|
||||
return &ScramClient{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
// StartAuthentication generates initial client message
|
||||
func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) {
|
||||
c.startTime = time.Now()
|
||||
|
||||
// Generate client nonce
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return ClientFirstRequest{}, ErrSCRAMNonceGenFailed
|
||||
}
|
||||
c.clientNonce = base64.StdEncoding.EncodeToString(nonce)
|
||||
|
||||
return ClientFirstRequest{
|
||||
Username: c.Username,
|
||||
ClientNonce: c.clientNonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessServerFirstMessage handles server challenge
|
||||
func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) {
|
||||
// Check timeout (30 seconds)
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
||||
return ClientFinalRequest{}, ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
c.serverFirst = &msg
|
||||
|
||||
// Handle enumeration prevention - server may send fake response
|
||||
// We still process it normally and let verification fail later
|
||||
|
||||
// Decode salt
|
||||
salt, err := base64.StdEncoding.DecodeString(msg.Salt)
|
||||
if err != nil {
|
||||
return ClientFinalRequest{}, ErrSCRAMInvalidSalt
|
||||
}
|
||||
|
||||
// Validate parameters
|
||||
if msg.ArgonTime == 0 || msg.ArgonMemory == 0 || msg.ArgonThreads == 0 {
|
||||
return ClientFinalRequest{}, ErrSCRAMZeroParams
|
||||
}
|
||||
|
||||
// Derive keys using Argon2id
|
||||
saltedPassword := argon2.IDKey([]byte(c.Password), salt, msg.ArgonTime, msg.ArgonMemory, msg.ArgonThreads, 32)
|
||||
|
||||
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||
storedKey := sha256.Sum256(clientKey)
|
||||
|
||||
// Build auth message
|
||||
clientFirstBare := fmt.Sprintf("u=%s,n=%s", c.Username, c.clientNonce)
|
||||
clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce)
|
||||
c.authMessage = clientFirstBare + "," + msg.Marshal() + "," + clientFinalBare
|
||||
|
||||
// Compute client proof
|
||||
clientSignature := computeHMAC(storedKey[:], []byte(c.authMessage))
|
||||
clientProof := xorBytes(clientKey, clientSignature)
|
||||
|
||||
// Store server key for verification
|
||||
c.serverKey = serverKey
|
||||
|
||||
return ClientFinalRequest{
|
||||
FullNonce: msg.FullNonce,
|
||||
ClientProof: base64.StdEncoding.EncodeToString(clientProof),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyServerFinalMessage validates server signature
|
||||
func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error {
|
||||
// Check timeout
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
||||
return ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
if c.authMessage == "" || c.serverKey == nil {
|
||||
return ErrSCRAMInvalidState
|
||||
}
|
||||
|
||||
// Compute expected server signature
|
||||
expectedSig := computeHMAC(c.serverKey, []byte(c.authMessage))
|
||||
|
||||
// Decode received signature
|
||||
receivedSig, err := base64.StdEncoding.DecodeString(msg.ServerSignature)
|
||||
if err != nil {
|
||||
return ErrSCRAMServerAuthFailed
|
||||
}
|
||||
|
||||
// Constant-time comparison
|
||||
if subtle.ConstantTimeCompare(expectedSig, receivedSig) != 1 {
|
||||
return ErrSCRAMServerAuthFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset clears client state for retry
|
||||
func (c *ScramClient) Reset() {
|
||||
c.clientNonce = ""
|
||||
c.serverFirst = nil
|
||||
c.authMessage = ""
|
||||
c.serverKey = nil
|
||||
c.startTime = time.Time{}
|
||||
}
|
||||
|
||||
// SCRAM message types
|
||||
type ClientFirstRequest struct {
|
||||
Username string `json:"username"`
|
||||
ClientNonce string `json:"client_nonce"`
|
||||
}
|
||||
|
||||
type ServerFirstMessage struct {
|
||||
FullNonce string `json:"full_nonce"`
|
||||
Salt string `json:"salt"`
|
||||
ArgonTime uint32 `json:"argon_time"`
|
||||
ArgonMemory uint32 `json:"argon_memory"`
|
||||
ArgonThreads uint8 `json:"argon_threads"`
|
||||
}
|
||||
|
||||
func (s ServerFirstMessage) Marshal() string {
|
||||
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
|
||||
s.FullNonce, s.Salt, s.ArgonTime, s.ArgonMemory, s.ArgonThreads)
|
||||
}
|
||||
|
||||
type ClientFinalRequest struct {
|
||||
FullNonce string `json:"full_nonce"`
|
||||
ClientProof string `json:"client_proof"`
|
||||
}
|
||||
|
||||
type ServerFinalMessage struct {
|
||||
ServerSignature string `json:"server_signature"`
|
||||
Username string `json:"username,omitempty"`
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func computeHMAC(key, message []byte) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(message)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func xorBytes(a, b []byte) []byte {
|
||||
if len(a) != len(b) {
|
||||
panic("xor length mismatch")
|
||||
}
|
||||
result := make([]byte, len(a))
|
||||
for i := range a {
|
||||
result[i] = a[i] ^ b[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func generateNonce() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
50
token.go
Normal file
50
token.go
Normal file
@ -0,0 +1,50 @@
|
||||
// FILE: auth/token.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SimpleTokenValidator implements in-memory token validation
|
||||
type SimpleTokenValidator struct {
|
||||
tokens map[string]struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSimpleTokenValidator creates token validator
|
||||
func NewSimpleTokenValidator() *SimpleTokenValidator {
|
||||
return &SimpleTokenValidator{
|
||||
tokens: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateToken checks if token is valid
|
||||
func (v *SimpleTokenValidator) ValidateToken(token string) bool {
|
||||
v.mu.RLock()
|
||||
defer v.mu.RUnlock()
|
||||
|
||||
// Constant-time comparison for each stored token
|
||||
for storedToken := range v.tokens {
|
||||
if subtle.ConstantTimeEq(int32(len(token)), int32(len(storedToken))) == 1 {
|
||||
if subtle.ConstantTimeCompare([]byte(token), []byte(storedToken)) == 1 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AddToken adds token to validator
|
||||
func (v *SimpleTokenValidator) AddToken(token string) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
v.tokens[token] = struct{}{}
|
||||
}
|
||||
|
||||
// RemoveToken removes token from validator
|
||||
func (v *SimpleTokenValidator) RemoveToken(token string) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
delete(v.tokens, token)
|
||||
}
|
||||
81
token_test.go
Normal file
81
token_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
// FILE: auth/token_test.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSimpleTokenValidator(t *testing.T) {
|
||||
validator := NewSimpleTokenValidator()
|
||||
|
||||
token1 := "test-token-123"
|
||||
token2 := "test-token-456"
|
||||
|
||||
// Add tokens
|
||||
validator.AddToken(token1)
|
||||
validator.AddToken(token2)
|
||||
|
||||
// Validate existing tokens
|
||||
assert.True(t, validator.ValidateToken(token1))
|
||||
assert.True(t, validator.ValidateToken(token2))
|
||||
|
||||
// Invalid token
|
||||
assert.False(t, validator.ValidateToken("invalid-token"))
|
||||
|
||||
// Remove token
|
||||
validator.RemoveToken(token1)
|
||||
assert.False(t, validator.ValidateToken(token1))
|
||||
assert.True(t, validator.ValidateToken(token2))
|
||||
}
|
||||
|
||||
func TestConcurrentTokenValidator(t *testing.T) {
|
||||
validator := NewSimpleTokenValidator()
|
||||
|
||||
// Add tokens concurrently
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
token := fmt.Sprintf("token-%d", idx)
|
||||
validator.AddToken(token)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Validate concurrently
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
token := fmt.Sprintf("token-%d", idx)
|
||||
assert.True(t, validator.ValidateToken(token))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Remove concurrently
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
token := fmt.Sprintf("token-%d", idx)
|
||||
validator.RemoveToken(token)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify removal
|
||||
for i := 0; i < 50; i++ {
|
||||
token := fmt.Sprintf("token-%d", i)
|
||||
assert.False(t, validator.ValidateToken(token))
|
||||
}
|
||||
for i := 50; i < 100; i++ {
|
||||
token := fmt.Sprintf("token-%d", i)
|
||||
assert.True(t, validator.ValidateToken(token))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user