v0.2.1 argon2 and scram imporoved

This commit is contained in:
2025-11-04 08:10:16 -05:00
parent 3471030edd
commit aafa680a35
4 changed files with 371 additions and 23 deletions

View File

@ -142,4 +142,76 @@ func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
} }
return DeriveCredential(username, password, salt, time, memory, threads) return DeriveCredential(username, password, salt, time, memory, threads)
}
// ValidatePHCHashFormat checks if a hash string has a valid and complete
// PHC format for Argon2id. It validates structure, parameters, and encoding,
// but does not verify a password against the hash.
func ValidatePHCHashFormat(phcHash string) error {
parts := strings.Split(phcHash, "$")
if len(parts) != 6 {
return fmt.Errorf("%w: expected 6 parts, got %d", ErrPHCInvalidFormat, len(parts))
}
// Validate empty parts[0] (PHC format starts with $)
if parts[0] != "" {
return fmt.Errorf("%w: hash must start with $", ErrPHCInvalidFormat)
}
// Validate algorithm identifier
if parts[1] != "argon2id" {
return fmt.Errorf("%w: unsupported algorithm %q, expected argon2id", ErrPHCInvalidFormat, parts[1])
}
// Validate version
var version int
n, err := fmt.Sscanf(parts[2], "v=%d", &version)
if err != nil || n != 1 {
return fmt.Errorf("%w: invalid version format", ErrPHCInvalidFormat)
}
if version != argon2.Version {
return fmt.Errorf("%w: unsupported version %d, expected %d", ErrPHCInvalidFormat, version, argon2.Version)
}
// Validate parameters
var memory, time uint32
var threads uint8
n, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
if err != nil || n != 3 {
return fmt.Errorf("%w: failed to parse parameters", ErrPHCInvalidFormat)
}
// Validate parameter ranges
if time == 0 || memory == 0 || threads == 0 {
return fmt.Errorf("%w: parameters must be non-zero", ErrPHCInvalidFormat)
}
if memory > 4*1024*1024 { // 4GB limit
return fmt.Errorf("%w: memory parameter exceeds maximum (4GB)", ErrPHCInvalidFormat)
}
if time > 1000 { // Reasonable upper bound
return fmt.Errorf("%w: time parameter exceeds maximum (1000)", ErrPHCInvalidFormat)
}
if threads > 255 { // uint8 max, but practically much lower
return fmt.Errorf("%w: threads parameter exceeds maximum (255)", ErrPHCInvalidFormat)
}
// Validate salt encoding
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return fmt.Errorf("%w: %v", ErrPHCInvalidSalt, err)
}
if len(salt) < 8 { // Minimum safe salt length
return fmt.Errorf("%w: salt too short (%d bytes)", ErrPHCInvalidSalt, len(salt))
}
// Validate hash encoding
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
}
if len(hash) < 16 { // Minimum hash length
return fmt.Errorf("%w: hash too short (%d bytes)", ErrPHCInvalidHash, len(hash))
}
return nil
} }

View File

