v0.1.0 initial commit, auth features extracted from logwisp to be a standalone utility package
This commit is contained in:
12
.gitignore
vendored
Normal file
12
.gitignore
vendored
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
.idea
|
||||||
|
data
|
||||||
|
dev
|
||||||
|
log
|
||||||
|
logs
|
||||||
|
cert
|
||||||
|
bin
|
||||||
|
script
|
||||||
|
build
|
||||||
|
*.log
|
||||||
|
*.toml
|
||||||
|
build.sh
|
||||||
28
LICENSE
Normal file
28
LICENSE
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2025, Lixen Wraith
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
38
README.md
Normal file
38
README.md
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Auth Package
|
||||||
|
|
||||||
|
Pluggable authentication utilities for Go applications.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Password Hashing**: Argon2id with PHC format
|
||||||
|
- **JWT**: HS256/RS256 token generation and validation
|
||||||
|
- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF
|
||||||
|
- **HTTP Auth**: Basic/Bearer header parsing
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
```go
|
||||||
|
// JWT with HS256
|
||||||
|
auth, _ := auth.NewAuthenticator([]byte("32-byte-secret-key..."))
|
||||||
|
token, _ := auth.GenerateToken("user123", map[string]interface{}{"role": "admin"})
|
||||||
|
userID, claims, _ := auth.ValidateToken(token)
|
||||||
|
|
||||||
|
// SCRAM authentication
|
||||||
|
server := auth.NewScramServer()
|
||||||
|
cred, _ := auth.DeriveCredential("user", "password", salt, 1, 65536, 4)
|
||||||
|
server.AddCredential(cred)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Package Structure
|
||||||
|
|
||||||
|
- `interfaces.go` - Core interfaces
|
||||||
|
- `jwt.go` - JWT token operations
|
||||||
|
- `argon2.go` - Password hashing
|
||||||
|
- `scram.go` - SCRAM-SHA256 protocol
|
||||||
|
- `token.go` - Token validation utilities
|
||||||
|
- `http.go` - HTTP header parsing
|
||||||
|
- `errors.go` - Error definitions
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
```bash
|
||||||
|
go test -v ./auth
|
||||||
|
```
|
||||||
110
argon2.go
Normal file
110
argon2.go
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
// FILE: auth/argon2.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/argon2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default Argon2id parameters
|
||||||
|
const (
|
||||||
|
DefaultArgonTime = 3 // iterations (reduce for faster but less secure auth)
|
||||||
|
DefaultArgonMemory = 64 * 1024 // 64 MB
|
||||||
|
DefaultArgonThreads = 4
|
||||||
|
DefaultArgonSaltLen = 16
|
||||||
|
DefaultArgonKeyLen = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
// HashPassword creates an Argon2id PHC-format hash
|
||||||
|
func (a *Authenticator) HashPassword(password string) (string, error) {
|
||||||
|
if len(password) < 8 {
|
||||||
|
return "", ErrWeakPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate salt
|
||||||
|
salt := make([]byte, DefaultArgonSaltLen)
|
||||||
|
if _, err := rand.Read(salt); err != nil {
|
||||||
|
return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive key using Argon2id
|
||||||
|
hash := argon2.IDKey([]byte(password), salt, a.argonTime, a.argonMemory, a.argonThreads, DefaultArgonKeyLen)
|
||||||
|
|
||||||
|
// Construct PHC format
|
||||||
|
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||||
|
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
||||||
|
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||||
|
argon2.Version, a.argonMemory, a.argonTime, a.argonThreads, saltB64, hashB64)
|
||||||
|
|
||||||
|
return phcHash, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyPassword checks password against PHC-format hash
|
||||||
|
func (a *Authenticator) VerifyPassword(password, phcHash string) error {
|
||||||
|
// Parse PHC format
|
||||||
|
parts := strings.Split(phcHash, "$")
|
||||||
|
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||||
|
return ErrPHCInvalidFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var memory, time uint32
|
||||||
|
var threads uint8
|
||||||
|
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||||
|
|
||||||
|
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", ErrPHCInvalidSalt, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute hash with same parameters
|
||||||
|
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||||
|
|
||||||
|
// Constant-time comparison
|
||||||
|
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
||||||
|
return ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential
|
||||||
|
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
||||||
|
// Parse PHC format
|
||||||
|
parts := strings.Split(phcHash, "$")
|
||||||
|
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||||
|
return nil, ErrPHCInvalidFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var memory, time uint32
|
||||||
|
var threads uint8
|
||||||
|
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||||
|
|
||||||
|
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrPHCInvalidSalt
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrPHCInvalidHash
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify password against hash
|
||||||
|
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||||
|
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
||||||
|
return nil, ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive SCRAM credential with same parameters
|
||||||
|
return DeriveCredential(username, password, salt, time, memory, threads)
|
||||||
|
}
|
||||||
117
argon2_test.go
Normal file
117
argon2_test.go
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
// FILE: auth/argon2_test.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPasswordHashing(t *testing.T) {
|
||||||
|
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err, "Failed to create authenticator")
|
||||||
|
|
||||||
|
password := "testPassword123"
|
||||||
|
|
||||||
|
// Test hashing
|
||||||
|
hash, err := auth.HashPassword(password)
|
||||||
|
require.NoError(t, err, "Failed to hash password")
|
||||||
|
|
||||||
|
// Verify PHC format
|
||||||
|
assert.True(t, strings.HasPrefix(hash, "$argon2id$"),
|
||||||
|
"Hash should have argon2id prefix, got: %s", hash)
|
||||||
|
|
||||||
|
// Test verification with correct password
|
||||||
|
err = auth.VerifyPassword(password, hash)
|
||||||
|
assert.NoError(t, err, "Failed to verify correct password")
|
||||||
|
|
||||||
|
// Test verification with incorrect password
|
||||||
|
err = auth.VerifyPassword("wrongPassword", hash)
|
||||||
|
assert.Error(t, err, "Verification should fail for incorrect password")
|
||||||
|
assert.Equal(t, ErrInvalidCredentials, err)
|
||||||
|
|
||||||
|
// Test weak password
|
||||||
|
_, err = auth.HashPassword("weak")
|
||||||
|
assert.Equal(t, ErrWeakPassword, err, "Should reject weak password")
|
||||||
|
|
||||||
|
// Test malformed PHC hash
|
||||||
|
err = auth.VerifyPassword(password, "$invalid$format")
|
||||||
|
assert.Error(t, err, "Should reject malformed hash")
|
||||||
|
|
||||||
|
// Test corrupted salt
|
||||||
|
corruptedHash := strings.Replace(hash, "$argon2id$", "$argon2id$", 1)
|
||||||
|
parts := strings.Split(corruptedHash, "$")
|
||||||
|
if len(parts) == 6 {
|
||||||
|
parts[4] = "invalid!base64"
|
||||||
|
corruptedHash = strings.Join(parts, "$")
|
||||||
|
err = auth.VerifyPassword(password, corruptedHash)
|
||||||
|
assert.Error(t, err, "Should reject corrupted salt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmptyPasswordAfterValidation(t *testing.T) {
|
||||||
|
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Empty password should be rejected by length check
|
||||||
|
_, err = auth.HashPassword("")
|
||||||
|
assert.Equal(t, ErrWeakPassword, err)
|
||||||
|
|
||||||
|
// 8-character password should pass
|
||||||
|
hash, err := auth.HashPassword("12345678")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = auth.VerifyPassword("12345678", hash)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentPasswordOperations(t *testing.T) {
|
||||||
|
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
password := "testPassword123"
|
||||||
|
hash, err := auth.HashPassword(password)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test concurrent verification
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := auth.VerifyPassword(password, hash)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPHCMigration(t *testing.T) {
|
||||||
|
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
password := "testPassword123"
|
||||||
|
username := "migrationUser"
|
||||||
|
|
||||||
|
// Generate PHC hash
|
||||||
|
phcHash, err := auth.HashPassword(password)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Migrate to SCRAM credential
|
||||||
|
cred, err := MigrateFromPHC(username, password, phcHash)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, username, cred.Username)
|
||||||
|
assert.NotNil(t, cred.StoredKey)
|
||||||
|
assert.NotNil(t, cred.ServerKey)
|
||||||
|
|
||||||
|
// Test with wrong password
|
||||||
|
_, err = MigrateFromPHC(username, "wrongPassword", phcHash)
|
||||||
|
assert.Equal(t, ErrInvalidCredentials, err)
|
||||||
|
|
||||||
|
// Test with invalid PHC format
|
||||||
|
_, err = MigrateFromPHC(username, password, "$invalid$format")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
71
auth.go
Normal file
71
auth.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
// FILE: auth/auth.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Authenticator provides password hashing and JWT operations
|
||||||
|
type Authenticator struct {
|
||||||
|
algorithm string
|
||||||
|
jwtSecret []byte // For HS256
|
||||||
|
privateKey *rsa.PrivateKey // For RS256
|
||||||
|
publicKey *rsa.PublicKey // For RS256
|
||||||
|
argonTime uint32
|
||||||
|
argonMemory uint32
|
||||||
|
argonThreads uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthenticator creates a new authenticator with specified algorithm
|
||||||
|
func NewAuthenticator(key any, algorithm ...string) (*Authenticator, error) {
|
||||||
|
alg := "HS256"
|
||||||
|
if len(algorithm) > 0 && algorithm[0] != "" {
|
||||||
|
alg = algorithm[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := &Authenticator{
|
||||||
|
algorithm: alg,
|
||||||
|
argonTime: DefaultArgonTime,
|
||||||
|
argonMemory: DefaultArgonMemory,
|
||||||
|
argonThreads: DefaultArgonThreads,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch alg {
|
||||||
|
case "HS256":
|
||||||
|
secret, ok := key.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidKeyType
|
||||||
|
}
|
||||||
|
if len(secret) < 32 {
|
||||||
|
return nil, ErrSecretTooShort
|
||||||
|
}
|
||||||
|
auth.jwtSecret = secret
|
||||||
|
|
||||||
|
case "RS256":
|
||||||
|
switch k := key.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
auth.privateKey = k
|
||||||
|
auth.publicKey = &k.PublicKey
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
auth.publicKey = k
|
||||||
|
case []byte:
|
||||||
|
// Try parsing as PEM
|
||||||
|
if privKey, err := parseRSAPrivateKey(k); err == nil {
|
||||||
|
auth.privateKey = privKey
|
||||||
|
auth.publicKey = &privKey.PublicKey
|
||||||
|
} else if pubKey, err := parseRSAPublicKey(k); err == nil {
|
||||||
|
auth.publicKey = pubKey
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("failed to parse RSA key: %w", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, ErrInvalidKeyType
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, ErrInvalidAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
68
auth_test.go
Normal file
68
auth_test.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
// FILE: auth/auth_test.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewAuthenticator(t *testing.T) {
|
||||||
|
// Test HS256 creation
|
||||||
|
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err, "Failed to create HS256 authenticator")
|
||||||
|
assert.Equal(t, "HS256", auth.algorithm)
|
||||||
|
|
||||||
|
// Test with short secret
|
||||||
|
_, err = NewAuthenticator([]byte("short"))
|
||||||
|
assert.Equal(t, ErrSecretTooShort, err)
|
||||||
|
|
||||||
|
// Test RS256 with private key
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authRS, err := NewAuthenticator(privateKey, "RS256")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "RS256", authRS.algorithm)
|
||||||
|
assert.NotNil(t, authRS.privateKey)
|
||||||
|
assert.NotNil(t, authRS.publicKey)
|
||||||
|
|
||||||
|
// Test RS256 with public key only
|
||||||
|
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "RS256", authPub.algorithm)
|
||||||
|
assert.Nil(t, authPub.privateKey)
|
||||||
|
assert.NotNil(t, authPub.publicKey)
|
||||||
|
|
||||||
|
// Test invalid algorithm
|
||||||
|
_, err = NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"), "INVALID")
|
||||||
|
assert.Equal(t, ErrInvalidAlgorithm, err)
|
||||||
|
|
||||||
|
// Test invalid key type for HS256
|
||||||
|
_, err = NewAuthenticator(privateKey, "HS256")
|
||||||
|
assert.Equal(t, ErrInvalidKeyType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterfaceCompliance(t *testing.T) {
|
||||||
|
// Verify Authenticator implements AuthenticatorInterface
|
||||||
|
auth, _ := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
|
||||||
|
var _ AuthenticatorInterface = auth
|
||||||
|
|
||||||
|
// Test interface methods work
|
||||||
|
hash, err := auth.HashPassword("testpass123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = auth.VerifyPassword("testpass123", hash)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := auth.GenerateToken("user1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID, _, err := auth.ValidateToken(token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "user1", userID)
|
||||||
|
}
|
||||||
100
error.go
Normal file
100
error.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
// FILE: auth/errors.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Base authentication errors
|
||||||
|
var (
|
||||||
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||||
|
ErrWeakPassword = errors.New("password must be at least 8 characters")
|
||||||
|
ErrInvalidAlgorithm = errors.New("invalid algorithm")
|
||||||
|
ErrInvalidKeyType = errors.New("invalid key type for algorithm")
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWT-specific errors
|
||||||
|
var (
|
||||||
|
ErrTokenMalformed = errors.New("token: malformed structure")
|
||||||
|
ErrTokenExpired = errors.New("token: expired")
|
||||||
|
ErrTokenNotYetValid = errors.New("token: not yet valid")
|
||||||
|
ErrTokenInvalidSignature = errors.New("token: invalid signature")
|
||||||
|
ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch")
|
||||||
|
ErrTokenMissingClaim = errors.New("token: missing required claim")
|
||||||
|
ErrTokenInvalidHeader = errors.New("token: invalid header encoding")
|
||||||
|
ErrTokenInvalidClaims = errors.New("token: invalid claims encoding")
|
||||||
|
ErrTokenInvalidJSON = errors.New("token: malformed JSON")
|
||||||
|
ErrTokenEmptyUserID = errors.New("token: empty user ID")
|
||||||
|
ErrTokenNoPrivateKey = errors.New("token: private key required for signing")
|
||||||
|
ErrTokenNoPublicKey = errors.New("token: public key required for verification")
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWT secret errors
|
||||||
|
var (
|
||||||
|
ErrSecretTooShort = errors.New("JWT secret must be at least 32 bytes")
|
||||||
|
)
|
||||||
|
|
||||||
|
// RSA key parsing errors
|
||||||
|
var (
|
||||||
|
ErrRSAInvalidPEM = errors.New("rsa: failed to parse PEM block")
|
||||||
|
ErrRSAInvalidPrivateKey = errors.New("rsa: invalid private key format")
|
||||||
|
ErrRSAInvalidPublicKey = errors.New("rsa: invalid public key format")
|
||||||
|
ErrRSANotPublicKey = errors.New("rsa: not an RSA public key")
|
||||||
|
)
|
||||||
|
|
||||||
|
// PHC format errors
|
||||||
|
var (
|
||||||
|
ErrPHCInvalidFormat = errors.New("phc: invalid format")
|
||||||
|
ErrPHCInvalidSalt = errors.New("phc: invalid salt encoding")
|
||||||
|
ErrPHCInvalidHash = errors.New("phc: invalid hash encoding")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SCRAM-specific errors
|
||||||
|
var (
|
||||||
|
ErrSCRAMInvalidNonce = errors.New("scram: invalid nonce or expired handshake")
|
||||||
|
ErrSCRAMTimeout = errors.New("scram: handshake timeout")
|
||||||
|
ErrSCRAMVerifyInProgress = errors.New("scram: verification already in progress")
|
||||||
|
ErrSCRAMInvalidProof = errors.New("scram: invalid proof encoding")
|
||||||
|
ErrSCRAMInvalidProofLen = errors.New("scram: invalid proof length")
|
||||||
|
ErrSCRAMServerAuthFailed = errors.New("scram: server authentication failed")
|
||||||
|
ErrSCRAMInvalidState = errors.New("scram: invalid handshake state")
|
||||||
|
ErrSCRAMInvalidSalt = errors.New("scram: invalid salt encoding")
|
||||||
|
ErrSCRAMZeroParams = errors.New("scram: invalid Argon2 parameters")
|
||||||
|
ErrSCRAMSaltTooShort = errors.New("scram: salt must be at least 16 bytes")
|
||||||
|
ErrSCRAMNonceGenFailed = errors.New("scram: failed to generate nonce")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Credential import/export errors
|
||||||
|
var (
|
||||||
|
ErrCredMissingUsername = errors.New("credential: missing username")
|
||||||
|
ErrCredMissingSalt = errors.New("credential: missing salt")
|
||||||
|
ErrCredInvalidSalt = errors.New("credential: invalid salt encoding")
|
||||||
|
ErrCredMissingTime = errors.New("credential: missing argon_time")
|
||||||
|
ErrCredMissingMemory = errors.New("credential: missing argon_memory")
|
||||||
|
ErrCredMissingThreads = errors.New("credential: missing argon_threads")
|
||||||
|
ErrCredMissingStoredKey = errors.New("credential: missing stored_key")
|
||||||
|
ErrCredInvalidStoredKey = errors.New("credential: invalid stored_key encoding")
|
||||||
|
ErrCredMissingServerKey = errors.New("credential: missing server_key")
|
||||||
|
ErrCredInvalidServerKey = errors.New("credential: invalid server_key encoding")
|
||||||
|
ErrCredInvalidType = fmt.Errorf("credential: invalid type for field")
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTP auth parsing errors
|
||||||
|
var (
|
||||||
|
ErrAuthInvalidBasicFormat = errors.New("auth: invalid Basic auth format")
|
||||||
|
ErrAuthInvalidBasicEncoding = errors.New("auth: invalid Basic auth base64 encoding")
|
||||||
|
ErrAuthInvalidBasicCreds = errors.New("auth: invalid Basic auth credentials format")
|
||||||
|
ErrAuthInvalidBearerFormat = errors.New("auth: invalid Bearer auth format")
|
||||||
|
ErrAuthEmptyBearerToken = errors.New("auth: empty Bearer token")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Salt generation errors
|
||||||
|
var (
|
||||||
|
ErrSaltGenerationFailed = errors.New("failed to generate salt")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Key generation errors
|
||||||
|
var (
|
||||||
|
ErrRSAKeyGenFailed = errors.New("failed to generate RSA key")
|
||||||
|
)
|
||||||
15
go.mod
Normal file
15
go.mod
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
module auth
|
||||||
|
|
||||||
|
go 1.25.3
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
golang.org/x/crypto v0.43.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
golang.org/x/sys v0.37.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
14
go.sum
Normal file
14
go.sum
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
61
http.go
Normal file
61
http.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// FILE: auth/http.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseBasicAuth extracts username/password from Basic auth header
|
||||||
|
func ParseBasicAuth(header string) (username, password string, err error) {
|
||||||
|
const prefix = "Basic "
|
||||||
|
if !strings.HasPrefix(header, prefix) {
|
||||||
|
return "", "", ErrAuthInvalidBasicFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded := strings.TrimPrefix(header, prefix)
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", ErrAuthInvalidBasicEncoding
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := string(decoded)
|
||||||
|
idx := strings.IndexByte(credentials, ':')
|
||||||
|
if idx < 0 {
|
||||||
|
return "", "", ErrAuthInvalidBasicCreds
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials[:idx], credentials[idx+1:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseBearerToken extracts token from Bearer auth header
|
||||||
|
func ParseBearerToken(header string) (token string, err error) {
|
||||||
|
const prefix = "Bearer "
|
||||||
|
if !strings.HasPrefix(header, prefix) {
|
||||||
|
return "", ErrAuthInvalidBearerFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
token = strings.TrimPrefix(header, prefix)
|
||||||
|
if token == "" {
|
||||||
|
return "", ErrAuthEmptyBearerToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractAuthType returns authentication type from header
|
||||||
|
func ExtractAuthType(header string) string {
|
||||||
|
if strings.HasPrefix(header, "Basic ") {
|
||||||
|
return "Basic"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(header, "Bearer ") {
|
||||||
|
return "Bearer"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract first word as auth type
|
||||||
|
idx := strings.IndexByte(header, ' ')
|
||||||
|
if idx > 0 {
|
||||||
|
return header[:idx]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
54
http_test.go
Normal file
54
http_test.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
// FILE: auth/http_test.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHTTPAuthParsing(t *testing.T) {
|
||||||
|
// Test Basic Auth
|
||||||
|
basicHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
|
||||||
|
username, password, err := ParseBasicAuth(basicHeader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "user", username)
|
||||||
|
assert.Equal(t, "pass", password)
|
||||||
|
|
||||||
|
// Test Bearer Token
|
||||||
|
bearerHeader := "Bearer test-token-xyz"
|
||||||
|
token, err := ParseBearerToken(bearerHeader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "test-token-xyz", token)
|
||||||
|
|
||||||
|
// Test ExtractAuthType
|
||||||
|
assert.Equal(t, "Basic", ExtractAuthType(basicHeader))
|
||||||
|
assert.Equal(t, "Bearer", ExtractAuthType(bearerHeader))
|
||||||
|
assert.Equal(t, "Custom", ExtractAuthType("Custom somedata"))
|
||||||
|
assert.Equal(t, "", ExtractAuthType("InvalidHeader"))
|
||||||
|
|
||||||
|
// Test invalid formats
|
||||||
|
_, _, err = ParseBasicAuth("Invalid header")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, ErrAuthInvalidBasicFormat, err)
|
||||||
|
|
||||||
|
_, err = ParseBearerToken("Invalid header")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, ErrAuthInvalidBearerFormat, err)
|
||||||
|
|
||||||
|
// Test malformed Basic auth
|
||||||
|
_, _, err = ParseBasicAuth("Basic not-base64!")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, ErrAuthInvalidBasicEncoding, err)
|
||||||
|
|
||||||
|
_, _, err = ParseBasicAuth("Basic " + base64.StdEncoding.EncodeToString([]byte("no-colon")))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, ErrAuthInvalidBasicCreds, err)
|
||||||
|
|
||||||
|
// Test empty Bearer token
|
||||||
|
_, err = ParseBearerToken("Bearer ")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, ErrAuthEmptyBearerToken, err)
|
||||||
|
}
|
||||||
17
interface.go
Normal file
17
interface.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
// FILE: auth/interface.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
// AuthenticatorInterface defines the authentication operations
|
||||||
|
type AuthenticatorInterface interface {
|
||||||
|
HashPassword(password string) (hash string, err error)
|
||||||
|
VerifyPassword(password, hash string) (err error)
|
||||||
|
GenerateToken(userID string, claims map[string]any) (token string, err error)
|
||||||
|
ValidateToken(token string) (userID string, claims map[string]any, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenValidator validates bearer tokens
|
||||||
|
type TokenValidator interface {
|
||||||
|
ValidateToken(token string) (valid bool)
|
||||||
|
AddToken(token string)
|
||||||
|
RemoveToken(token string)
|
||||||
|
}
|
||||||
205
jwt.go
Normal file
205
jwt.go
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
// FILE: auth/jwt.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateToken creates a JWT token with user claims
|
||||||
|
func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) {
|
||||||
|
if userID == "" {
|
||||||
|
return "", ErrTokenEmptyUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.algorithm == "RS256" && a.privateKey == nil {
|
||||||
|
return "", ErrTokenNoPrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build JWT claims
|
||||||
|
now := time.Now()
|
||||||
|
jwtClaims := map[string]any{
|
||||||
|
"sub": userID,
|
||||||
|
"iat": now.Unix(),
|
||||||
|
"exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reserved claims that cannot be overridden
|
||||||
|
reservedClaims := map[string]bool{
|
||||||
|
"sub": true, "iat": true, "exp": true, "nbf": true,
|
||||||
|
"iss": true, "aud": true, "jti": true, "typ": true,
|
||||||
|
"alg": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge custom claims
|
||||||
|
for k, v := range claims {
|
||||||
|
if !reservedClaims[k] {
|
||||||
|
jwtClaims[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create JWT header
|
||||||
|
header := map[string]any{
|
||||||
|
"alg": a.algorithm,
|
||||||
|
"typ": "JWT",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode header and payload
|
||||||
|
headerJSON, _ := json.Marshal(header)
|
||||||
|
claimsJSON, _ := json.Marshal(jwtClaims)
|
||||||
|
|
||||||
|
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||||
|
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||||
|
|
||||||
|
// Create signature
|
||||||
|
signingInput := headerB64 + "." + claimsB64
|
||||||
|
var signature string
|
||||||
|
|
||||||
|
switch a.algorithm {
|
||||||
|
case "HS256":
|
||||||
|
h := hmac.New(sha256.New, a.jwtSecret)
|
||||||
|
h.Write([]byte(signingInput))
|
||||||
|
signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
|
case "RS256":
|
||||||
|
hashed := sha256.Sum256([]byte(signingInput))
|
||||||
|
sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||||
|
}
|
||||||
|
signature = base64.RawURLEncoding.EncodeToString(sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine to form JWT
|
||||||
|
token := signingInput + "." + signature
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken verifies JWT and returns userID and claims
|
||||||
|
func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) {
|
||||||
|
// Split token
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return "", nil, ErrTokenMalformed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode header to check algorithm
|
||||||
|
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, ErrTokenInvalidHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
var header map[string]any
|
||||||
|
if err = json.Unmarshal(headerJSON, &header); err != nil {
|
||||||
|
return "", nil, ErrTokenInvalidJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify algorithm matches
|
||||||
|
if alg, ok := header["alg"].(string); !ok || alg != a.algorithm {
|
||||||
|
return "", nil, ErrTokenAlgorithmMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify signature
|
||||||
|
signingInput := parts[0] + "." + parts[1]
|
||||||
|
|
||||||
|
switch a.algorithm {
|
||||||
|
case "HS256":
|
||||||
|
h := hmac.New(sha256.New, a.jwtSecret)
|
||||||
|
h.Write([]byte(signingInput))
|
||||||
|
expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
|
if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 {
|
||||||
|
return "", nil, ErrTokenInvalidSignature
|
||||||
|
}
|
||||||
|
|
||||||
|
case "RS256":
|
||||||
|
if a.publicKey == nil {
|
||||||
|
return "", nil, ErrTokenNoPublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, ErrTokenInvalidSignature
|
||||||
|
}
|
||||||
|
|
||||||
|
hashed := sha256.Sum256([]byte(signingInput))
|
||||||
|
if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil {
|
||||||
|
return "", nil, ErrTokenInvalidSignature
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode claims
|
||||||
|
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, ErrTokenInvalidClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims map[string]any
|
||||||
|
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
|
||||||
|
return "", nil, ErrTokenInvalidJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiration
|
||||||
|
if exp, ok := claims["exp"].(float64); ok {
|
||||||
|
if time.Now().Unix() > int64(exp) {
|
||||||
|
return "", nil, ErrTokenExpired
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check not before
|
||||||
|
if nbf, ok := claims["nbf"].(float64); ok {
|
||||||
|
if time.Now().Unix() < int64(nbf) {
|
||||||
|
return "", nil, ErrTokenNotYetValid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract userID
|
||||||
|
userID, ok := claims["sub"].(string)
|
||||||
|
if !ok {
|
||||||
|
return "", nil, ErrTokenMissingClaim
|
||||||
|
}
|
||||||
|
|
||||||
|
return userID, claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRSAPrivateKey parses PEM encoded RSA private key
|
||||||
|
func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||||
|
block, _ := pem.Decode(pemBytes)
|
||||||
|
if block == nil {
|
||||||
|
return nil, ErrRSAInvalidPEM
|
||||||
|
}
|
||||||
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrRSAInvalidPrivateKey
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRSAPublicKey parses PEM encoded RSA public key
|
||||||
|
func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) {
|
||||||
|
block, _ := pem.Decode(pemBytes)
|
||||||
|
if block == nil {
|
||||||
|
return nil, ErrRSAInvalidPEM
|
||||||
|
}
|
||||||
|
pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrRSAInvalidPublicKey
|
||||||
|
}
|
||||||
|
pubKey, ok := pubInterface.(*rsa.PublicKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrRSANotPublicKey
|
||||||
|
}
|
||||||
|
return pubKey, nil
|
||||||
|
}
|
||||||
176
jwt_test.go
Normal file
176
jwt_test.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
// FILE: auth/jwt_test.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJWTHS256(t *testing.T) {
|
||||||
|
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID := "user123"
|
||||||
|
claims := map[string]any{
|
||||||
|
"email": "test@example.com",
|
||||||
|
"role": "admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate token
|
||||||
|
token, err := auth.GenerateToken(userID, claims)
|
||||||
|
require.NoError(t, err, "Failed to generate token")
|
||||||
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
|
// Validate token
|
||||||
|
extractedUserID, extractedClaims, err := auth.ValidateToken(token)
|
||||||
|
require.NoError(t, err, "Failed to validate token")
|
||||||
|
|
||||||
|
assert.Equal(t, userID, extractedUserID)
|
||||||
|
assert.Equal(t, "test@example.com", extractedClaims["email"])
|
||||||
|
assert.Equal(t, "admin", extractedClaims["role"])
|
||||||
|
|
||||||
|
// Test invalid token
|
||||||
|
_, _, err = auth.ValidateToken("invalid.token.here")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, ErrTokenInvalidJSON))
|
||||||
|
|
||||||
|
// Test tampered token
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
require.Len(t, parts, 3, "JWT should have 3 parts")
|
||||||
|
|
||||||
|
tampered := parts[0] + "." + parts[1] + ".invalidsignature"
|
||||||
|
_, _, err = auth.ValidateToken(tampered)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, ErrTokenInvalidSignature))
|
||||||
|
|
||||||
|
// Test reserved claims cannot be overridden
|
||||||
|
overrideClaims := map[string]any{
|
||||||
|
"sub": "override",
|
||||||
|
"iat": 12345,
|
||||||
|
"exp": 67890,
|
||||||
|
"nbf": 11111,
|
||||||
|
"iss": "attacker",
|
||||||
|
"aud": "victim",
|
||||||
|
"jti": "fake",
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err = auth.GenerateToken(userID, overrideClaims)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
extractedUserID, extractedClaims, err = auth.ValidateToken(token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, userID, extractedUserID, "UserID should not be overridden")
|
||||||
|
assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden")
|
||||||
|
assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden")
|
||||||
|
assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims")
|
||||||
|
assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTRS256(t *testing.T) {
|
||||||
|
// Generate RSA key pair
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test with private key (can sign and verify)
|
||||||
|
authPriv, err := NewAuthenticator(privateKey, "RS256")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID := "user456"
|
||||||
|
claims := map[string]any{
|
||||||
|
"email": "rs256@example.com",
|
||||||
|
"scope": "read:all",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate token
|
||||||
|
token, err := authPriv.GenerateToken(userID, claims)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
|
// Validate token with private key auth (has public key too)
|
||||||
|
extractedUserID, extractedClaims, err := authPriv.ValidateToken(token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, userID, extractedUserID)
|
||||||
|
assert.Equal(t, "rs256@example.com", extractedClaims["email"])
|
||||||
|
assert.Equal(t, "read:all", extractedClaims["scope"])
|
||||||
|
|
||||||
|
// Test with public key only (can only verify)
|
||||||
|
authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should be able to validate token
|
||||||
|
extractedUserID, extractedClaims, err = authPub.ValidateToken(token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, userID, extractedUserID)
|
||||||
|
|
||||||
|
// Should not be able to generate token
|
||||||
|
_, err = authPub.GenerateToken(userID, claims)
|
||||||
|
assert.Error(t, err, "Public key only auth should not generate tokens")
|
||||||
|
assert.Equal(t, ErrTokenNoPrivateKey, err)
|
||||||
|
|
||||||
|
// Test algorithm mismatch
|
||||||
|
authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, err = authHS256.ValidateToken(token)
|
||||||
|
assert.Error(t, err, "HS256 auth should not validate RS256 token")
|
||||||
|
// assert.True(t, errors.Is(err, ErrInvalidToken))
|
||||||
|
assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch))
|
||||||
|
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpiredToken(t *testing.T) {
|
||||||
|
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID := "user123"
|
||||||
|
|
||||||
|
// Generate normal token (should have 7 days expiry)
|
||||||
|
token, err := auth.GenerateToken(userID, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, extractedClaims, err := auth.ValidateToken(token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check expiry is in future (approximately 7 days)
|
||||||
|
expiry := extractedClaims["exp"].(float64)
|
||||||
|
now := time.Now().Unix()
|
||||||
|
|
||||||
|
assert.Greater(t, expiry, float64(now), "Token expiry should be in future")
|
||||||
|
assert.InDelta(t, expiry, float64(now+7*24*60*60), 10,
|
||||||
|
"Token expiry should be approximately 7 days from now")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorruptJWTParts(t *testing.T) {
|
||||||
|
auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test with missing parts
|
||||||
|
_, _, err = auth.ValidateToken("only.two")
|
||||||
|
assert.True(t, errors.Is(err, ErrTokenMalformed))
|
||||||
|
|
||||||
|
// Test with invalid header encoding
|
||||||
|
_, _, err = auth.ValidateToken("not-base64!.valid.valid")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Test with invalid claims encoding
|
||||||
|
validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||||
|
_, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Test with invalid JSON in claims
|
||||||
|
invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json"))
|
||||||
|
_, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
498
scram.go
Normal file
498
scram.go
Normal file
@ -0,0 +1,498 @@
|
|||||||
|
// FILE: auth/scram.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/argon2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SCRAM-SHA256 implementation
|
||||||
|
|
||||||
|
// Credential stores SCRAM authentication data
|
||||||
|
type Credential struct {
|
||||||
|
Username string
|
||||||
|
Salt []byte
|
||||||
|
ArgonTime uint32
|
||||||
|
ArgonMemory uint32
|
||||||
|
ArgonThreads uint8
|
||||||
|
StoredKey []byte // SHA256(ClientKey)
|
||||||
|
ServerKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export returns credential as config-friendly map
|
||||||
|
func (c *Credential) Export() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"username": c.Username,
|
||||||
|
"salt": base64.StdEncoding.EncodeToString(c.Salt),
|
||||||
|
"argon_time": c.ArgonTime,
|
||||||
|
"argon_memory": c.ArgonMemory,
|
||||||
|
"argon_threads": c.ArgonThreads,
|
||||||
|
"stored_key": base64.StdEncoding.EncodeToString(c.StoredKey),
|
||||||
|
"server_key": base64.StdEncoding.EncodeToString(c.ServerKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImportCredential creates credential from map
|
||||||
|
func ImportCredential(data map[string]any) (*Credential, error) {
|
||||||
|
username, ok := data["username"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrCredMissingUsername
|
||||||
|
}
|
||||||
|
|
||||||
|
saltStr, ok := data["salt"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrCredMissingSalt
|
||||||
|
}
|
||||||
|
salt, err := base64.StdEncoding.DecodeString(saltStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrCredInvalidSalt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle both float64 (from JSON) and int types
|
||||||
|
getUint32 := func(key string) (uint32, error) {
|
||||||
|
val, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
switch key {
|
||||||
|
case "argon_time":
|
||||||
|
return 0, ErrCredMissingTime
|
||||||
|
case "argon_memory":
|
||||||
|
return 0, ErrCredMissingMemory
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("missing %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch v := val.(type) {
|
||||||
|
case float64:
|
||||||
|
return uint32(v), nil
|
||||||
|
case int:
|
||||||
|
return uint32(v), nil
|
||||||
|
case uint32:
|
||||||
|
return v, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid type for %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
argonTime, err := getUint32("argon_time")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
argonMemory, err := getUint32("argon_memory")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
threadsVal, ok := data["argon_threads"]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrCredMissingThreads
|
||||||
|
}
|
||||||
|
var argonThreads uint8
|
||||||
|
switch v := threadsVal.(type) {
|
||||||
|
case float64:
|
||||||
|
argonThreads = uint8(v)
|
||||||
|
case int:
|
||||||
|
argonThreads = uint8(v)
|
||||||
|
case uint8:
|
||||||
|
argonThreads = v
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w: argon_threads", ErrCredInvalidType)
|
||||||
|
}
|
||||||
|
|
||||||
|
storedKeyStr, ok := data["stored_key"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrCredMissingStoredKey
|
||||||
|
}
|
||||||
|
storedKey, err := base64.StdEncoding.DecodeString(storedKeyStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrCredInvalidStoredKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverKeyStr, ok := data["server_key"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrCredMissingServerKey
|
||||||
|
}
|
||||||
|
serverKey, err := base64.StdEncoding.DecodeString(serverKeyStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrCredInvalidServerKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Credential{
|
||||||
|
Username: username,
|
||||||
|
Salt: salt,
|
||||||
|
ArgonTime: argonTime,
|
||||||
|
ArgonMemory: argonMemory,
|
||||||
|
ArgonThreads: argonThreads,
|
||||||
|
StoredKey: storedKey,
|
||||||
|
ServerKey: serverKey,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveCredential creates SCRAM credential from password
|
||||||
|
func DeriveCredential(username, password string, salt []byte, time, memory uint32, threads uint8) (*Credential, error) {
|
||||||
|
if len(salt) < 16 {
|
||||||
|
return nil, ErrSCRAMSaltTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive salted password using Argon2id
|
||||||
|
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen)
|
||||||
|
|
||||||
|
// Derive keys
|
||||||
|
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||||
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||||
|
storedKey := sha256.Sum256(clientKey)
|
||||||
|
|
||||||
|
return &Credential{
|
||||||
|
Username: username,
|
||||||
|
Salt: salt,
|
||||||
|
ArgonTime: time,
|
||||||
|
ArgonMemory: memory,
|
||||||
|
ArgonThreads: threads,
|
||||||
|
StoredKey: storedKey[:],
|
||||||
|
ServerKey: serverKey,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScramServer handles server-side SCRAM authentication
|
||||||
|
type ScramServer struct {
|
||||||
|
credentials map[string]*Credential
|
||||||
|
handshakes map[string]*HandshakeState
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandshakeState tracks ongoing authentication
|
||||||
|
type HandshakeState struct {
|
||||||
|
Username string
|
||||||
|
ClientNonce string
|
||||||
|
ServerNonce string
|
||||||
|
FullNonce string
|
||||||
|
Credential *Credential
|
||||||
|
CreatedAt time.Time
|
||||||
|
verifying int32 // Atomic flag to prevent race during verification
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScramServer creates SCRAM server
|
||||||
|
func NewScramServer() *ScramServer {
|
||||||
|
return &ScramServer{
|
||||||
|
credentials: make(map[string]*Credential),
|
||||||
|
handshakes: make(map[string]*HandshakeState),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCredential registers user credential
|
||||||
|
func (s *ScramServer) AddCredential(cred *Credential) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.credentials[cred.Username] = cred
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessClientFirstMessage processes initial auth request
|
||||||
|
func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (ServerFirstMessage, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if user exists
|
||||||
|
cred, exists := s.credentials[username]
|
||||||
|
if !exists {
|
||||||
|
// Prevent user enumeration - still generate response
|
||||||
|
salt := make([]byte, 16)
|
||||||
|
rand.Read(salt)
|
||||||
|
serverNonce := generateNonce()
|
||||||
|
|
||||||
|
return ServerFirstMessage{
|
||||||
|
FullNonce: clientNonce + serverNonce,
|
||||||
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||||
|
ArgonTime: DefaultArgonTime,
|
||||||
|
ArgonMemory: DefaultArgonMemory,
|
||||||
|
ArgonThreads: DefaultArgonThreads,
|
||||||
|
}, ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate server nonce
|
||||||
|
serverNonce := generateNonce()
|
||||||
|
fullNonce := clientNonce + serverNonce
|
||||||
|
|
||||||
|
// Store handshake state
|
||||||
|
state := &HandshakeState{
|
||||||
|
Username: username,
|
||||||
|
ClientNonce: clientNonce,
|
||||||
|
ServerNonce: serverNonce,
|
||||||
|
FullNonce: fullNonce,
|
||||||
|
Credential: cred,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
verifying: 0,
|
||||||
|
}
|
||||||
|
s.handshakes[fullNonce] = state
|
||||||
|
|
||||||
|
// Cleanup old handshakes
|
||||||
|
s.cleanupHandshakes()
|
||||||
|
|
||||||
|
return ServerFirstMessage{
|
||||||
|
FullNonce: fullNonce,
|
||||||
|
Salt: base64.StdEncoding.EncodeToString(cred.Salt),
|
||||||
|
ArgonTime: cred.ArgonTime,
|
||||||
|
ArgonMemory: cred.ArgonMemory,
|
||||||
|
ArgonThreads: cred.ArgonThreads,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessClientFinalMessage verifies client proof
|
||||||
|
func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (ServerFinalMessage, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
state, exists := s.handshakes[fullNonce]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return ServerFinalMessage{}, ErrSCRAMInvalidNonce
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark as verifying to prevent deletion race
|
||||||
|
if !atomic.CompareAndSwapInt32(&state.verifying, 0, 1) {
|
||||||
|
return ServerFinalMessage{}, ErrSCRAMVerifyInProgress
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
atomic.StoreInt32(&state.verifying, 0)
|
||||||
|
// Safe to delete after verification completes
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.handshakes, fullNonce)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Check timeout
|
||||||
|
if time.Since(state.CreatedAt) > 60*time.Second {
|
||||||
|
return ServerFinalMessage{}, ErrSCRAMTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode client proof
|
||||||
|
clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof)
|
||||||
|
if err != nil {
|
||||||
|
return ServerFinalMessage{}, ErrSCRAMInvalidProof
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build auth message
|
||||||
|
clientFirstBare := fmt.Sprintf("u=%s,n=%s", state.Username, state.ClientNonce)
|
||||||
|
serverFirst := ServerFirstMessage{
|
||||||
|
FullNonce: state.FullNonce,
|
||||||
|
Salt: base64.StdEncoding.EncodeToString(state.Credential.Salt),
|
||||||
|
ArgonTime: state.Credential.ArgonTime,
|
||||||
|
ArgonMemory: state.Credential.ArgonMemory,
|
||||||
|
ArgonThreads: state.Credential.ArgonThreads,
|
||||||
|
}
|
||||||
|
clientFinalBare := fmt.Sprintf("r=%s", fullNonce)
|
||||||
|
authMessage := clientFirstBare + "," + serverFirst.Marshal() + "," + clientFinalBare
|
||||||
|
|
||||||
|
// Compute client signature
|
||||||
|
clientSignature := computeHMAC(state.Credential.StoredKey, []byte(authMessage))
|
||||||
|
|
||||||
|
// XOR to get ClientKey
|
||||||
|
if len(clientProofBytes) != len(clientSignature) {
|
||||||
|
return ServerFinalMessage{}, ErrSCRAMInvalidProofLen
|
||||||
|
}
|
||||||
|
clientKey := xorBytes(clientProofBytes, clientSignature)
|
||||||
|
|
||||||
|
// Verify by computing StoredKey
|
||||||
|
computedStoredKey := sha256.Sum256(clientKey)
|
||||||
|
if subtle.ConstantTimeCompare(computedStoredKey[:], state.Credential.StoredKey) != 1 {
|
||||||
|
return ServerFinalMessage{}, ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate server signature for mutual auth
|
||||||
|
serverSignature := computeHMAC(state.Credential.ServerKey, []byte(authMessage))
|
||||||
|
|
||||||
|
return ServerFinalMessage{
|
||||||
|
ServerSignature: base64.StdEncoding.EncodeToString(serverSignature),
|
||||||
|
Username: state.Username,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScramServer) cleanupHandshakes() {
|
||||||
|
cutoff := time.Now().Add(-60 * time.Second)
|
||||||
|
for nonce, state := range s.handshakes {
|
||||||
|
if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 {
|
||||||
|
delete(s.handshakes, nonce)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScramClient handles client-side SCRAM authentication
|
||||||
|
type ScramClient struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
clientNonce string
|
||||||
|
serverFirst *ServerFirstMessage
|
||||||
|
authMessage string
|
||||||
|
serverKey []byte
|
||||||
|
startTime time.Time // Track handshake start
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScramClient creates SCRAM client
|
||||||
|
func NewScramClient(username, password string) *ScramClient {
|
||||||
|
return &ScramClient{
|
||||||
|
Username: username,
|
||||||
|
Password: password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartAuthentication generates initial client message
|
||||||
|
func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) {
|
||||||
|
c.startTime = time.Now()
|
||||||
|
|
||||||
|
// Generate client nonce
|
||||||
|
nonce := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return ClientFirstRequest{}, ErrSCRAMNonceGenFailed
|
||||||
|
}
|
||||||
|
c.clientNonce = base64.StdEncoding.EncodeToString(nonce)
|
||||||
|
|
||||||
|
return ClientFirstRequest{
|
||||||
|
Username: c.Username,
|
||||||
|
ClientNonce: c.clientNonce,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessServerFirstMessage handles server challenge
|
||||||
|
func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) {
|
||||||
|
// Check timeout (30 seconds)
|
||||||
|
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
||||||
|
return ClientFinalRequest{}, ErrSCRAMTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
c.serverFirst = &msg
|
||||||
|
|
||||||
|
// Handle enumeration prevention - server may send fake response
|
||||||
|
// We still process it normally and let verification fail later
|
||||||
|
|
||||||
|
// Decode salt
|
||||||
|
salt, err := base64.StdEncoding.DecodeString(msg.Salt)
|
||||||
|
if err != nil {
|
||||||
|
return ClientFinalRequest{}, ErrSCRAMInvalidSalt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate parameters
|
||||||
|
if msg.ArgonTime == 0 || msg.ArgonMemory == 0 || msg.ArgonThreads == 0 {
|
||||||
|
return ClientFinalRequest{}, ErrSCRAMZeroParams
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive keys using Argon2id
|
||||||
|
saltedPassword := argon2.IDKey([]byte(c.Password), salt, msg.ArgonTime, msg.ArgonMemory, msg.ArgonThreads, 32)
|
||||||
|
|
||||||
|
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||||
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||||
|
storedKey := sha256.Sum256(clientKey)
|
||||||
|
|
||||||
|
// Build auth message
|
||||||
|
clientFirstBare := fmt.Sprintf("u=%s,n=%s", c.Username, c.clientNonce)
|
||||||
|
clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce)
|
||||||
|
c.authMessage = clientFirstBare + "," + msg.Marshal() + "," + clientFinalBare
|
||||||
|
|
||||||
|
// Compute client proof
|
||||||
|
clientSignature := computeHMAC(storedKey[:], []byte(c.authMessage))
|
||||||
|
clientProof := xorBytes(clientKey, clientSignature)
|
||||||
|
|
||||||
|
// Store server key for verification
|
||||||
|
c.serverKey = serverKey
|
||||||
|
|
||||||
|
return ClientFinalRequest{
|
||||||
|
FullNonce: msg.FullNonce,
|
||||||
|
ClientProof: base64.StdEncoding.EncodeToString(clientProof),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyServerFinalMessage validates server signature
|
||||||
|
func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error {
|
||||||
|
// Check timeout
|
||||||
|
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
||||||
|
return ErrSCRAMTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.authMessage == "" || c.serverKey == nil {
|
||||||
|
return ErrSCRAMInvalidState
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute expected server signature
|
||||||
|
expectedSig := computeHMAC(c.serverKey, []byte(c.authMessage))
|
||||||
|
|
||||||
|
// Decode received signature
|
||||||
|
receivedSig, err := base64.StdEncoding.DecodeString(msg.ServerSignature)
|
||||||
|
if err != nil {
|
||||||
|
return ErrSCRAMServerAuthFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constant-time comparison
|
||||||
|
if subtle.ConstantTimeCompare(expectedSig, receivedSig) != 1 {
|
||||||
|
return ErrSCRAMServerAuthFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears client state for retry
|
||||||
|
func (c *ScramClient) Reset() {
|
||||||
|
c.clientNonce = ""
|
||||||
|
c.serverFirst = nil
|
||||||
|
c.authMessage = ""
|
||||||
|
c.serverKey = nil
|
||||||
|
c.startTime = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SCRAM message types
|
||||||
|
type ClientFirstRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
ClientNonce string `json:"client_nonce"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerFirstMessage struct {
|
||||||
|
FullNonce string `json:"full_nonce"`
|
||||||
|
Salt string `json:"salt"`
|
||||||
|
ArgonTime uint32 `json:"argon_time"`
|
||||||
|
ArgonMemory uint32 `json:"argon_memory"`
|
||||||
|
ArgonThreads uint8 `json:"argon_threads"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ServerFirstMessage) Marshal() string {
|
||||||
|
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
|
||||||
|
s.FullNonce, s.Salt, s.ArgonTime, s.ArgonMemory, s.ArgonThreads)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientFinalRequest struct {
|
||||||
|
FullNonce string `json:"full_nonce"`
|
||||||
|
ClientProof string `json:"client_proof"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerFinalMessage struct {
|
||||||
|
ServerSignature string `json:"server_signature"`
|
||||||
|
Username string `json:"username,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
func computeHMAC(key, message []byte) []byte {
|
||||||
|
mac := hmac.New(sha256.New, key)
|
||||||
|
mac.Write(message)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func xorBytes(a, b []byte) []byte {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
panic("xor length mismatch")
|
||||||
|
}
|
||||||
|
result := make([]byte, len(a))
|
||||||
|
for i := range a {
|
||||||
|
result[i] = a[i] ^ b[i]
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateNonce() string {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
rand.Read(b)
|
||||||
|
return base64.StdEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
50
token.go
Normal file
50
token.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
// FILE: auth/token.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SimpleTokenValidator implements in-memory token validation
|
||||||
|
type SimpleTokenValidator struct {
|
||||||
|
tokens map[string]struct{}
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSimpleTokenValidator creates token validator
|
||||||
|
func NewSimpleTokenValidator() *SimpleTokenValidator {
|
||||||
|
return &SimpleTokenValidator{
|
||||||
|
tokens: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken checks if token is valid
|
||||||
|
func (v *SimpleTokenValidator) ValidateToken(token string) bool {
|
||||||
|
v.mu.RLock()
|
||||||
|
defer v.mu.RUnlock()
|
||||||
|
|
||||||
|
// Constant-time comparison for each stored token
|
||||||
|
for storedToken := range v.tokens {
|
||||||
|
if subtle.ConstantTimeEq(int32(len(token)), int32(len(storedToken))) == 1 {
|
||||||
|
if subtle.ConstantTimeCompare([]byte(token), []byte(storedToken)) == 1 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddToken adds token to validator
|
||||||
|
func (v *SimpleTokenValidator) AddToken(token string) {
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
v.tokens[token] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveToken removes token from validator
|
||||||
|
func (v *SimpleTokenValidator) RemoveToken(token string) {
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
delete(v.tokens, token)
|
||||||
|
}
|
||||||
81
token_test.go
Normal file
81
token_test.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
// FILE: auth/token_test.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSimpleTokenValidator(t *testing.T) {
|
||||||
|
validator := NewSimpleTokenValidator()
|
||||||
|
|
||||||
|
token1 := "test-token-123"
|
||||||
|
token2 := "test-token-456"
|
||||||
|
|
||||||
|
// Add tokens
|
||||||
|
validator.AddToken(token1)
|
||||||
|
validator.AddToken(token2)
|
||||||
|
|
||||||
|
// Validate existing tokens
|
||||||
|
assert.True(t, validator.ValidateToken(token1))
|
||||||
|
assert.True(t, validator.ValidateToken(token2))
|
||||||
|
|
||||||
|
// Invalid token
|
||||||
|
assert.False(t, validator.ValidateToken("invalid-token"))
|
||||||
|
|
||||||
|
// Remove token
|
||||||
|
validator.RemoveToken(token1)
|
||||||
|
assert.False(t, validator.ValidateToken(token1))
|
||||||
|
assert.True(t, validator.ValidateToken(token2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentTokenValidator(t *testing.T) {
|
||||||
|
validator := NewSimpleTokenValidator()
|
||||||
|
|
||||||
|
// Add tokens concurrently
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
token := fmt.Sprintf("token-%d", idx)
|
||||||
|
validator.AddToken(token)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Validate concurrently
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
token := fmt.Sprintf("token-%d", idx)
|
||||||
|
assert.True(t, validator.ValidateToken(token))
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Remove concurrently
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
token := fmt.Sprintf("token-%d", idx)
|
||||||
|
validator.RemoveToken(token)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Verify removal
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
token := fmt.Sprintf("token-%d", i)
|
||||||
|
assert.False(t, validator.ValidateToken(token))
|
||||||
|
}
|
||||||
|
for i := 50; i < 100; i++ {
|
||||||
|
token := fmt.Sprintf("token-%d", i)
|
||||||
|
assert.True(t, validator.ValidateToken(token))
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user