319 lines
11 KiB
Go
319 lines
11 KiB
Go
// FILE: auth/scram_test.go
|
|
package auth
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// setupScramTest is a helper to initialize a server and user credential for testing.
|
|
// It performs the full Argon2 -> SCRAM migration workflow.
|
|
func setupScramTest(t *testing.T) (server *ScramServer, username, password string, cred *Credential) {
|
|
username = "testuser"
|
|
password = "SecurePassword123"
|
|
|
|
// 1. Start with an Argon2 PHC hash, as a real application would.
|
|
phcHash, err := HashPassword(password)
|
|
require.NoError(t, err, "Setup failed: could not hash password")
|
|
|
|
// 2. Migrate the PHC hash to a SCRAM credential.
|
|
cred, err = MigrateFromPHC(username, password, phcHash)
|
|
require.NoError(t, err, "Setup failed: could not migrate from PHC hash")
|
|
|
|
// 3. Create a server and add the new credential.
|
|
server = NewScramServer()
|
|
server.AddCredential(cred)
|
|
|
|
return server, username, password, cred
|
|
}
|
|
|
|
// TestScram_FullRoundtrip_Success simulates a complete, successful authentication handshake.
|
|
func TestScram_FullRoundtrip_Success(t *testing.T) {
|
|
server, username, password, _ := setupScramTest(t)
|
|
client := NewScramClient(username, password)
|
|
|
|
// --- Step 1: Client sends its first message ---
|
|
clientFirst, err := client.StartAuthentication()
|
|
require.NoError(t, err)
|
|
|
|
// --- Step 2: Server receives client's message and responds ---
|
|
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
|
require.NoError(t, err, "Server failed to process client's first message")
|
|
|
|
// --- Step 3: Client receives server's message, computes proof ---
|
|
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
|
require.NoError(t, err, "Client failed to process server's first message")
|
|
|
|
// --- Step 4: Server receives client's proof and verifies it ---
|
|
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
|
require.NoError(t, err, "Server failed to verify client's final proof")
|
|
assert.NotEmpty(t, serverFinal.ServerSignature, "Server signature should not be empty")
|
|
|
|
// --- Step 5: Client verifies server's signature (mutual authentication) ---
|
|
err = client.VerifyServerFinalMessage(serverFinal)
|
|
assert.NoError(t, err, "Client failed to verify server's final signature")
|
|
|
|
t.Log("SCRAM full roundtrip successful")
|
|
}
|
|
|
|
// TestScram_FullRoundtrip_WrongPassword ensures authentication fails with an incorrect password.
|
|
func TestScram_FullRoundtrip_WrongPassword(t *testing.T) {
|
|
server, username, _, _ := setupScramTest(t)
|
|
// Create a client with the WRONG password
|
|
client := NewScramClient(username, "WrongPassword!!!")
|
|
|
|
// Steps 1-3 will appear to succeed, as the client doesn't know the password is wrong yet.
|
|
clientFirst, err := client.StartAuthentication()
|
|
require.NoError(t, err)
|
|
|
|
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
|
require.NoError(t, err)
|
|
|
|
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
|
require.NoError(t, err)
|
|
|
|
// --- Step 4: Server verification should fail here ---
|
|
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
|
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should reject proof from wrong password")
|
|
|
|
t.Log("SCRAM correctly failed for wrong password")
|
|
}
|
|
|
|
// TestScram_FullRoundtrip_UserNotFound tests for user enumeration protection.
|
|
// The server should not reveal whether a user exists or not in its first message.
|
|
func TestScram_FullRoundtrip_UserNotFound(t *testing.T) {
|
|
server, _, _, _ := setupScramTest(t)
|
|
client := NewScramClient("unknown_user", "any_password")
|
|
|
|
clientFirst, err := client.StartAuthentication()
|
|
require.NoError(t, err)
|
|
|
|
// --- Step 2: Server should return an error, but also a FAKE response ---
|
|
// This prevents an attacker from knowing if the user exists based on the response structure.
|
|
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
|
assert.ErrorIs(t, err, ErrInvalidCredentials, "Server should return an error for an unknown user")
|
|
assert.NotEmpty(t, serverFirst.FullNonce, "Server must still provide a nonce to prevent enumeration")
|
|
assert.NotEmpty(t, serverFirst.Salt, "Server must still provide a salt to prevent enumeration")
|
|
|
|
t.Log("SCRAM correctly protected against user enumeration")
|
|
}
|
|
|
|
// TestScram_InvalidNonce simulates a replay attack or message mismatch.
|
|
func TestScram_InvalidNonce(t *testing.T) {
|
|
server, username, password, _ := setupScramTest(t)
|
|
client := NewScramClient(username, password)
|
|
|
|
// Perform the first part of the handshake
|
|
clientFirst, _ := client.StartAuthentication()
|
|
serverFirst, _ := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
|
clientFinal, _ := client.ProcessServerFirstMessage(serverFirst)
|
|
|
|
// Attempt to finalize with a completely different nonce
|
|
_, err := server.ProcessClientFinalMessage("this-is-a-bad-nonce", clientFinal.ClientProof)
|
|
assert.ErrorIs(t, err, ErrSCRAMInvalidNonce, "Server should reject a final message with an unknown nonce")
|
|
}
|
|
|
|
// TestScram_CredentialImportExport verifies that credentials can be serialized and deserialized correctly.
|
|
func TestScram_CredentialImportExport(t *testing.T) {
|
|
_, _, _, originalCred := setupScramTest(t)
|
|
|
|
// Export the credential to a map
|
|
exportedData := originalCred.Export()
|
|
require.NotNil(t, exportedData)
|
|
|
|
// Assert that required fields exist and are strings (as they are base64 encoded)
|
|
assert.IsType(t, "", exportedData["salt"])
|
|
assert.IsType(t, "", exportedData["stored_key"])
|
|
assert.IsType(t, "", exportedData["server_key"])
|
|
|
|
// Import the credential back from the map
|
|
importedCred, err := ImportCredential(exportedData)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, importedCred)
|
|
|
|
// Verify that the imported credential is identical to the original
|
|
assert.Equal(t, originalCred.Username, importedCred.Username)
|
|
assert.Equal(t, originalCred.Salt, importedCred.Salt)
|
|
assert.Equal(t, originalCred.ArgonTime, importedCred.ArgonTime)
|
|
assert.Equal(t, originalCred.ArgonMemory, importedCred.ArgonMemory)
|
|
assert.Equal(t, originalCred.ArgonThreads, importedCred.ArgonThreads)
|
|
assert.Equal(t, originalCred.StoredKey, importedCred.StoredKey)
|
|
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
|
|
|
|
t.Log("SCRAM credential import/export successful")
|
|
}
|
|
|
|
// TestScramServerCleanup verifies automatic cleanup of expired handshakes
|
|
func TestScramServerCleanup(t *testing.T) {
|
|
// Create server with short cleanup interval for testing
|
|
server := NewScramServer()
|
|
defer server.Stop()
|
|
|
|
// Add a test credential
|
|
cred := &Credential{
|
|
Username: "testuser",
|
|
Salt: []byte("salt1234567890123456"),
|
|
ArgonTime: 1,
|
|
ArgonMemory: 64,
|
|
ArgonThreads: 1,
|
|
StoredKey: []byte("stored_key_placeholder"),
|
|
ServerKey: []byte("server_key_placeholder"),
|
|
}
|
|
server.AddCredential(cred)
|
|
|
|
// Start multiple handshakes
|
|
var nonces []string
|
|
for i := 0; i < 5; i++ {
|
|
clientNonce := fmt.Sprintf("client-nonce-%d", i)
|
|
msg, err := server.ProcessClientFirstMessage("testuser", clientNonce)
|
|
require.NoError(t, err)
|
|
nonces = append(nonces, msg.FullNonce)
|
|
}
|
|
|
|
// Verify all handshakes exist
|
|
server.mu.RLock()
|
|
assert.Len(t, server.handshakes, 5)
|
|
server.mu.RUnlock()
|
|
|
|
// Manually set old timestamp for first 3 handshakes
|
|
server.mu.Lock()
|
|
oldTime := time.Now().Add(-2 * ScramHandshakeTimeout)
|
|
count := 0
|
|
for nonce := range server.handshakes {
|
|
if count < 3 {
|
|
server.handshakes[nonce].CreatedAt = oldTime
|
|
count++
|
|
}
|
|
}
|
|
server.mu.Unlock()
|
|
|
|
// Trigger cleanup manually
|
|
server.cleanupExpiredHandshakes()
|
|
|
|
// Verify only 2 handshakes remain
|
|
server.mu.RLock()
|
|
assert.Len(t, server.handshakes, 2, "Expired handshakes should be cleaned up")
|
|
server.mu.RUnlock()
|
|
}
|
|
|
|
// TestScramConcurrentSameUser verifies multiple concurrent authentications for same user
|
|
func TestScramConcurrentSameUser(t *testing.T) {
|
|
server, username, password, _ := setupScramTest(t)
|
|
defer server.Stop()
|
|
|
|
// Number of concurrent authentication attempts
|
|
numAttempts := 10
|
|
results := make(chan error, numAttempts)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < numAttempts; i++ {
|
|
wg.Add(1)
|
|
go func(attempt int) {
|
|
defer wg.Done()
|
|
|
|
// Each goroutine performs full authentication
|
|
client := NewScramClient(username, password)
|
|
|
|
// Step 1: Client first
|
|
clientFirst, err := client.StartAuthentication()
|
|
if err != nil {
|
|
results <- err
|
|
return
|
|
}
|
|
|
|
// Step 2: Server first
|
|
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
|
if err != nil {
|
|
results <- err
|
|
return
|
|
}
|
|
|
|
// Step 3: Client final
|
|
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
|
if err != nil {
|
|
results <- err
|
|
return
|
|
}
|
|
|
|
// Step 4: Server final
|
|
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
|
if err != nil {
|
|
results <- err
|
|
return
|
|
}
|
|
|
|
// Step 5: Client verify
|
|
err = client.VerifyServerFinalMessage(serverFinal)
|
|
results <- err
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(results)
|
|
|
|
// Verify all attempts succeeded
|
|
successCount := 0
|
|
for err := range results {
|
|
if err == nil {
|
|
successCount++
|
|
} else {
|
|
t.Logf("Auth attempt failed: %v", err)
|
|
}
|
|
}
|
|
|
|
assert.Equal(t, numAttempts, successCount,
|
|
"All concurrent authentication attempts should succeed")
|
|
|
|
// Verify no handshakes remain after completion
|
|
server.mu.RLock()
|
|
assert.Empty(t, server.handshakes, "All handshakes should be cleaned up after completion")
|
|
server.mu.RUnlock()
|
|
}
|
|
|
|
// TestScramExplicitTimeout verifies timeout enforcement
|
|
func TestScramExplicitTimeout(t *testing.T) {
|
|
// Save original timeout and set shorter one for testing
|
|
originalTimeout := ScramHandshakeTimeout
|
|
// Note: Can't modify const at runtime, so we test with delay instead
|
|
|
|
server, username, password, _ := setupScramTest(t)
|
|
defer server.Stop()
|
|
|
|
client := NewScramClient(username, password)
|
|
|
|
// Start authentication
|
|
clientFirst, err := client.StartAuthentication()
|
|
require.NoError(t, err)
|
|
|
|
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
|
require.NoError(t, err)
|
|
|
|
// Manually expire the handshake
|
|
server.mu.Lock()
|
|
for nonce := range server.handshakes {
|
|
server.handshakes[nonce].CreatedAt = time.Now().Add(-2 * ScramHandshakeTimeout)
|
|
}
|
|
server.mu.Unlock()
|
|
|
|
// Client processes server message (should work, client tracks own timeout)
|
|
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
|
require.NoError(t, err)
|
|
|
|
// Server should reject due to timeout
|
|
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
|
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Server should reject expired handshake")
|
|
|
|
// Test client-side timeout
|
|
client2 := NewScramClient(username, password)
|
|
client2.startTime = time.Now().Add(-2 * ScramHandshakeTimeout)
|
|
|
|
_, err = client2.ProcessServerFirstMessage(serverFirst)
|
|
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Client should reject after timeout")
|
|
|
|
_ = originalTimeout // Suppress unused variable warning
|
|
} |