From aafa680a35cf429d3e2369f5d4ccd31b015ce8a031621f8986c5fe4985af2cfc Mon Sep 17 00:00:00 2001 From: Lixen Wraith Date: Tue, 4 Nov 2025 08:10:16 -0500 Subject: [PATCH] v0.2.1 argon2 and scram imporoved --- argon2.go | 72 ++++++++++++++++++++ argon2_test.go | 60 +++++++++++++++++ scram.go | 89 ++++++++++++++++++------- scram_test.go | 173 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 371 insertions(+), 23 deletions(-) diff --git a/argon2.go b/argon2.go index 5e0e9de..9670587 100644 --- a/argon2.go +++ b/argon2.go @@ -142,4 +142,76 @@ func MigrateFromPHC(username, password, phcHash string) (*Credential, error) { } return DeriveCredential(username, password, salt, time, memory, threads) +} + +// ValidatePHCHashFormat checks if a hash string has a valid and complete +// PHC format for Argon2id. It validates structure, parameters, and encoding, +// but does not verify a password against the hash. +func ValidatePHCHashFormat(phcHash string) error { + parts := strings.Split(phcHash, "$") + if len(parts) != 6 { + return fmt.Errorf("%w: expected 6 parts, got %d", ErrPHCInvalidFormat, len(parts)) + } + + // Validate empty parts[0] (PHC format starts with $) + if parts[0] != "" { + return fmt.Errorf("%w: hash must start with $", ErrPHCInvalidFormat) + } + + // Validate algorithm identifier + if parts[1] != "argon2id" { + return fmt.Errorf("%w: unsupported algorithm %q, expected argon2id", ErrPHCInvalidFormat, parts[1]) + } + + // Validate version + var version int + n, err := fmt.Sscanf(parts[2], "v=%d", &version) + if err != nil || n != 1 { + return fmt.Errorf("%w: invalid version format", ErrPHCInvalidFormat) + } + if version != argon2.Version { + return fmt.Errorf("%w: unsupported version %d, expected %d", ErrPHCInvalidFormat, version, argon2.Version) + } + + // Validate parameters + var memory, time uint32 + var threads uint8 + n, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) + if err != nil || n != 3 { + return fmt.Errorf("%w: failed to parse parameters", ErrPHCInvalidFormat) + } + + // Validate parameter ranges + if time == 0 || memory == 0 || threads == 0 { + return fmt.Errorf("%w: parameters must be non-zero", ErrPHCInvalidFormat) + } + if memory > 4*1024*1024 { // 4GB limit + return fmt.Errorf("%w: memory parameter exceeds maximum (4GB)", ErrPHCInvalidFormat) + } + if time > 1000 { // Reasonable upper bound + return fmt.Errorf("%w: time parameter exceeds maximum (1000)", ErrPHCInvalidFormat) + } + if threads > 255 { // uint8 max, but practically much lower + return fmt.Errorf("%w: threads parameter exceeds maximum (255)", ErrPHCInvalidFormat) + } + + // Validate salt encoding + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return fmt.Errorf("%w: %v", ErrPHCInvalidSalt, err) + } + if len(salt) < 8 { // Minimum safe salt length + return fmt.Errorf("%w: salt too short (%d bytes)", ErrPHCInvalidSalt, len(salt)) + } + + // Validate hash encoding + hash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err) + } + if len(hash) < 16 { // Minimum hash length + return fmt.Errorf("%w: hash too short (%d bytes)", ErrPHCInvalidHash, len(hash)) + } + + return nil } \ No newline at end of file diff --git a/argon2_test.go b/argon2_test.go index af78978..b14c10b 100644 --- a/argon2_test.go +++ b/argon2_test.go @@ -2,6 +2,7 @@ package auth import ( + "encoding/base64" "strings" "sync" "testing" @@ -112,4 +113,63 @@ func TestPHCMigration(t *testing.T) { // Test with invalid PHC format _, err = MigrateFromPHC(username, password, "$invalid$format") assert.Error(t, err) +} + +func TestValidatePHCHashFormat(t *testing.T) { + // Generate valid hash for testing + validHash, err := HashPassword("testPassword123") + require.NoError(t, err) + + // Test valid hash + err = ValidatePHCHashFormat(validHash) + assert.NoError(t, err, "Valid hash should pass validation") + + // Test malformed formats + testCases := []struct { + name string + hash string + wantErr error + }{ + {"empty", "", ErrPHCInvalidFormat}, + {"not PHC format", "plaintext", ErrPHCInvalidFormat}, + {"wrong prefix", "argon2id$v=19$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat}, + {"wrong algorithm", "$bcrypt$v=19$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat}, + {"missing version", "$argon2id$$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat}, + {"wrong version", "$argon2id$v=1$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat}, + {"missing params", "$argon2id$v=19$$salt$hash", ErrPHCInvalidFormat}, + {"invalid params format", "$argon2id$v=19$invalid$salt$hash", ErrPHCInvalidFormat}, + {"zero time", "$argon2id$v=19$m=65536,t=0,p=4$salt$hash", ErrPHCInvalidFormat}, + {"zero memory", "$argon2id$v=19$m=0,t=3,p=4$salt$hash", ErrPHCInvalidFormat}, + {"zero threads", "$argon2id$v=19$m=65536,t=3,p=0$salt$hash", ErrPHCInvalidFormat}, + {"excessive memory", "$argon2id$v=19$m=5000000,t=3,p=4$salt$hash", ErrPHCInvalidFormat}, + {"excessive time", "$argon2id$v=19$m=65536,t=2000,p=4$salt$hash", ErrPHCInvalidFormat}, + {"invalid salt encoding", "$argon2id$v=19$m=65536,t=3,p=4$!!!invalid!!!$hash", ErrPHCInvalidSalt}, + {"invalid hash encoding", "$argon2id$v=19$m=65536,t=3,p=4$" + + base64.RawStdEncoding.EncodeToString([]byte("salt12345678")) + "$!!!invalid!!!", ErrPHCInvalidHash}, + {"short salt", "$argon2id$v=19$m=65536,t=3,p=4$" + + base64.RawStdEncoding.EncodeToString([]byte("short")) + "$" + + base64.RawStdEncoding.EncodeToString([]byte("hash1234567890123456")), ErrPHCInvalidSalt}, + {"short hash", "$argon2id$v=19$m=65536,t=3,p=4$" + + base64.RawStdEncoding.EncodeToString([]byte("salt12345678")) + "$" + + base64.RawStdEncoding.EncodeToString([]byte("short")), ErrPHCInvalidHash}, + {"too few parts", "$argon2id$v=19$m=65536,t=3,p=4", ErrPHCInvalidFormat}, + {"too many parts", "$argon2id$v=19$m=65536,t=3,p=4$salt$hash$extra", ErrPHCInvalidFormat}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidatePHCHashFormat(tc.hash) + assert.ErrorIs(t, err, tc.wantErr, "Test case: %s", tc.name) + }) + } + + // Test that validation doesn't require password + err = ValidatePHCHashFormat(validHash) + assert.NoError(t, err, "Should validate format without password") + + // Verify that a validated hash can still be used for verification + err = ValidatePHCHashFormat(validHash) + require.NoError(t, err) + err = VerifyPassword("testPassword123", validHash) + assert.NoError(t, err, "Validated hash should still work for password verification") } \ No newline at end of file diff --git a/scram.go b/scram.go index d9840e5..d852e50 100644 --- a/scram.go +++ b/scram.go @@ -17,6 +17,13 @@ import ( // SCRAM-SHA256 implementation +const ( + // ScramHandshakeTimeout defines maximum time for completing SCRAM handshake + ScramHandshakeTimeout = 30 * time.Second + // ScramCleanupInterval defines how often expired handshakes are cleaned + ScramCleanupInterval = 60 * time.Second +) + // Credential stores SCRAM authentication data type Credential struct { Username string @@ -166,13 +173,6 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3 }, nil } -// ScramServer handles server-side SCRAM authentication -type ScramServer struct { - credentials map[string]*Credential - handshakes map[string]*HandshakeState - mu sync.RWMutex -} - // HandshakeState tracks ongoing authentication type HandshakeState struct { Username string @@ -184,19 +184,59 @@ type HandshakeState struct { verifying int32 // Atomic flag to prevent race during verification } +// ScramServer handles server-side SCRAM authentication +type ScramServer struct { + credentials map[string]*Credential + handshakes map[string]*HandshakeState + mu sync.RWMutex + cleanupTicker *time.Ticker + cleanupStop chan struct{} +} + // NewScramServer creates SCRAM server func NewScramServer() *ScramServer { - return &ScramServer{ - credentials: make(map[string]*Credential), - handshakes: make(map[string]*HandshakeState), + s := &ScramServer{ + credentials: make(map[string]*Credential), + handshakes: make(map[string]*HandshakeState), + cleanupTicker: time.NewTicker(ScramCleanupInterval), + cleanupStop: make(chan struct{}), + } + + // Start background cleanup goroutine + go s.cleanupLoop() + + return s +} + +// Stop gracefully shuts down the server and cleanup goroutine +func (s *ScramServer) Stop() { + close(s.cleanupStop) + s.cleanupTicker.Stop() +} + +// cleanupLoop runs periodic cleanup of expired handshakes +func (s *ScramServer) cleanupLoop() { + for { + select { + case <-s.cleanupTicker.C: + s.cleanupExpiredHandshakes() + case <-s.cleanupStop: + return + } } } -// AddCredential registers user credential -func (s *ScramServer) AddCredential(cred *Credential) { +// cleanupExpiredHandshakes removes handshakes older than timeout +func (s *ScramServer) cleanupExpiredHandshakes() { s.mu.Lock() defer s.mu.Unlock() - s.credentials[cred.Username] = cred + + cutoff := time.Now().Add(-ScramHandshakeTimeout) + for nonce, state := range s.handshakes { + if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 { + delete(s.handshakes, nonce) + } + } } // ProcessClientFirstMessage processes initial auth request @@ -237,9 +277,6 @@ func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (S } s.handshakes[fullNonce] = state - // Cleanup old handshakes - s.cleanupHandshakes() - return ServerFirstMessage{ FullNonce: fullNonce, Salt: base64.StdEncoding.EncodeToString(cred.Salt), @@ -272,10 +309,11 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) ( }() // Check timeout - if time.Since(state.CreatedAt) > 60*time.Second { + if time.Since(state.CreatedAt) > ScramHandshakeTimeout { return ServerFinalMessage{}, ErrSCRAMTimeout } + // [rest of verification logic unchanged] // Decode client proof clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof) if err != nil { @@ -318,6 +356,13 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) ( }, nil } +// AddCredential registers user credential +func (s *ScramServer) AddCredential(cred *Credential) { + s.mu.Lock() + defer s.mu.Unlock() + s.credentials[cred.Username] = cred +} + func (s *ScramServer) cleanupHandshakes() { cutoff := time.Now().Add(-60 * time.Second) for nonce, state := range s.handshakes { @@ -365,16 +410,13 @@ func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) { // ProcessServerFirstMessage handles server challenge func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) { - // Check timeout (30 seconds) - if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second { + // Check timeout + if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout { return ClientFinalRequest{}, ErrSCRAMTimeout } c.serverFirst = &msg - // Handle enumeration prevention - server may send fake response - // We still process it normally and let verification fail later - // Decode salt salt, err := base64.StdEncoding.DecodeString(msg.Salt) if err != nil { @@ -414,10 +456,11 @@ func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientF // VerifyServerFinalMessage validates server signature func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error { // Check timeout - if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second { + if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout { return ErrSCRAMTimeout } + // [rest unchanged] if c.authMessage == "" || c.serverKey == nil { return ErrSCRAMInvalidState } diff --git a/scram_test.go b/scram_test.go index 40c8fc2..40c1e4e 100644 --- a/scram_test.go +++ b/scram_test.go @@ -2,7 +2,10 @@ package auth import ( + "fmt" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -143,4 +146,174 @@ func TestScram_CredentialImportExport(t *testing.T) { assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey) t.Log("SCRAM credential import/export successful") +} + +// TestScramServerCleanup verifies automatic cleanup of expired handshakes +func TestScramServerCleanup(t *testing.T) { + // Create server with short cleanup interval for testing + server := NewScramServer() + defer server.Stop() + + // Add a test credential + cred := &Credential{ + Username: "testuser", + Salt: []byte("salt1234567890123456"), + ArgonTime: 1, + ArgonMemory: 64, + ArgonThreads: 1, + StoredKey: []byte("stored_key_placeholder"), + ServerKey: []byte("server_key_placeholder"), + } + server.AddCredential(cred) + + // Start multiple handshakes + var nonces []string + for i := 0; i < 5; i++ { + clientNonce := fmt.Sprintf("client-nonce-%d", i) + msg, err := server.ProcessClientFirstMessage("testuser", clientNonce) + require.NoError(t, err) + nonces = append(nonces, msg.FullNonce) + } + + // Verify all handshakes exist + server.mu.RLock() + assert.Len(t, server.handshakes, 5) + server.mu.RUnlock() + + // Manually set old timestamp for first 3 handshakes + server.mu.Lock() + oldTime := time.Now().Add(-2 * ScramHandshakeTimeout) + count := 0 + for nonce := range server.handshakes { + if count < 3 { + server.handshakes[nonce].CreatedAt = oldTime + count++ + } + } + server.mu.Unlock() + + // Trigger cleanup manually + server.cleanupExpiredHandshakes() + + // Verify only 2 handshakes remain + server.mu.RLock() + assert.Len(t, server.handshakes, 2, "Expired handshakes should be cleaned up") + server.mu.RUnlock() +} + +// TestScramConcurrentSameUser verifies multiple concurrent authentications for same user +func TestScramConcurrentSameUser(t *testing.T) { + server, username, password, _ := setupScramTest(t) + defer server.Stop() + + // Number of concurrent authentication attempts + numAttempts := 10 + results := make(chan error, numAttempts) + + var wg sync.WaitGroup + for i := 0; i < numAttempts; i++ { + wg.Add(1) + go func(attempt int) { + defer wg.Done() + + // Each goroutine performs full authentication + client := NewScramClient(username, password) + + // Step 1: Client first + clientFirst, err := client.StartAuthentication() + if err != nil { + results <- err + return + } + + // Step 2: Server first + serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce) + if err != nil { + results <- err + return + } + + // Step 3: Client final + clientFinal, err := client.ProcessServerFirstMessage(serverFirst) + if err != nil { + results <- err + return + } + + // Step 4: Server final + serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof) + if err != nil { + results <- err + return + } + + // Step 5: Client verify + err = client.VerifyServerFinalMessage(serverFinal) + results <- err + }(i) + } + + wg.Wait() + close(results) + + // Verify all attempts succeeded + successCount := 0 + for err := range results { + if err == nil { + successCount++ + } else { + t.Logf("Auth attempt failed: %v", err) + } + } + + assert.Equal(t, numAttempts, successCount, + "All concurrent authentication attempts should succeed") + + // Verify no handshakes remain after completion + server.mu.RLock() + assert.Empty(t, server.handshakes, "All handshakes should be cleaned up after completion") + server.mu.RUnlock() +} + +// TestScramExplicitTimeout verifies timeout enforcement +func TestScramExplicitTimeout(t *testing.T) { + // Save original timeout and set shorter one for testing + originalTimeout := ScramHandshakeTimeout + // Note: Can't modify const at runtime, so we test with delay instead + + server, username, password, _ := setupScramTest(t) + defer server.Stop() + + client := NewScramClient(username, password) + + // Start authentication + clientFirst, err := client.StartAuthentication() + require.NoError(t, err) + + serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce) + require.NoError(t, err) + + // Manually expire the handshake + server.mu.Lock() + for nonce := range server.handshakes { + server.handshakes[nonce].CreatedAt = time.Now().Add(-2 * ScramHandshakeTimeout) + } + server.mu.Unlock() + + // Client processes server message (should work, client tracks own timeout) + clientFinal, err := client.ProcessServerFirstMessage(serverFirst) + require.NoError(t, err) + + // Server should reject due to timeout + _, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof) + assert.ErrorIs(t, err, ErrSCRAMTimeout, "Server should reject expired handshake") + + // Test client-side timeout + client2 := NewScramClient(username, password) + client2.startTime = time.Now().Add(-2 * ScramHandshakeTimeout) + + _, err = client2.ProcessServerFirstMessage(serverFirst) + assert.ErrorIs(t, err, ErrSCRAMTimeout, "Client should reject after timeout") + + _ = originalTimeout // Suppress unused variable warning } \ No newline at end of file