@ -2,6 +2,7 @@
package auth package auth
import ( import (
"encoding/base64"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -112,4 +113,63 @@ func TestPHCMigration(t *testing.T) {
// Test with invalid PHC format // Test with invalid PHC format
_, err = MigrateFromPHC(username, password, "$invalid$format") _, err = MigrateFromPHC(username, password, "$invalid$format")
assert.Error(t, err) assert.Error(t, err)
}
func TestValidatePHCHashFormat(t *testing.T) {
// Generate valid hash for testing
validHash, err := HashPassword("testPassword123")
require.NoError(t, err)
// Test valid hash
err = ValidatePHCHashFormat(validHash)
assert.NoError(t, err, "Valid hash should pass validation")
// Test malformed formats
testCases := []struct {
name string
hash string
wantErr error
}{
{"empty", "", ErrPHCInvalidFormat},
{"not PHC format", "plaintext", ErrPHCInvalidFormat},
{"wrong prefix", "argon2id$v=19$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
{"wrong algorithm", "$bcrypt$v=19$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
{"missing version", "$argon2id$$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
{"wrong version", "$argon2id$v=1$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
{"missing params", "$argon2id$v=19$$salt$hash", ErrPHCInvalidFormat},
{"invalid params format", "$argon2id$v=19$invalid$salt$hash", ErrPHCInvalidFormat},
{"zero time", "$argon2id$v=19$m=65536,t=0,p=4$salt$hash", ErrPHCInvalidFormat},
{"zero memory", "$argon2id$v=19$m=0,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
{"zero threads", "$argon2id$v=19$m=65536,t=3,p=0$salt$hash", ErrPHCInvalidFormat},
{"excessive memory", "$argon2id$v=19$m=5000000,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
{"excessive time", "$argon2id$v=19$m=65536,t=2000,p=4$salt$hash", ErrPHCInvalidFormat},
{"invalid salt encoding", "$argon2id$v=19$m=65536,t=3,p=4$!!!invalid!!!$hash", ErrPHCInvalidSalt},
{"invalid hash encoding", "$argon2id$v=19$m=65536,t=3,p=4$" +
base64.RawStdEncoding.EncodeToString([]byte("salt12345678")) + "$!!!invalid!!!", ErrPHCInvalidHash},
{"short salt", "$argon2id$v=19$m=65536,t=3,p=4$" +
base64.RawStdEncoding.EncodeToString([]byte("short")) + "$" +
base64.RawStdEncoding.EncodeToString([]byte("hash1234567890123456")), ErrPHCInvalidSalt},
{"short hash", "$argon2id$v=19$m=65536,t=3,p=4$" +
base64.RawStdEncoding.EncodeToString([]byte("salt12345678")) + "$" +
base64.RawStdEncoding.EncodeToString([]byte("short")), ErrPHCInvalidHash},
{"too few parts", "$argon2id$v=19$m=65536,t=3,p=4", ErrPHCInvalidFormat},
{"too many parts", "$argon2id$v=19$m=65536,t=3,p=4$salt$hash$extra", ErrPHCInvalidFormat},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidatePHCHashFormat(tc.hash)
assert.ErrorIs(t, err, tc.wantErr, "Test case: %s", tc.name)
})
}
// Test that validation doesn't require password
err = ValidatePHCHashFormat(validHash)
assert.NoError(t, err, "Should validate format without password")
// Verify that a validated hash can still be used for verification
err = ValidatePHCHashFormat(validHash)
require.NoError(t, err)
err = VerifyPassword("testPassword123", validHash)
assert.NoError(t, err, "Validated hash should still work for password verification")
} }

View File

@ -17,6 +17,13 @@ import (
// SCRAM-SHA256 implementation // SCRAM-SHA256 implementation
const (
// ScramHandshakeTimeout defines maximum time for completing SCRAM handshake
ScramHandshakeTimeout = 30 * time.Second
// ScramCleanupInterval defines how often expired handshakes are cleaned
ScramCleanupInterval = 60 * time.Second
)
// Credential stores SCRAM authentication data // Credential stores SCRAM authentication data
type Credential struct { type Credential struct {
Username string Username string
@ -166,13 +173,6 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
}, nil }, nil
} }
// ScramServer handles server-side SCRAM authentication
type ScramServer struct {
credentials map[string]*Credential
handshakes map[string]*HandshakeState
mu sync.RWMutex
}
// HandshakeState tracks ongoing authentication // HandshakeState tracks ongoing authentication
type HandshakeState struct { type HandshakeState struct {
Username string Username string
@ -184,19 +184,59 @@ type HandshakeState struct {
verifying int32 // Atomic flag to prevent race during verification verifying int32 // Atomic flag to prevent race during verification
} }
// ScramServer handles server-side SCRAM authentication
type ScramServer struct {
credentials map[string]*Credential
handshakes map[string]*HandshakeState
mu sync.RWMutex
cleanupTicker *time.Ticker
cleanupStop chan struct{}
}
// NewScramServer creates SCRAM server // NewScramServer creates SCRAM server
func NewScramServer() *ScramServer { func NewScramServer() *ScramServer {
return &ScramServer{ s := &ScramServer{
credentials: make(map[string]*Credential), credentials: make(map[string]*Credential),
handshakes: make(map[string]*HandshakeState), handshakes: make(map[string]*HandshakeState),
cleanupTicker: time.NewTicker(ScramCleanupInterval),
cleanupStop: make(chan struct{}),
}
// Start background cleanup goroutine
go s.cleanupLoop()
return s
}
// Stop gracefully shuts down the server and cleanup goroutine
func (s *ScramServer) Stop() {
close(s.cleanupStop)
s.cleanupTicker.Stop()
}
// cleanupLoop runs periodic cleanup of expired handshakes
func (s *ScramServer) cleanupLoop() {
for {
select {
case <-s.cleanupTicker.C:
s.cleanupExpiredHandshakes()
case <-s.cleanupStop:
return
}
} }
} }
// AddCredential registers user credential // cleanupExpiredHandshakes removes handshakes older than timeout
func (s *ScramServer) AddCredential(cred *Credential) { func (s *ScramServer) cleanupExpiredHandshakes() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.credentials[cred.Username] = cred
cutoff := time.Now().Add(-ScramHandshakeTimeout)
for nonce, state := range s.handshakes {
if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 {
delete(s.handshakes, nonce)
}
}
} }
// ProcessClientFirstMessage processes initial auth request // ProcessClientFirstMessage processes initial auth request
@ -237,9 +277,6 @@ func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (S
} }
s.handshakes[fullNonce] = state s.handshakes[fullNonce] = state
// Cleanup old handshakes
s.cleanupHandshakes()
return ServerFirstMessage{ return ServerFirstMessage{
FullNonce: fullNonce, FullNonce: fullNonce,
Salt: base64.StdEncoding.EncodeToString(cred.Salt), Salt: base64.StdEncoding.EncodeToString(cred.Salt),
@ -272,10 +309,11 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (
}() }()
// Check timeout // Check timeout
if time.Since(state.CreatedAt) > 60*time.Second { if time.Since(state.CreatedAt) > ScramHandshakeTimeout {
return ServerFinalMessage{}, ErrSCRAMTimeout return ServerFinalMessage{}, ErrSCRAMTimeout
} }
// [rest of verification logic unchanged]
// Decode client proof // Decode client proof
clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof) clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof)
if err != nil { if err != nil {
@ -318,6 +356,13 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (
}, nil }, nil
} }
// AddCredential registers user credential
func (s *ScramServer) AddCredential(cred *Credential) {
s.mu.Lock()
defer s.mu.Unlock()
s.credentials[cred.Username] = cred
}
func (s *ScramServer) cleanupHandshakes() { func (s *ScramServer) cleanupHandshakes() {
cutoff := time.Now().Add(-60 * time.Second) cutoff := time.Now().Add(-60 * time.Second)
for nonce, state := range s.handshakes { for nonce, state := range s.handshakes {
@ -365,16 +410,13 @@ func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) {
// ProcessServerFirstMessage handles server challenge // ProcessServerFirstMessage handles server challenge
func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) { func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) {
// Check timeout (30 seconds) // Check timeout
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second { if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout {
return ClientFinalRequest{}, ErrSCRAMTimeout return ClientFinalRequest{}, ErrSCRAMTimeout
} }
c.serverFirst = &msg c.serverFirst = &msg
// Handle enumeration prevention - server may send fake response
// We still process it normally and let verification fail later
// Decode salt // Decode salt
salt, err := base64.StdEncoding.DecodeString(msg.Salt) salt, err := base64.StdEncoding.DecodeString(msg.Salt)
if err != nil { if err != nil {
@ -414,10 +456,11 @@ func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientF
// VerifyServerFinalMessage validates server signature // VerifyServerFinalMessage validates server signature
func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error { func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error {
// Check timeout // Check timeout
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second { if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout {
return ErrSCRAMTimeout return ErrSCRAMTimeout
} }
// [rest unchanged]
if c.authMessage == "" || c.serverKey == nil { if c.authMessage == "" || c.serverKey == nil {
return ErrSCRAMInvalidState return ErrSCRAMInvalidState
} }

View File

@ -2,7 +2,10 @@
package auth package auth
import ( import (
"fmt"
"sync"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -143,4 +146,174 @@ func TestScram_CredentialImportExport(t *testing.T) {
assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey) assert.Equal(t, originalCred.ServerKey, importedCred.ServerKey)
t.Log("SCRAM credential import/export successful") t.Log("SCRAM credential import/export successful")
}
// TestScramServerCleanup verifies automatic cleanup of expired handshakes
func TestScramServerCleanup(t *testing.T) {
// Create server with short cleanup interval for testing
server := NewScramServer()
defer server.Stop()
// Add a test credential
cred := &Credential{
Username: "testuser",
Salt: []byte("salt1234567890123456"),
ArgonTime: 1,
ArgonMemory: 64,
ArgonThreads: 1,
StoredKey: []byte("stored_key_placeholder"),
ServerKey: []byte("server_key_placeholder"),
}
server.AddCredential(cred)
// Start multiple handshakes
var nonces []string
for i := 0; i < 5; i++ {
clientNonce := fmt.Sprintf("client-nonce-%d", i)
msg, err := server.ProcessClientFirstMessage("testuser", clientNonce)
require.NoError(t, err)
nonces = append(nonces, msg.FullNonce)
}
// Verify all handshakes exist
server.mu.RLock()
assert.Len(t, server.handshakes, 5)
server.mu.RUnlock()
// Manually set old timestamp for first 3 handshakes
server.mu.Lock()
oldTime := time.Now().Add(-2 * ScramHandshakeTimeout)
count := 0
for nonce := range server.handshakes {
if count < 3 {
server.handshakes[nonce].CreatedAt = oldTime
count++
}
}
server.mu.Unlock()
// Trigger cleanup manually
server.cleanupExpiredHandshakes()
// Verify only 2 handshakes remain
server.mu.RLock()
assert.Len(t, server.handshakes, 2, "Expired handshakes should be cleaned up")
server.mu.RUnlock()
}
// TestScramConcurrentSameUser verifies multiple concurrent authentications for same user
func TestScramConcurrentSameUser(t *testing.T) {
server, username, password, _ := setupScramTest(t)
defer server.Stop()
// Number of concurrent authentication attempts
numAttempts := 10
results := make(chan error, numAttempts)
var wg sync.WaitGroup
for i := 0; i < numAttempts; i++ {
wg.Add(1)
go func(attempt int) {
defer wg.Done()
// Each goroutine performs full authentication
client := NewScramClient(username, password)
// Step 1: Client first
clientFirst, err := client.StartAuthentication()
if err != nil {
results <- err
return
}
// Step 2: Server first
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
if err != nil {
results <- err
return
}
// Step 3: Client final
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
if err != nil {
results <- err
return
}
// Step 4: Server final
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
if err != nil {
results <- err
return
}
// Step 5: Client verify
err = client.VerifyServerFinalMessage(serverFinal)
results <- err
}(i)
}
wg.Wait()
close(results)
// Verify all attempts succeeded
successCount := 0
for err := range results {
if err == nil {
successCount++
} else {
t.Logf("Auth attempt failed: %v", err)
}
}
assert.Equal(t, numAttempts, successCount,
"All concurrent authentication attempts should succeed")
// Verify no handshakes remain after completion
server.mu.RLock()
assert.Empty(t, server.handshakes, "All handshakes should be cleaned up after completion")
server.mu.RUnlock()
}
// TestScramExplicitTimeout verifies timeout enforcement
func TestScramExplicitTimeout(t *testing.T) {
// Save original timeout and set shorter one for testing
originalTimeout := ScramHandshakeTimeout
// Note: Can't modify const at runtime, so we test with delay instead
server, username, password, _ := setupScramTest(t)
defer server.Stop()
client := NewScramClient(username, password)
// Start authentication
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
require.NoError(t, err)
// Manually expire the handshake
server.mu.Lock()
for nonce := range server.handshakes {
server.handshakes[nonce].CreatedAt = time.Now().Add(-2 * ScramHandshakeTimeout)
}
server.mu.Unlock()
// Client processes server message (should work, client tracks own timeout)
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
require.NoError(t, err)
// Server should reject due to timeout
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Server should reject expired handshake")
// Test client-side timeout
client2 := NewScramClient(username, password)
client2.startTime = time.Now().Add(-2 * ScramHandshakeTimeout)
_, err = client2.ProcessServerFirstMessage(serverFirst)
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Client should reject after timeout")
_ = originalTimeout // Suppress unused variable warning
} }