From bc1a7603973458dff6cdd31d8c7a6cf2c74d876cae5a8cd09fee7b9a56826dfe Mon Sep 17 00:00:00 2001 From: Lixen Wraith Date: Sun, 2 Nov 2025 13:05:37 -0500 Subject: [PATCH] v0.1.0 initial commit, auth features exatracted from logwisp to be a standalone utility package --- .gitignore | 12 ++ LICENSE | 28 +++ README.md | 38 ++++ argon2.go | 110 +++++++++++ argon2_test.go | 117 ++++++++++++ auth.go | 71 +++++++ auth_test.go | 68 +++++++ error.go | 100 ++++++++++ go.mod | 15 ++ go.sum | 14 ++ http.go | 61 ++++++ http_test.go | 54 ++++++ interface.go | 17 ++ jwt.go | 205 ++++++++++++++++++++ jwt_test.go | 176 +++++++++++++++++ scram.go | 498 +++++++++++++++++++++++++++++++++++++++++++++++++ token.go | 50 +++++ token_test.go | 81 ++++++++ 18 files changed, 1715 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 argon2.go create mode 100644 argon2_test.go create mode 100644 auth.go create mode 100644 auth_test.go create mode 100644 error.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 http.go create mode 100644 http_test.go create mode 100644 interface.go create mode 100644 jwt.go create mode 100644 jwt_test.go create mode 100644 scram.go create mode 100644 token.go create mode 100644 token_test.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4facf1e --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +.idea +data +dev +log +logs +cert +bin +script +build +*.log +*.toml +build.sh diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c71f04c --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2025, Lixen Wraith + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..c82d0e1 --- /dev/null +++ b/README.md @@ -0,0 +1,38 @@ +# Auth Package + +Pluggable authentication utilities for Go applications. + +## Features + +- **Password Hashing**: Argon2id with PHC format +- **JWT**: HS256/RS256 token generation and validation +- **SCRAM-SHA256**: Client/server implementation with Argon2id KDF +- **HTTP Auth**: Basic/Bearer header parsing + +## Usage +```go +// JWT with HS256 +auth, _ := auth.NewAuthenticator([]byte("32-byte-secret-key...")) +token, _ := auth.GenerateToken("user123", map[string]interface{}{"role": "admin"}) +userID, claims, _ := auth.ValidateToken(token) + +// SCRAM authentication +server := auth.NewScramServer() +cred, _ := auth.DeriveCredential("user", "password", salt, 1, 65536, 4) +server.AddCredential(cred) +``` + +## Package Structure + +- `interfaces.go` - Core interfaces +- `jwt.go` - JWT token operations +- `argon2.go` - Password hashing +- `scram.go` - SCRAM-SHA256 protocol +- `token.go` - Token validation utilities +- `http.go` - HTTP header parsing +- `errors.go` - Error definitions + +## Testing +```bash +go test -v ./auth +``` \ No newline at end of file diff --git a/argon2.go b/argon2.go new file mode 100644 index 0000000..533d5ce --- /dev/null +++ b/argon2.go @@ -0,0 +1,110 @@ +// FILE: auth/argon2.go +package auth + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +// Default Argon2id parameters +const ( + DefaultArgonTime = 3 // iterations (reduce for faster but less secure auth) + DefaultArgonMemory = 64 * 1024 // 64 MB + DefaultArgonThreads = 4 + DefaultArgonSaltLen = 16 + DefaultArgonKeyLen = 32 +) + +// HashPassword creates an Argon2id PHC-format hash +func (a *Authenticator) HashPassword(password string) (string, error) { + if len(password) < 8 { + return "", ErrWeakPassword + } + + // Generate salt + salt := make([]byte, DefaultArgonSaltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("%w: %v", ErrSaltGenerationFailed, err) + } + + // Derive key using Argon2id + hash := argon2.IDKey([]byte(password), salt, a.argonTime, a.argonMemory, a.argonThreads, DefaultArgonKeyLen) + + // Construct PHC format + saltB64 := base64.RawStdEncoding.EncodeToString(salt) + hashB64 := base64.RawStdEncoding.EncodeToString(hash) + phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, a.argonMemory, a.argonTime, a.argonThreads, saltB64, hashB64) + + return phcHash, nil +} + +// VerifyPassword checks password against PHC-format hash +func (a *Authenticator) VerifyPassword(password, phcHash string) error { + // Parse PHC format + parts := strings.Split(phcHash, "$") + if len(parts) != 6 || parts[1] != "argon2id" { + return ErrPHCInvalidFormat + } + + var memory, time uint32 + var threads uint8 + fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) + + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return fmt.Errorf("%w: %v", ErrPHCInvalidSalt, err) + } + + expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err) + } + + // Compute hash with same parameters + computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash))) + + // Constant-time comparison + if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 { + return ErrInvalidCredentials + } + + return nil +} + +// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential +func MigrateFromPHC(username, password, phcHash string) (*Credential, error) { + // Parse PHC format + parts := strings.Split(phcHash, "$") + if len(parts) != 6 || parts[1] != "argon2id" { + return nil, ErrPHCInvalidFormat + } + + var memory, time uint32 + var threads uint8 + fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) + + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return nil, ErrPHCInvalidSalt + } + + expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return nil, ErrPHCInvalidHash + } + + // Verify password against hash + computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash))) + if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 { + return nil, ErrInvalidCredentials + } + + // Derive SCRAM credential with same parameters + return DeriveCredential(username, password, salt, time, memory, threads) +} \ No newline at end of file diff --git a/argon2_test.go b/argon2_test.go new file mode 100644 index 0000000..f58f066 --- /dev/null +++ b/argon2_test.go @@ -0,0 +1,117 @@ +// FILE: auth/argon2_test.go +package auth + +import ( + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPasswordHashing(t *testing.T) { + auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err, "Failed to create authenticator") + + password := "testPassword123" + + // Test hashing + hash, err := auth.HashPassword(password) + require.NoError(t, err, "Failed to hash password") + + // Verify PHC format + assert.True(t, strings.HasPrefix(hash, "$argon2id$"), + "Hash should have argon2id prefix, got: %s", hash) + + // Test verification with correct password + err = auth.VerifyPassword(password, hash) + assert.NoError(t, err, "Failed to verify correct password") + + // Test verification with incorrect password + err = auth.VerifyPassword("wrongPassword", hash) + assert.Error(t, err, "Verification should fail for incorrect password") + assert.Equal(t, ErrInvalidCredentials, err) + + // Test weak password + _, err = auth.HashPassword("weak") + assert.Equal(t, ErrWeakPassword, err, "Should reject weak password") + + // Test malformed PHC hash + err = auth.VerifyPassword(password, "$invalid$format") + assert.Error(t, err, "Should reject malformed hash") + + // Test corrupted salt + corruptedHash := strings.Replace(hash, "$argon2id$", "$argon2id$", 1) + parts := strings.Split(corruptedHash, "$") + if len(parts) == 6 { + parts[4] = "invalid!base64" + corruptedHash = strings.Join(parts, "$") + err = auth.VerifyPassword(password, corruptedHash) + assert.Error(t, err, "Should reject corrupted salt") + } +} + +func TestEmptyPasswordAfterValidation(t *testing.T) { + auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err) + + // Empty password should be rejected by length check + _, err = auth.HashPassword("") + assert.Equal(t, ErrWeakPassword, err) + + // 8-character password should pass + hash, err := auth.HashPassword("12345678") + require.NoError(t, err) + + err = auth.VerifyPassword("12345678", hash) + assert.NoError(t, err) +} + +func TestConcurrentPasswordOperations(t *testing.T) { + auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err) + + password := "testPassword123" + hash, err := auth.HashPassword(password) + require.NoError(t, err) + + // Test concurrent verification + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := auth.VerifyPassword(password, hash) + assert.NoError(t, err) + }() + } + wg.Wait() +} + +func TestPHCMigration(t *testing.T) { + auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err) + + password := "testPassword123" + username := "migrationUser" + + // Generate PHC hash + phcHash, err := auth.HashPassword(password) + require.NoError(t, err) + + // Migrate to SCRAM credential + cred, err := MigrateFromPHC(username, password, phcHash) + require.NoError(t, err) + assert.Equal(t, username, cred.Username) + assert.NotNil(t, cred.StoredKey) + assert.NotNil(t, cred.ServerKey) + + // Test with wrong password + _, err = MigrateFromPHC(username, "wrongPassword", phcHash) + assert.Equal(t, ErrInvalidCredentials, err) + + // Test with invalid PHC format + _, err = MigrateFromPHC(username, password, "$invalid$format") + assert.Error(t, err) +} \ No newline at end of file diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..3a80ab3 --- /dev/null +++ b/auth.go @@ -0,0 +1,71 @@ +// FILE: auth/auth.go +package auth + +import ( + "crypto/rsa" + "fmt" +) + +// Authenticator provides password hashing and JWT operations +type Authenticator struct { + algorithm string + jwtSecret []byte // For HS256 + privateKey *rsa.PrivateKey // For RS256 + publicKey *rsa.PublicKey // For RS256 + argonTime uint32 + argonMemory uint32 + argonThreads uint8 +} + +// NewAuthenticator creates a new authenticator with specified algorithm +func NewAuthenticator(key any, algorithm ...string) (*Authenticator, error) { + alg := "HS256" + if len(algorithm) > 0 && algorithm[0] != "" { + alg = algorithm[0] + } + + auth := &Authenticator{ + algorithm: alg, + argonTime: DefaultArgonTime, + argonMemory: DefaultArgonMemory, + argonThreads: DefaultArgonThreads, + } + + switch alg { + case "HS256": + secret, ok := key.([]byte) + if !ok { + return nil, ErrInvalidKeyType + } + if len(secret) < 32 { + return nil, ErrSecretTooShort + } + auth.jwtSecret = secret + + case "RS256": + switch k := key.(type) { + case *rsa.PrivateKey: + auth.privateKey = k + auth.publicKey = &k.PublicKey + case *rsa.PublicKey: + auth.publicKey = k + case []byte: + // Try parsing as PEM + if privKey, err := parseRSAPrivateKey(k); err == nil { + auth.privateKey = privKey + auth.publicKey = &privKey.PublicKey + } else if pubKey, err := parseRSAPublicKey(k); err == nil { + auth.publicKey = pubKey + } else { + return nil, fmt.Errorf("failed to parse RSA key: %w", err) + } + default: + return nil, ErrInvalidKeyType + } + + default: + return nil, ErrInvalidAlgorithm + } + + return auth, nil +} \ No newline at end of file diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..5176f37 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,68 @@ +// FILE: auth/auth_test.go +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAuthenticator(t *testing.T) { + // Test HS256 creation + auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err, "Failed to create HS256 authenticator") + assert.Equal(t, "HS256", auth.algorithm) + + // Test with short secret + _, err = NewAuthenticator([]byte("short")) + assert.Equal(t, ErrSecretTooShort, err) + + // Test RS256 with private key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + authRS, err := NewAuthenticator(privateKey, "RS256") + require.NoError(t, err) + assert.Equal(t, "RS256", authRS.algorithm) + assert.NotNil(t, authRS.privateKey) + assert.NotNil(t, authRS.publicKey) + + // Test RS256 with public key only + authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256") + require.NoError(t, err) + assert.Equal(t, "RS256", authPub.algorithm) + assert.Nil(t, authPub.privateKey) + assert.NotNil(t, authPub.publicKey) + + // Test invalid algorithm + _, err = NewAuthenticator([]byte("test-secret-key-must-be-32-bytes"), "INVALID") + assert.Equal(t, ErrInvalidAlgorithm, err) + + // Test invalid key type for HS256 + _, err = NewAuthenticator(privateKey, "HS256") + assert.Equal(t, ErrInvalidKeyType, err) +} + +func TestInterfaceCompliance(t *testing.T) { + // Verify Authenticator implements AuthenticatorInterface + auth, _ := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + + var _ AuthenticatorInterface = auth + + // Test interface methods work + hash, err := auth.HashPassword("testpass123") + require.NoError(t, err) + + err = auth.VerifyPassword("testpass123", hash) + assert.NoError(t, err) + + token, err := auth.GenerateToken("user1", nil) + require.NoError(t, err) + + userID, _, err := auth.ValidateToken(token) + require.NoError(t, err) + assert.Equal(t, "user1", userID) +} \ No newline at end of file diff --git a/error.go b/error.go new file mode 100644 index 0000000..824a494 --- /dev/null +++ b/error.go @@ -0,0 +1,100 @@ +// FILE: auth/errors.go +package auth + +import ( + "errors" + "fmt" +) + +// Base authentication errors +var ( + ErrInvalidCredentials = errors.New("invalid credentials") + ErrWeakPassword = errors.New("password must be at least 8 characters") + ErrInvalidAlgorithm = errors.New("invalid algorithm") + ErrInvalidKeyType = errors.New("invalid key type for algorithm") +) + +// JWT-specific errors +var ( + ErrTokenMalformed = errors.New("token: malformed structure") + ErrTokenExpired = errors.New("token: expired") + ErrTokenNotYetValid = errors.New("token: not yet valid") + ErrTokenInvalidSignature = errors.New("token: invalid signature") + ErrTokenAlgorithmMismatch = errors.New("token: algorithm mismatch") + ErrTokenMissingClaim = errors.New("token: missing required claim") + ErrTokenInvalidHeader = errors.New("token: invalid header encoding") + ErrTokenInvalidClaims = errors.New("token: invalid claims encoding") + ErrTokenInvalidJSON = errors.New("token: malformed JSON") + ErrTokenEmptyUserID = errors.New("token: empty user ID") + ErrTokenNoPrivateKey = errors.New("token: private key required for signing") + ErrTokenNoPublicKey = errors.New("token: public key required for verification") +) + +// JWT secret errors +var ( + ErrSecretTooShort = errors.New("JWT secret must be at least 32 bytes") +) + +// RSA key parsing errors +var ( + ErrRSAInvalidPEM = errors.New("rsa: failed to parse PEM block") + ErrRSAInvalidPrivateKey = errors.New("rsa: invalid private key format") + ErrRSAInvalidPublicKey = errors.New("rsa: invalid public key format") + ErrRSANotPublicKey = errors.New("rsa: not an RSA public key") +) + +// PHC format errors +var ( + ErrPHCInvalidFormat = errors.New("phc: invalid format") + ErrPHCInvalidSalt = errors.New("phc: invalid salt encoding") + ErrPHCInvalidHash = errors.New("phc: invalid hash encoding") +) + +// SCRAM-specific errors +var ( + ErrSCRAMInvalidNonce = errors.New("scram: invalid nonce or expired handshake") + ErrSCRAMTimeout = errors.New("scram: handshake timeout") + ErrSCRAMVerifyInProgress = errors.New("scram: verification already in progress") + ErrSCRAMInvalidProof = errors.New("scram: invalid proof encoding") + ErrSCRAMInvalidProofLen = errors.New("scram: invalid proof length") + ErrSCRAMServerAuthFailed = errors.New("scram: server authentication failed") + ErrSCRAMInvalidState = errors.New("scram: invalid handshake state") + ErrSCRAMInvalidSalt = errors.New("scram: invalid salt encoding") + ErrSCRAMZeroParams = errors.New("scram: invalid Argon2 parameters") + ErrSCRAMSaltTooShort = errors.New("scram: salt must be at least 16 bytes") + ErrSCRAMNonceGenFailed = errors.New("scram: failed to generate nonce") +) + +// Credential import/export errors +var ( + ErrCredMissingUsername = errors.New("credential: missing username") + ErrCredMissingSalt = errors.New("credential: missing salt") + ErrCredInvalidSalt = errors.New("credential: invalid salt encoding") + ErrCredMissingTime = errors.New("credential: missing argon_time") + ErrCredMissingMemory = errors.New("credential: missing argon_memory") + ErrCredMissingThreads = errors.New("credential: missing argon_threads") + ErrCredMissingStoredKey = errors.New("credential: missing stored_key") + ErrCredInvalidStoredKey = errors.New("credential: invalid stored_key encoding") + ErrCredMissingServerKey = errors.New("credential: missing server_key") + ErrCredInvalidServerKey = errors.New("credential: invalid server_key encoding") + ErrCredInvalidType = fmt.Errorf("credential: invalid type for field") +) + +// HTTP auth parsing errors +var ( + ErrAuthInvalidBasicFormat = errors.New("auth: invalid Basic auth format") + ErrAuthInvalidBasicEncoding = errors.New("auth: invalid Basic auth base64 encoding") + ErrAuthInvalidBasicCreds = errors.New("auth: invalid Basic auth credentials format") + ErrAuthInvalidBearerFormat = errors.New("auth: invalid Bearer auth format") + ErrAuthEmptyBearerToken = errors.New("auth: empty Bearer token") +) + +// Salt generation errors +var ( + ErrSaltGenerationFailed = errors.New("failed to generate salt") +) + +// Key generation errors +var ( + ErrRSAKeyGenFailed = errors.New("failed to generate RSA key") +) \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b6154ca --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module auth + +go 1.25.3 + +require ( + github.com/stretchr/testify v1.11.1 + golang.org/x/crypto v0.43.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.37.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4a5adfe --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/http.go b/http.go new file mode 100644 index 0000000..c8b7a16 --- /dev/null +++ b/http.go @@ -0,0 +1,61 @@ +// FILE: auth/http.go +package auth + +import ( + "encoding/base64" + "strings" +) + +// ParseBasicAuth extracts username/password from Basic auth header +func ParseBasicAuth(header string) (username, password string, err error) { + const prefix = "Basic " + if !strings.HasPrefix(header, prefix) { + return "", "", ErrAuthInvalidBasicFormat + } + + encoded := strings.TrimPrefix(header, prefix) + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", "", ErrAuthInvalidBasicEncoding + } + + credentials := string(decoded) + idx := strings.IndexByte(credentials, ':') + if idx < 0 { + return "", "", ErrAuthInvalidBasicCreds + } + + return credentials[:idx], credentials[idx+1:], nil +} + +// ParseBearerToken extracts token from Bearer auth header +func ParseBearerToken(header string) (token string, err error) { + const prefix = "Bearer " + if !strings.HasPrefix(header, prefix) { + return "", ErrAuthInvalidBearerFormat + } + + token = strings.TrimPrefix(header, prefix) + if token == "" { + return "", ErrAuthEmptyBearerToken + } + + return token, nil +} + +// ExtractAuthType returns authentication type from header +func ExtractAuthType(header string) string { + if strings.HasPrefix(header, "Basic ") { + return "Basic" + } + if strings.HasPrefix(header, "Bearer ") { + return "Bearer" + } + + // Extract first word as auth type + idx := strings.IndexByte(header, ' ') + if idx > 0 { + return header[:idx] + } + return "" +} \ No newline at end of file diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..984e590 --- /dev/null +++ b/http_test.go @@ -0,0 +1,54 @@ +// FILE: auth/http_test.go +package auth + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHTTPAuthParsing(t *testing.T) { + // Test Basic Auth + basicHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")) + username, password, err := ParseBasicAuth(basicHeader) + require.NoError(t, err) + assert.Equal(t, "user", username) + assert.Equal(t, "pass", password) + + // Test Bearer Token + bearerHeader := "Bearer test-token-xyz" + token, err := ParseBearerToken(bearerHeader) + require.NoError(t, err) + assert.Equal(t, "test-token-xyz", token) + + // Test ExtractAuthType + assert.Equal(t, "Basic", ExtractAuthType(basicHeader)) + assert.Equal(t, "Bearer", ExtractAuthType(bearerHeader)) + assert.Equal(t, "Custom", ExtractAuthType("Custom somedata")) + assert.Equal(t, "", ExtractAuthType("InvalidHeader")) + + // Test invalid formats + _, _, err = ParseBasicAuth("Invalid header") + assert.Error(t, err) + assert.Equal(t, ErrAuthInvalidBasicFormat, err) + + _, err = ParseBearerToken("Invalid header") + assert.Error(t, err) + assert.Equal(t, ErrAuthInvalidBearerFormat, err) + + // Test malformed Basic auth + _, _, err = ParseBasicAuth("Basic not-base64!") + assert.Error(t, err) + assert.Equal(t, ErrAuthInvalidBasicEncoding, err) + + _, _, err = ParseBasicAuth("Basic " + base64.StdEncoding.EncodeToString([]byte("no-colon"))) + assert.Error(t, err) + assert.Equal(t, ErrAuthInvalidBasicCreds, err) + + // Test empty Bearer token + _, err = ParseBearerToken("Bearer ") + assert.Error(t, err) + assert.Equal(t, ErrAuthEmptyBearerToken, err) +} \ No newline at end of file diff --git a/interface.go b/interface.go new file mode 100644 index 0000000..e84d390 --- /dev/null +++ b/interface.go @@ -0,0 +1,17 @@ +// FILE: auth/interface.go +package auth + +// AuthenticatorInterface defines the authentication operations +type AuthenticatorInterface interface { + HashPassword(password string) (hash string, err error) + VerifyPassword(password, hash string) (err error) + GenerateToken(userID string, claims map[string]any) (token string, err error) + ValidateToken(token string) (userID string, claims map[string]any, err error) +} + +// TokenValidator validates bearer tokens +type TokenValidator interface { + ValidateToken(token string) (valid bool) + AddToken(token string) + RemoveToken(token string) +} \ No newline at end of file diff --git a/jwt.go b/jwt.go new file mode 100644 index 0000000..e5d4f53 --- /dev/null +++ b/jwt.go @@ -0,0 +1,205 @@ +// FILE: auth/jwt.go +package auth + +import ( + "crypto" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/subtle" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "strings" + "time" +) + +// GenerateToken creates a JWT token with user claims +func (a *Authenticator) GenerateToken(userID string, claims map[string]any) (string, error) { + if userID == "" { + return "", ErrTokenEmptyUserID + } + + if a.algorithm == "RS256" && a.privateKey == nil { + return "", ErrTokenNoPrivateKey + } + + // Build JWT claims + now := time.Now() + jwtClaims := map[string]any{ + "sub": userID, + "iat": now.Unix(), + "exp": now.Add(7 * 24 * time.Hour).Unix(), // 7 days expiry + } + + // Reserved claims that cannot be overridden + reservedClaims := map[string]bool{ + "sub": true, "iat": true, "exp": true, "nbf": true, + "iss": true, "aud": true, "jti": true, "typ": true, + "alg": true, + } + + // Merge custom claims + for k, v := range claims { + if !reservedClaims[k] { + jwtClaims[k] = v + } + } + + // Create JWT header + header := map[string]any{ + "alg": a.algorithm, + "typ": "JWT", + } + + // Encode header and payload + headerJSON, _ := json.Marshal(header) + claimsJSON, _ := json.Marshal(jwtClaims) + + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + + // Create signature + signingInput := headerB64 + "." + claimsB64 + var signature string + + switch a.algorithm { + case "HS256": + h := hmac.New(sha256.New, a.jwtSecret) + h.Write([]byte(signingInput)) + signature = base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + case "RS256": + hashed := sha256.Sum256([]byte(signingInput)) + sig, err := rsa.SignPKCS1v15(rand.Reader, a.privateKey, crypto.SHA256, hashed[:]) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + signature = base64.RawURLEncoding.EncodeToString(sig) + } + + // Combine to form JWT + token := signingInput + "." + signature + + return token, nil +} + +// ValidateToken verifies JWT and returns userID and claims +func (a *Authenticator) ValidateToken(token string) (string, map[string]any, error) { + // Split token + parts := strings.Split(token, ".") + if len(parts) != 3 { + return "", nil, ErrTokenMalformed + } + + // Decode header to check algorithm + headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return "", nil, ErrTokenInvalidHeader + } + + var header map[string]any + if err = json.Unmarshal(headerJSON, &header); err != nil { + return "", nil, ErrTokenInvalidJSON + } + + // Verify algorithm matches + if alg, ok := header["alg"].(string); !ok || alg != a.algorithm { + return "", nil, ErrTokenAlgorithmMismatch + } + + // Verify signature + signingInput := parts[0] + "." + parts[1] + + switch a.algorithm { + case "HS256": + h := hmac.New(sha256.New, a.jwtSecret) + h.Write([]byte(signingInput)) + expectedSig := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 { + return "", nil, ErrTokenInvalidSignature + } + + case "RS256": + if a.publicKey == nil { + return "", nil, ErrTokenNoPublicKey + } + + sig, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return "", nil, ErrTokenInvalidSignature + } + + hashed := sha256.Sum256([]byte(signingInput)) + if err := rsa.VerifyPKCS1v15(a.publicKey, crypto.SHA256, hashed[:], sig); err != nil { + return "", nil, ErrTokenInvalidSignature + } + } + + // Decode claims + claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", nil, ErrTokenInvalidClaims + } + + var claims map[string]any + if err := json.Unmarshal(claimsJSON, &claims); err != nil { + return "", nil, ErrTokenInvalidJSON + } + + // Check expiration + if exp, ok := claims["exp"].(float64); ok { + if time.Now().Unix() > int64(exp) { + return "", nil, ErrTokenExpired + } + } + + // Check not before + if nbf, ok := claims["nbf"].(float64); ok { + if time.Now().Unix() < int64(nbf) { + return "", nil, ErrTokenNotYetValid + } + } + + // Extract userID + userID, ok := claims["sub"].(string) + if !ok { + return "", nil, ErrTokenMissingClaim + } + + return userID, claims, nil +} + +// parseRSAPrivateKey parses PEM encoded RSA private key +func parseRSAPrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, ErrRSAInvalidPEM + } + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, ErrRSAInvalidPrivateKey + } + return key, nil +} + +// parseRSAPublicKey parses PEM encoded RSA public key +func parseRSAPublicKey(pemBytes []byte) (*rsa.PublicKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, ErrRSAInvalidPEM + } + pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, ErrRSAInvalidPublicKey + } + pubKey, ok := pubInterface.(*rsa.PublicKey) + if !ok { + return nil, ErrRSANotPublicKey + } + return pubKey, nil +} \ No newline at end of file diff --git a/jwt_test.go b/jwt_test.go new file mode 100644 index 0000000..56ec30d --- /dev/null +++ b/jwt_test.go @@ -0,0 +1,176 @@ +// FILE: auth/jwt_test.go +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJWTHS256(t *testing.T) { + auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err) + + userID := "user123" + claims := map[string]any{ + "email": "test@example.com", + "role": "admin", + } + + // Generate token + token, err := auth.GenerateToken(userID, claims) + require.NoError(t, err, "Failed to generate token") + assert.NotEmpty(t, token) + + // Validate token + extractedUserID, extractedClaims, err := auth.ValidateToken(token) + require.NoError(t, err, "Failed to validate token") + + assert.Equal(t, userID, extractedUserID) + assert.Equal(t, "test@example.com", extractedClaims["email"]) + assert.Equal(t, "admin", extractedClaims["role"]) + + // Test invalid token + _, _, err = auth.ValidateToken("invalid.token.here") + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrTokenInvalidJSON)) + + // Test tampered token + parts := strings.Split(token, ".") + require.Len(t, parts, 3, "JWT should have 3 parts") + + tampered := parts[0] + "." + parts[1] + ".invalidsignature" + _, _, err = auth.ValidateToken(tampered) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrTokenInvalidSignature)) + + // Test reserved claims cannot be overridden + overrideClaims := map[string]any{ + "sub": "override", + "iat": 12345, + "exp": 67890, + "nbf": 11111, + "iss": "attacker", + "aud": "victim", + "jti": "fake", + } + + token, err = auth.GenerateToken(userID, overrideClaims) + require.NoError(t, err) + + extractedUserID, extractedClaims, err = auth.ValidateToken(token) + require.NoError(t, err) + + assert.Equal(t, userID, extractedUserID, "UserID should not be overridden") + assert.NotEqual(t, 12345, extractedClaims["iat"], "iat should not be overridden") + assert.NotEqual(t, 67890, extractedClaims["exp"], "exp should not be overridden") + assert.NotContains(t, extractedClaims, "nbf", "nbf should not be added from user claims") + assert.NotContains(t, extractedClaims, "iss", "iss should not be added from user claims") +} + +func TestJWTRS256(t *testing.T) { + // Generate RSA key pair + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + // Test with private key (can sign and verify) + authPriv, err := NewAuthenticator(privateKey, "RS256") + require.NoError(t, err) + + userID := "user456" + claims := map[string]any{ + "email": "rs256@example.com", + "scope": "read:all", + } + + // Generate token + token, err := authPriv.GenerateToken(userID, claims) + require.NoError(t, err) + assert.NotEmpty(t, token) + + // Validate token with private key auth (has public key too) + extractedUserID, extractedClaims, err := authPriv.ValidateToken(token) + require.NoError(t, err) + + assert.Equal(t, userID, extractedUserID) + assert.Equal(t, "rs256@example.com", extractedClaims["email"]) + assert.Equal(t, "read:all", extractedClaims["scope"]) + + // Test with public key only (can only verify) + authPub, err := NewAuthenticator(&privateKey.PublicKey, "RS256") + require.NoError(t, err) + + // Should be able to validate token + extractedUserID, extractedClaims, err = authPub.ValidateToken(token) + require.NoError(t, err) + assert.Equal(t, userID, extractedUserID) + + // Should not be able to generate token + _, err = authPub.GenerateToken(userID, claims) + assert.Error(t, err, "Public key only auth should not generate tokens") + assert.Equal(t, ErrTokenNoPrivateKey, err) + + // Test algorithm mismatch + authHS256, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err) + + _, _, err = authHS256.ValidateToken(token) + assert.Error(t, err, "HS256 auth should not validate RS256 token") + // assert.True(t, errors.Is(err, ErrInvalidToken)) + assert.True(t, errors.Is(err, ErrTokenAlgorithmMismatch)) + + fmt.Println(err) +} + +func TestExpiredToken(t *testing.T) { + auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err) + + userID := "user123" + + // Generate normal token (should have 7 days expiry) + token, err := auth.GenerateToken(userID, nil) + require.NoError(t, err) + + _, extractedClaims, err := auth.ValidateToken(token) + require.NoError(t, err) + + // Check expiry is in future (approximately 7 days) + expiry := extractedClaims["exp"].(float64) + now := time.Now().Unix() + + assert.Greater(t, expiry, float64(now), "Token expiry should be in future") + assert.InDelta(t, expiry, float64(now+7*24*60*60), 10, + "Token expiry should be approximately 7 days from now") +} + +func TestCorruptJWTParts(t *testing.T) { + auth, err := NewAuthenticator([]byte("test-secret-key-must-be-32-bytes")) + require.NoError(t, err) + + // Test with missing parts + _, _, err = auth.ValidateToken("only.two") + assert.True(t, errors.Is(err, ErrTokenMalformed)) + + // Test with invalid header encoding + _, _, err = auth.ValidateToken("not-base64!.valid.valid") + assert.Error(t, err) + + // Test with invalid claims encoding + validHeader := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" + _, _, err = auth.ValidateToken(validHeader + ".not-base64!.valid") + assert.Error(t, err) + + // Test with invalid JSON in claims + invalidJSON := base64.RawURLEncoding.EncodeToString([]byte("{invalid json")) + _, _, err = auth.ValidateToken(validHeader + "." + invalidJSON + ".signature") + assert.Error(t, err) +} \ No newline at end of file diff --git a/scram.go b/scram.go new file mode 100644 index 0000000..51e7266 --- /dev/null +++ b/scram.go @@ -0,0 +1,498 @@ +// FILE: auth/scram.go +package auth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "fmt" + "sync" + "sync/atomic" + "time" + + "golang.org/x/crypto/argon2" +) + +// SCRAM-SHA256 implementation + +// Credential stores SCRAM authentication data +type Credential struct { + Username string + Salt []byte + ArgonTime uint32 + ArgonMemory uint32 + ArgonThreads uint8 + StoredKey []byte // SHA256(ClientKey) + ServerKey []byte +} + +// Export returns credential as config-friendly map +func (c *Credential) Export() map[string]any { + return map[string]any{ + "username": c.Username, + "salt": base64.StdEncoding.EncodeToString(c.Salt), + "argon_time": c.ArgonTime, + "argon_memory": c.ArgonMemory, + "argon_threads": c.ArgonThreads, + "stored_key": base64.StdEncoding.EncodeToString(c.StoredKey), + "server_key": base64.StdEncoding.EncodeToString(c.ServerKey), + } +} + +// ImportCredential creates credential from map +func ImportCredential(data map[string]any) (*Credential, error) { + username, ok := data["username"].(string) + if !ok { + return nil, ErrCredMissingUsername + } + + saltStr, ok := data["salt"].(string) + if !ok { + return nil, ErrCredMissingSalt + } + salt, err := base64.StdEncoding.DecodeString(saltStr) + if err != nil { + return nil, ErrCredInvalidSalt + } + + // Handle both float64 (from JSON) and int types + getUint32 := func(key string) (uint32, error) { + val, ok := data[key] + if !ok { + switch key { + case "argon_time": + return 0, ErrCredMissingTime + case "argon_memory": + return 0, ErrCredMissingMemory + default: + return 0, fmt.Errorf("missing %s", key) + } + } + switch v := val.(type) { + case float64: + return uint32(v), nil + case int: + return uint32(v), nil + case uint32: + return v, nil + default: + return 0, fmt.Errorf("invalid type for %s", key) + } + } + + argonTime, err := getUint32("argon_time") + if err != nil { + return nil, err + } + + argonMemory, err := getUint32("argon_memory") + if err != nil { + return nil, err + } + + threadsVal, ok := data["argon_threads"] + if !ok { + return nil, ErrCredMissingThreads + } + var argonThreads uint8 + switch v := threadsVal.(type) { + case float64: + argonThreads = uint8(v) + case int: + argonThreads = uint8(v) + case uint8: + argonThreads = v + default: + return nil, fmt.Errorf("%w: argon_threads", ErrCredInvalidType) + } + + storedKeyStr, ok := data["stored_key"].(string) + if !ok { + return nil, ErrCredMissingStoredKey + } + storedKey, err := base64.StdEncoding.DecodeString(storedKeyStr) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrCredInvalidStoredKey, err) + } + + serverKeyStr, ok := data["server_key"].(string) + if !ok { + return nil, ErrCredMissingServerKey + } + serverKey, err := base64.StdEncoding.DecodeString(serverKeyStr) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrCredInvalidServerKey, err) + } + + return &Credential{ + Username: username, + Salt: salt, + ArgonTime: argonTime, + ArgonMemory: argonMemory, + ArgonThreads: argonThreads, + StoredKey: storedKey, + ServerKey: serverKey, + }, nil +} + +// DeriveCredential creates SCRAM credential from password +func DeriveCredential(username, password string, salt []byte, time, memory uint32, threads uint8) (*Credential, error) { + if len(salt) < 16 { + return nil, ErrSCRAMSaltTooShort + } + + // Derive salted password using Argon2id + saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen) + + // Derive keys + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + storedKey := sha256.Sum256(clientKey) + + return &Credential{ + Username: username, + Salt: salt, + ArgonTime: time, + ArgonMemory: memory, + ArgonThreads: threads, + StoredKey: storedKey[:], + ServerKey: serverKey, + }, nil +} + +// ScramServer handles server-side SCRAM authentication +type ScramServer struct { + credentials map[string]*Credential + handshakes map[string]*HandshakeState + mu sync.RWMutex +} + +// HandshakeState tracks ongoing authentication +type HandshakeState struct { + Username string + ClientNonce string + ServerNonce string + FullNonce string + Credential *Credential + CreatedAt time.Time + verifying int32 // Atomic flag to prevent race during verification +} + +// NewScramServer creates SCRAM server +func NewScramServer() *ScramServer { + return &ScramServer{ + credentials: make(map[string]*Credential), + handshakes: make(map[string]*HandshakeState), + } +} + +// AddCredential registers user credential +func (s *ScramServer) AddCredential(cred *Credential) { + s.mu.Lock() + defer s.mu.Unlock() + s.credentials[cred.Username] = cred +} + +// ProcessClientFirstMessage processes initial auth request +func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (ServerFirstMessage, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Check if user exists + cred, exists := s.credentials[username] + if !exists { + // Prevent user enumeration - still generate response + salt := make([]byte, 16) + rand.Read(salt) + serverNonce := generateNonce() + + return ServerFirstMessage{ + FullNonce: clientNonce + serverNonce, + Salt: base64.StdEncoding.EncodeToString(salt), + ArgonTime: DefaultArgonTime, + ArgonMemory: DefaultArgonMemory, + ArgonThreads: DefaultArgonThreads, + }, ErrInvalidCredentials + } + + // Generate server nonce + serverNonce := generateNonce() + fullNonce := clientNonce + serverNonce + + // Store handshake state + state := &HandshakeState{ + Username: username, + ClientNonce: clientNonce, + ServerNonce: serverNonce, + FullNonce: fullNonce, + Credential: cred, + CreatedAt: time.Now(), + verifying: 0, + } + s.handshakes[fullNonce] = state + + // Cleanup old handshakes + s.cleanupHandshakes() + + return ServerFirstMessage{ + FullNonce: fullNonce, + Salt: base64.StdEncoding.EncodeToString(cred.Salt), + ArgonTime: cred.ArgonTime, + ArgonMemory: cred.ArgonMemory, + ArgonThreads: cred.ArgonThreads, + }, nil +} + +// ProcessClientFinalMessage verifies client proof +func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (ServerFinalMessage, error) { + s.mu.RLock() + state, exists := s.handshakes[fullNonce] + s.mu.RUnlock() + + if !exists { + return ServerFinalMessage{}, ErrSCRAMInvalidNonce + } + + // Mark as verifying to prevent deletion race + if !atomic.CompareAndSwapInt32(&state.verifying, 0, 1) { + return ServerFinalMessage{}, ErrSCRAMVerifyInProgress + } + defer func() { + atomic.StoreInt32(&state.verifying, 0) + // Safe to delete after verification completes + s.mu.Lock() + delete(s.handshakes, fullNonce) + s.mu.Unlock() + }() + + // Check timeout + if time.Since(state.CreatedAt) > 60*time.Second { + return ServerFinalMessage{}, ErrSCRAMTimeout + } + + // Decode client proof + clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof) + if err != nil { + return ServerFinalMessage{}, ErrSCRAMInvalidProof + } + + // Build auth message + clientFirstBare := fmt.Sprintf("u=%s,n=%s", state.Username, state.ClientNonce) + serverFirst := ServerFirstMessage{ + FullNonce: state.FullNonce, + Salt: base64.StdEncoding.EncodeToString(state.Credential.Salt), + ArgonTime: state.Credential.ArgonTime, + ArgonMemory: state.Credential.ArgonMemory, + ArgonThreads: state.Credential.ArgonThreads, + } + clientFinalBare := fmt.Sprintf("r=%s", fullNonce) + authMessage := clientFirstBare + "," + serverFirst.Marshal() + "," + clientFinalBare + + // Compute client signature + clientSignature := computeHMAC(state.Credential.StoredKey, []byte(authMessage)) + + // XOR to get ClientKey + if len(clientProofBytes) != len(clientSignature) { + return ServerFinalMessage{}, ErrSCRAMInvalidProofLen + } + clientKey := xorBytes(clientProofBytes, clientSignature) + + // Verify by computing StoredKey + computedStoredKey := sha256.Sum256(clientKey) + if subtle.ConstantTimeCompare(computedStoredKey[:], state.Credential.StoredKey) != 1 { + return ServerFinalMessage{}, ErrInvalidCredentials + } + + // Generate server signature for mutual auth + serverSignature := computeHMAC(state.Credential.ServerKey, []byte(authMessage)) + + return ServerFinalMessage{ + ServerSignature: base64.StdEncoding.EncodeToString(serverSignature), + Username: state.Username, + }, nil +} + +func (s *ScramServer) cleanupHandshakes() { + cutoff := time.Now().Add(-60 * time.Second) + for nonce, state := range s.handshakes { + if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 { + delete(s.handshakes, nonce) + } + } +} + +// ScramClient handles client-side SCRAM authentication +type ScramClient struct { + Username string + Password string + clientNonce string + serverFirst *ServerFirstMessage + authMessage string + serverKey []byte + startTime time.Time // Track handshake start +} + +// NewScramClient creates SCRAM client +func NewScramClient(username, password string) *ScramClient { + return &ScramClient{ + Username: username, + Password: password, + } +} + +// StartAuthentication generates initial client message +func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) { + c.startTime = time.Now() + + // Generate client nonce + nonce := make([]byte, 32) + if _, err := rand.Read(nonce); err != nil { + return ClientFirstRequest{}, ErrSCRAMNonceGenFailed + } + c.clientNonce = base64.StdEncoding.EncodeToString(nonce) + + return ClientFirstRequest{ + Username: c.Username, + ClientNonce: c.clientNonce, + }, nil +} + +// ProcessServerFirstMessage handles server challenge +func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) { + // Check timeout (30 seconds) + if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second { + return ClientFinalRequest{}, ErrSCRAMTimeout + } + + c.serverFirst = &msg + + // Handle enumeration prevention - server may send fake response + // We still process it normally and let verification fail later + + // Decode salt + salt, err := base64.StdEncoding.DecodeString(msg.Salt) + if err != nil { + return ClientFinalRequest{}, ErrSCRAMInvalidSalt + } + + // Validate parameters + if msg.ArgonTime == 0 || msg.ArgonMemory == 0 || msg.ArgonThreads == 0 { + return ClientFinalRequest{}, ErrSCRAMZeroParams + } + + // Derive keys using Argon2id + saltedPassword := argon2.IDKey([]byte(c.Password), salt, msg.ArgonTime, msg.ArgonMemory, msg.ArgonThreads, 32) + + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + storedKey := sha256.Sum256(clientKey) + + // Build auth message + clientFirstBare := fmt.Sprintf("u=%s,n=%s", c.Username, c.clientNonce) + clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce) + c.authMessage = clientFirstBare + "," + msg.Marshal() + "," + clientFinalBare + + // Compute client proof + clientSignature := computeHMAC(storedKey[:], []byte(c.authMessage)) + clientProof := xorBytes(clientKey, clientSignature) + + // Store server key for verification + c.serverKey = serverKey + + return ClientFinalRequest{ + FullNonce: msg.FullNonce, + ClientProof: base64.StdEncoding.EncodeToString(clientProof), + }, nil +} + +// VerifyServerFinalMessage validates server signature +func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error { + // Check timeout + if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second { + return ErrSCRAMTimeout + } + + if c.authMessage == "" || c.serverKey == nil { + return ErrSCRAMInvalidState + } + + // Compute expected server signature + expectedSig := computeHMAC(c.serverKey, []byte(c.authMessage)) + + // Decode received signature + receivedSig, err := base64.StdEncoding.DecodeString(msg.ServerSignature) + if err != nil { + return ErrSCRAMServerAuthFailed + } + + // Constant-time comparison + if subtle.ConstantTimeCompare(expectedSig, receivedSig) != 1 { + return ErrSCRAMServerAuthFailed + } + + return nil +} + +// Reset clears client state for retry +func (c *ScramClient) Reset() { + c.clientNonce = "" + c.serverFirst = nil + c.authMessage = "" + c.serverKey = nil + c.startTime = time.Time{} +} + +// SCRAM message types +type ClientFirstRequest struct { + Username string `json:"username"` + ClientNonce string `json:"client_nonce"` +} + +type ServerFirstMessage struct { + FullNonce string `json:"full_nonce"` + Salt string `json:"salt"` + ArgonTime uint32 `json:"argon_time"` + ArgonMemory uint32 `json:"argon_memory"` + ArgonThreads uint8 `json:"argon_threads"` +} + +func (s ServerFirstMessage) Marshal() string { + return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d", + s.FullNonce, s.Salt, s.ArgonTime, s.ArgonMemory, s.ArgonThreads) +} + +type ClientFinalRequest struct { + FullNonce string `json:"full_nonce"` + ClientProof string `json:"client_proof"` +} + +type ServerFinalMessage struct { + ServerSignature string `json:"server_signature"` + Username string `json:"username,omitempty"` +} + +// Helper functions +func computeHMAC(key, message []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(message) + return mac.Sum(nil) +} + +func xorBytes(a, b []byte) []byte { + if len(a) != len(b) { + panic("xor length mismatch") + } + result := make([]byte, len(a)) + for i := range a { + result[i] = a[i] ^ b[i] + } + return result +} + +func generateNonce() string { + b := make([]byte, 32) + rand.Read(b) + return base64.StdEncoding.EncodeToString(b) +} \ No newline at end of file diff --git a/token.go b/token.go new file mode 100644 index 0000000..b2c23cf --- /dev/null +++ b/token.go @@ -0,0 +1,50 @@ +// FILE: auth/token.go +package auth + +import ( + "crypto/subtle" + "sync" +) + +// SimpleTokenValidator implements in-memory token validation +type SimpleTokenValidator struct { + tokens map[string]struct{} + mu sync.RWMutex +} + +// NewSimpleTokenValidator creates token validator +func NewSimpleTokenValidator() *SimpleTokenValidator { + return &SimpleTokenValidator{ + tokens: make(map[string]struct{}), + } +} + +// ValidateToken checks if token is valid +func (v *SimpleTokenValidator) ValidateToken(token string) bool { + v.mu.RLock() + defer v.mu.RUnlock() + + // Constant-time comparison for each stored token + for storedToken := range v.tokens { + if subtle.ConstantTimeEq(int32(len(token)), int32(len(storedToken))) == 1 { + if subtle.ConstantTimeCompare([]byte(token), []byte(storedToken)) == 1 { + return true + } + } + } + return false +} + +// AddToken adds token to validator +func (v *SimpleTokenValidator) AddToken(token string) { + v.mu.Lock() + defer v.mu.Unlock() + v.tokens[token] = struct{}{} +} + +// RemoveToken removes token from validator +func (v *SimpleTokenValidator) RemoveToken(token string) { + v.mu.Lock() + defer v.mu.Unlock() + delete(v.tokens, token) +} \ No newline at end of file diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..6d25017 --- /dev/null +++ b/token_test.go @@ -0,0 +1,81 @@ +// FILE: auth/token_test.go +package auth + +import ( + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSimpleTokenValidator(t *testing.T) { + validator := NewSimpleTokenValidator() + + token1 := "test-token-123" + token2 := "test-token-456" + + // Add tokens + validator.AddToken(token1) + validator.AddToken(token2) + + // Validate existing tokens + assert.True(t, validator.ValidateToken(token1)) + assert.True(t, validator.ValidateToken(token2)) + + // Invalid token + assert.False(t, validator.ValidateToken("invalid-token")) + + // Remove token + validator.RemoveToken(token1) + assert.False(t, validator.ValidateToken(token1)) + assert.True(t, validator.ValidateToken(token2)) +} + +func TestConcurrentTokenValidator(t *testing.T) { + validator := NewSimpleTokenValidator() + + // Add tokens concurrently + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + token := fmt.Sprintf("token-%d", idx) + validator.AddToken(token) + }(i) + } + wg.Wait() + + // Validate concurrently + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + token := fmt.Sprintf("token-%d", idx) + assert.True(t, validator.ValidateToken(token)) + }(i) + } + wg.Wait() + + // Remove concurrently + for i := 0; i < 50; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + token := fmt.Sprintf("token-%d", idx) + validator.RemoveToken(token) + }(i) + } + wg.Wait() + + // Verify removal + for i := 0; i < 50; i++ { + token := fmt.Sprintf("token-%d", i) + assert.False(t, validator.ValidateToken(token)) + } + for i := 50; i < 100; i++ { + token := fmt.Sprintf("token-%d", i) + assert.True(t, validator.ValidateToken(token)) + } +} \ No newline at end of file