v0.2.1 argon2 and scram imporoved

This commit is contained in:
2025-11-04 08:10:16 -05:00
parent 3471030edd
commit aafa680a35
4 changed files with 371 additions and 23 deletions

View File

@ -2,7 +2,10 @@
package auth
import (
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -143,4 +146,174 @@ func TestScram_CredentialImportExport(t *testing.T) {
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
t.Log("SCRAM credential import/export successful")
}
// TestScramServerCleanup verifies automatic cleanup of expired handshakes
func TestScramServerCleanup(t *testing.T) {
// Create server with short cleanup interval for testing
server := NewScramServer()
defer server.Stop()
// Add a test credential
cred := &Credential{
Username: "testuser",
Salt: []byte("salt1234567890123456"),
ArgonTime: 1,
ArgonMemory: 64,
ArgonThreads: 1,
StoredKey: []byte("stored_key_placeholder"),
ServerKey: []byte("server_key_placeholder"),
}
server.AddCredential(cred)
// Start multiple handshakes
var nonces []string
for i := 0; i < 5; i++ {
clientNonce := fmt.Sprintf("client-nonce-%d", i)
msg, err := server.ProcessClientFirstMessage("testuser", clientNonce)
require.NoError(t, err)
nonces = append(nonces, msg.FullNonce)
}
// Verify all handshakes exist
server.mu.RLock()
assert.Len(t, server.handshakes, 5)
server.mu.RUnlock()
// Manually set old timestamp for first 3 handshakes
server.mu.Lock()
oldTime := time.Now().Add(-2 * ScramHandshakeTimeout)
count := 0
for nonce := range server.handshakes {
if count < 3 {
server.handshakes[nonce].CreatedAt = oldTime
count++
}
}
server.mu.Unlock()
// Trigger cleanup manually
server.cleanupExpiredHandshakes()
// Verify only 2 handshakes remain
server.mu.RLock()
assert.Len(t, server.handshakes, 2, "Expired handshakes should be cleaned up")
server.mu.RUnlock()
}
// TestScramConcurrentSameUser verifies multiple concurrent authentications for same user
func TestScramConcurrentSameUser(t *testing.T) {
server, username, password, _ := setupScramTest(t)
defer server.Stop()
// Number of concurrent authentication attempts
numAttempts := 10
results := make(chan error, numAttempts)
var wg sync.WaitGroup
for i := 0; i < numAttempts; i++ {
wg.Add(1)
go func(attempt int) {
defer wg.Done()
// Each goroutine performs full authentication
client := NewScramClient(username, password)
// Step 1: Client first
clientFirst, err := client.StartAuthentication()
if err != nil {
results <- err
return
}
// Step 2: Server first
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
if err != nil {
results <- err
return
}
// Step 3: Client final
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
if err != nil {
results <- err
return
}
// Step 4: Server final
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
if err != nil {
results <- err
return
}
// Step 5: Client verify
err = client.VerifyServerFinalMessage(serverFinal)
results <- err
}(i)
}
wg.Wait()
close(results)
// Verify all attempts succeeded
successCount := 0
for err := range results {
if err == nil {
successCount++
} else {
t.Logf("Auth attempt failed: %v", err)
}
}
assert.Equal(t, numAttempts, successCount,
"All concurrent authentication attempts should succeed")
// Verify no handshakes remain after completion
server.mu.RLock()
assert.Empty(t, server.handshakes, "All handshakes should be cleaned up after completion")
server.mu.RUnlock()
}
// TestScramExplicitTimeout verifies timeout enforcement
func TestScramExplicitTimeout(t *testing.T) {
// Save original timeout and set shorter one for testing
originalTimeout := ScramHandshakeTimeout
// Note: Can't modify const at runtime, so we test with delay instead
server, username, password, _ := setupScramTest(t)
defer server.Stop()
client := NewScramClient(username, password)
// Start authentication
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err)
// Manually expire the handshake
server.mu.Lock()
for nonce := range server.handshakes {
server.handshakes[nonce].CreatedAt = time.Now().Add(-2 * ScramHandshakeTimeout)
}
server.mu.Unlock()
// Client processes server message (should work, client tracks own timeout)
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err)
// Server should reject due to timeout
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Server should reject expired handshake")
// Test client-side timeout
client2 := NewScramClient(username, password)
client2.startTime = time.Now().Add(-2 * ScramHandshakeTimeout)
_, err = client2.ProcessServerFirstMessage(serverFirst)
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Client should reject after timeout")
_ = originalTimeout // Suppress unused variable warning
}