Files
auth/scram_test.go

319 lines
11 KiB
Go

// FILE: auth/scram_test.go
package auth
import (
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// setupScramTest is a helper to initialize a server and user credential for testing.
// It performs the full Argon2 -> SCRAM migration workflow.
func setupScramTest(t *testing.T) (server *ScramServer, username, password string, cred *Credential) {
username = "testuser"
password = "SecurePassword123"
// 1. Start with an Argon2 PHC hash, as a real application would.
phcHash, err := HashPassword(password)
require.NoError(t, err, "Setup failed: could not hash password")
// 2. Migrate the PHC hash to a SCRAM credential.
cred, err = MigrateFromPHC(username, password, phcHash)
require.NoError(t, err, "Setup failed: could not migrate from PHC hash")
// 3. Create a server and add the new credential.
server = NewScramServer()
server.AddCredential(cred)
return server, username, password, cred
}
// TestScram_FullRoundtrip_Success simulates a complete, successful authentication handshake.
func TestScram_FullRoundtrip_Success(t *testing.T) {
server, username, password, _ := setupScramTest(t)
client := NewScramClient(username, password)
// --- Step 1: Client sends its first message ---
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
// --- Step 2: Server receives client's message and responds ---
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err, "Server failed to process client's first message")
// --- Step 3: Client receives server's message, computes proof ---
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err, "Client failed to process server's first message")
// --- Step 4: Server receives client's proof and verifies it ---
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
require.NoError(t, err, "Server failed to verify client's final proof")
assert.NotEmpty(t, serverFinal.ServerSignature, "Server signature should not be empty")
// --- Step 5: Client verifies server's signature (mutual authentication) ---
err = client.VerifyServerFinalMessage(serverFinal)
assert.NoError(t, err, "Client failed to verify server's final signature")
t.Log("SCRAM full roundtrip successful")
}
// TestScram_FullRoundtrip_WrongPassword ensures authentication fails with an incorrect password.
func TestScram_FullRoundtrip_WrongPassword(t *testing.T) {
server, username, _, _ := setupScramTest(t)
// Create a client with the WRONG password
client := NewScramClient(username, "WrongPassword!!!")
// Steps 1-3 will appear to succeed, as the client doesn't know the password is wrong yet.
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err)
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err)
// --- Step 4: Server verification should fail here ---
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should reject proof from wrong password")
t.Log("SCRAM correctly failed for wrong password")
}
// TestScram_FullRoundtrip_UserNotFound tests for user enumeration protection.
// The server should not reveal whether a user exists or not in its first message.
func TestScram_FullRoundtrip_UserNotFound(t *testing.T) {
server, _, _, _ := setupScramTest(t)
client := NewScramClient("unknown_user", "any_password")
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
// --- Step 2: Server should return an error, but also a FAKE response ---
// This prevents an attacker from knowing if the user exists based on the response structure.
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should return an error for an unknown user")
assert.NotEmpty(t, serverFirst.FullNonce, "Server must still provide a nonce to prevent enumeration")
assert.NotEmpty(t, serverFirst.Salt, "Server must still provide a salt to prevent enumeration")
t.Log("SCRAM correctly protected against user enumeration")
}
// TestScram_InvalidNonce simulates a replay attack or message mismatch.
func TestScram_InvalidNonce(t *testing.T) {
server, username, password, _ := setupScramTest(t)
client := NewScramClient(username, password)
// Perform the first part of the handshake
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
clientFinal, _ := client.ProcessServerFirstMessage(serverFirst)
// Attempt to finalize with a completely different nonce
_, err := server.ProcessClientFinalMessage("this-is-a-bad-nonce", clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrSCRAMInvalidNonce, "Server should reject a final message with an unknown nonce")
}
// TestScram_CredentialImportExport verifies that credentials can be serialized and deserialized correctly.
func TestScram_CredentialImportExport(t *testing.T) {
_, _, _, originalCred := setupScramTest(t)
// Export the credential to a map
exportedData := originalCred.Export()
require.NotNil(t, exportedData)
// Assert that required fields exist and are strings (as they are base64 encoded)
assert.IsType(t, "", exportedData["salt"])
assert.IsType(t, "", exportedData["stored_key"])
assert.IsType(t, "", exportedData["server_key"])
// Import the credential back from the map
importedCred, err := ImportCredential(exportedData)
require.NoError(t, err)
require.NotNil(t, importedCred)
// Verify that the imported credential is identical to the original
assert.Equal(t, originalCred.Username, importedCred.Username)
assert.Equal(t, originalCred.Salt, importedCred.Salt)
assert.Equal(t, originalCred.ArgonTime, importedCred.ArgonTime)
assert.Equal(t, originalCred.ArgonMemory, importedCred.ArgonMemory)
assert.Equal(t, originalCred.ArgonThreads, importedCred.ArgonThreads)
assert.Equal(t, originalCred.StoredKey, importedCred.StoredKey)
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
t.Log("SCRAM credential import/export successful")
}
// TestScramServerCleanup verifies automatic cleanup of expired handshakes
func TestScramServerCleanup(t *testing.T) {
// Create server with short cleanup interval for testing
server := NewScramServer()
defer server.Stop()
// Add a test credential
cred := &Credential{
Username: "testuser",
Salt: []byte("salt1234567890123456"),
ArgonTime: 1,
ArgonMemory: 64,
ArgonThreads: 1,
StoredKey: []byte("stored_key_placeholder"),
ServerKey: []byte("server_key_placeholder"),
}
server.AddCredential(cred)
// Start multiple handshakes
var nonces []string
for i := 0; i < 5; i++ {
clientNonce := fmt.Sprintf("client-nonce-%d", i)
msg, err := server.ProcessClientFirstMessage("testuser", clientNonce)
require.NoError(t, err)
nonces = append(nonces, msg.FullNonce)
}
// Verify all handshakes exist
server.mu.RLock()
assert.Len(t, server.handshakes, 5)
server.mu.RUnlock()
// Manually set old timestamp for first 3 handshakes
server.mu.Lock()
oldTime := time.Now().Add(-2 * ScramHandshakeTimeout)
count := 0
for nonce := range server.handshakes {
if count < 3 {
server.handshakes[nonce].CreatedAt = oldTime
count++
}
}
server.mu.Unlock()
// Trigger cleanup manually
server.cleanupExpiredHandshakes()
// Verify only 2 handshakes remain
server.mu.RLock()
assert.Len(t, server.handshakes, 2, "Expired handshakes should be cleaned up")
server.mu.RUnlock()
}
// TestScramConcurrentSameUser verifies multiple concurrent authentications for same user
func TestScramConcurrentSameUser(t *testing.T) {
server, username, password, _ := setupScramTest(t)
defer server.Stop()
// Number of concurrent authentication attempts
numAttempts := 10
results := make(chan error, numAttempts)
var wg sync.WaitGroup
for i := 0; i < numAttempts; i++ {
wg.Add(1)
go func(attempt int) {
defer wg.Done()
// Each goroutine performs full authentication
client := NewScramClient(username, password)
// Step 1: Client first
clientFirst, err := client.StartAuthentication()
if err != nil {
results <- err
return
}
// Step 2: Server first
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
if err != nil {
results <- err
return
}
// Step 3: Client final
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
if err != nil {
results <- err
return
}
// Step 4: Server final
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
if err != nil {
results <- err
return
}
// Step 5: Client verify
err = client.VerifyServerFinalMessage(serverFinal)
results <- err
}(i)
}
wg.Wait()
close(results)
// Verify all attempts succeeded
successCount := 0
for err := range results {
if err == nil {
successCount++
} else {
t.Logf("Auth attempt failed: %v", err)
}
}
assert.Equal(t, numAttempts, successCount,
"All concurrent authentication attempts should succeed")
// Verify no handshakes remain after completion
server.mu.RLock()
assert.Empty(t, server.handshakes, "All handshakes should be cleaned up after completion")
server.mu.RUnlock()
}
// TestScramExplicitTimeout verifies timeout enforcement
func TestScramExplicitTimeout(t *testing.T) {
// Save original timeout and set shorter one for testing
originalTimeout := ScramHandshakeTimeout
// Note: Can't modify const at runtime, so we test with delay instead
server, username, password, _ := setupScramTest(t)
defer server.Stop()
client := NewScramClient(username, password)
// Start authentication
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err)
// Manually expire the handshake
server.mu.Lock()
for nonce := range server.handshakes {
server.handshakes[nonce].CreatedAt = time.Now().Add(-2 * ScramHandshakeTimeout)
}
server.mu.Unlock()
// Client processes server message (should work, client tracks own timeout)
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err)
// Server should reject due to timeout
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Server should reject expired handshake")
// Test client-side timeout
client2 := NewScramClient(username, password)
client2.startTime = time.Now().Add(-2 * ScramHandshakeTimeout)
_, err = client2.ProcessServerFirstMessage(serverFirst)
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Client should reject after timeout")
_ = originalTimeout // Suppress unused variable warning
}