v0.2.1 argon2 and scram imporoved
This commit is contained in:
173
scram_test.go
173
scram_test.go
@ -2,7 +2,10 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -143,4 +146,174 @@ func TestScram_CredentialImportExport(t *testing.T) {
|
||||
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
|
||||
|
||||
t.Log("SCRAM credential import/export successful")
|
||||
}
|
||||
|
||||
// TestScramServerCleanup verifies automatic cleanup of expired handshakes
|
||||
func TestScramServerCleanup(t *testing.T) {
|
||||
// Create server with short cleanup interval for testing
|
||||
server := NewScramServer()
|
||||
defer server.Stop()
|
||||
|
||||
// Add a test credential
|
||||
cred := &Credential{
|
||||
Username: "testuser",
|
||||
Salt: []byte("salt1234567890123456"),
|
||||
ArgonTime: 1,
|
||||
ArgonMemory: 64,
|
||||
ArgonThreads: 1,
|
||||
StoredKey: []byte("stored_key_placeholder"),
|
||||
ServerKey: []byte("server_key_placeholder"),
|
||||
}
|
||||
server.AddCredential(cred)
|
||||
|
||||
// Start multiple handshakes
|
||||
var nonces []string
|
||||
for i := 0; i < 5; i++ {
|
||||
clientNonce := fmt.Sprintf("client-nonce-%d", i)
|
||||
msg, err := server.ProcessClientFirstMessage("testuser", clientNonce)
|
||||
require.NoError(t, err)
|
||||
nonces = append(nonces, msg.FullNonce)
|
||||
}
|
||||
|
||||
// Verify all handshakes exist
|
||||
server.mu.RLock()
|
||||
assert.Len(t, server.handshakes, 5)
|
||||
server.mu.RUnlock()
|
||||
|
||||
// Manually set old timestamp for first 3 handshakes
|
||||
server.mu.Lock()
|
||||
oldTime := time.Now().Add(-2 * ScramHandshakeTimeout)
|
||||
count := 0
|
||||
for nonce := range server.handshakes {
|
||||
if count < 3 {
|
||||
server.handshakes[nonce].CreatedAt = oldTime
|
||||
count++
|
||||
}
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// Trigger cleanup manually
|
||||
server.cleanupExpiredHandshakes()
|
||||
|
||||
// Verify only 2 handshakes remain
|
||||
server.mu.RLock()
|
||||
assert.Len(t, server.handshakes, 2, "Expired handshakes should be cleaned up")
|
||||
server.mu.RUnlock()
|
||||
}
|
||||
|
||||
// TestScramConcurrentSameUser verifies multiple concurrent authentications for same user
|
||||
func TestScramConcurrentSameUser(t *testing.T) {
|
||||
server, username, password, _ := setupScramTest(t)
|
||||
defer server.Stop()
|
||||
|
||||
// Number of concurrent authentication attempts
|
||||
numAttempts := 10
|
||||
results := make(chan error, numAttempts)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < numAttempts; i++ {
|
||||
wg.Add(1)
|
||||
go func(attempt int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine performs full authentication
|
||||
client := NewScramClient(username, password)
|
||||
|
||||
// Step 1: Client first
|
||||
clientFirst, err := client.StartAuthentication()
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Step 2: Server first
|
||||
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Step 3: Client final
|
||||
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Step 4: Server final
|
||||
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Step 5: Client verify
|
||||
err = client.VerifyServerFinalMessage(serverFinal)
|
||||
results <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Verify all attempts succeeded
|
||||
successCount := 0
|
||||
for err := range results {
|
||||
if err == nil {
|
||||
successCount++
|
||||
} else {
|
||||
t.Logf("Auth attempt failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, numAttempts, successCount,
|
||||
"All concurrent authentication attempts should succeed")
|
||||
|
||||
// Verify no handshakes remain after completion
|
||||
server.mu.RLock()
|
||||
assert.Empty(t, server.handshakes, "All handshakes should be cleaned up after completion")
|
||||
server.mu.RUnlock()
|
||||
}
|
||||
|
||||
// TestScramExplicitTimeout verifies timeout enforcement
|
||||
func TestScramExplicitTimeout(t *testing.T) {
|
||||
// Save original timeout and set shorter one for testing
|
||||
originalTimeout := ScramHandshakeTimeout
|
||||
// Note: Can't modify const at runtime, so we test with delay instead
|
||||
|
||||
server, username, password, _ := setupScramTest(t)
|
||||
defer server.Stop()
|
||||
|
||||
client := NewScramClient(username, password)
|
||||
|
||||
// Start authentication
|
||||
clientFirst, err := client.StartAuthentication()
|
||||
require.NoError(t, err)
|
||||
|
||||
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually expire the handshake
|
||||
server.mu.Lock()
|
||||
for nonce := range server.handshakes {
|
||||
server.handshakes[nonce].CreatedAt = time.Now().Add(-2 * ScramHandshakeTimeout)
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// Client processes server message (should work, client tracks own timeout)
|
||||
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should reject due to timeout
|
||||
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
||||
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Server should reject expired handshake")
|
||||
|
||||
// Test client-side timeout
|
||||
client2 := NewScramClient(username, password)
|
||||
client2.startTime = time.Now().Add(-2 * ScramHandshakeTimeout)
|
||||
|
||||
_, err = client2.ProcessServerFirstMessage(serverFirst)
|
||||
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Client should reject after timeout")
|
||||
|
||||
_ = originalTimeout // Suppress unused variable warning
|
||||
}
|
||||
Reference in New Issue
Block a user