v0.2.1 argon2 and scram imporoved
This commit is contained in:
72
argon2.go
72
argon2.go
@ -143,3 +143,75 @@ func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
||||
|
||||
return DeriveCredential(username, password, salt, time, memory, threads)
|
||||
}
|
||||
|
||||
// ValidatePHCHashFormat checks if a hash string has a valid and complete
|
||||
// PHC format for Argon2id. It validates structure, parameters, and encoding,
|
||||
// but does not verify a password against the hash.
|
||||
func ValidatePHCHashFormat(phcHash string) error {
|
||||
parts := strings.Split(phcHash, "$")
|
||||
if len(parts) != 6 {
|
||||
return fmt.Errorf("%w: expected 6 parts, got %d", ErrPHCInvalidFormat, len(parts))
|
||||
}
|
||||
|
||||
// Validate empty parts[0] (PHC format starts with $)
|
||||
if parts[0] != "" {
|
||||
return fmt.Errorf("%w: hash must start with $", ErrPHCInvalidFormat)
|
||||
}
|
||||
|
||||
// Validate algorithm identifier
|
||||
if parts[1] != "argon2id" {
|
||||
return fmt.Errorf("%w: unsupported algorithm %q, expected argon2id", ErrPHCInvalidFormat, parts[1])
|
||||
}
|
||||
|
||||
// Validate version
|
||||
var version int
|
||||
n, err := fmt.Sscanf(parts[2], "v=%d", &version)
|
||||
if err != nil || n != 1 {
|
||||
return fmt.Errorf("%w: invalid version format", ErrPHCInvalidFormat)
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return fmt.Errorf("%w: unsupported version %d, expected %d", ErrPHCInvalidFormat, version, argon2.Version)
|
||||
}
|
||||
|
||||
// Validate parameters
|
||||
var memory, time uint32
|
||||
var threads uint8
|
||||
n, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||
if err != nil || n != 3 {
|
||||
return fmt.Errorf("%w: failed to parse parameters", ErrPHCInvalidFormat)
|
||||
}
|
||||
|
||||
// Validate parameter ranges
|
||||
if time == 0 || memory == 0 || threads == 0 {
|
||||
return fmt.Errorf("%w: parameters must be non-zero", ErrPHCInvalidFormat)
|
||||
}
|
||||
if memory > 4*1024*1024 { // 4GB limit
|
||||
return fmt.Errorf("%w: memory parameter exceeds maximum (4GB)", ErrPHCInvalidFormat)
|
||||
}
|
||||
if time > 1000 { // Reasonable upper bound
|
||||
return fmt.Errorf("%w: time parameter exceeds maximum (1000)", ErrPHCInvalidFormat)
|
||||
}
|
||||
if threads > 255 { // uint8 max, but practically much lower
|
||||
return fmt.Errorf("%w: threads parameter exceeds maximum (255)", ErrPHCInvalidFormat)
|
||||
}
|
||||
|
||||
// Validate salt encoding
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrPHCInvalidSalt, err)
|
||||
}
|
||||
if len(salt) < 8 { // Minimum safe salt length
|
||||
return fmt.Errorf("%w: salt too short (%d bytes)", ErrPHCInvalidSalt, len(salt))
|
||||
}
|
||||
|
||||
// Validate hash encoding
|
||||
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrPHCInvalidHash, err)
|
||||
}
|
||||
if len(hash) < 16 { // Minimum hash length
|
||||
return fmt.Errorf("%w: hash too short (%d bytes)", ErrPHCInvalidHash, len(hash))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -2,6 +2,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@ -113,3 +114,62 @@ func TestPHCMigration(t *testing.T) {
|
||||
_, err = MigrateFromPHC(username, password, "$invalid$format")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidatePHCHashFormat(t *testing.T) {
|
||||
// Generate valid hash for testing
|
||||
validHash, err := HashPassword("testPassword123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test valid hash
|
||||
err = ValidatePHCHashFormat(validHash)
|
||||
assert.NoError(t, err, "Valid hash should pass validation")
|
||||
|
||||
// Test malformed formats
|
||||
testCases := []struct {
|
||||
name string
|
||||
hash string
|
||||
wantErr error
|
||||
}{
|
||||
{"empty", "", ErrPHCInvalidFormat},
|
||||
{"not PHC format", "plaintext", ErrPHCInvalidFormat},
|
||||
{"wrong prefix", "argon2id$v=19$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
|
||||
{"wrong algorithm", "$bcrypt$v=19$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
|
||||
{"missing version", "$argon2id$$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
|
||||
{"wrong version", "$argon2id$v=1$m=65536,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
|
||||
{"missing params", "$argon2id$v=19$$salt$hash", ErrPHCInvalidFormat},
|
||||
{"invalid params format", "$argon2id$v=19$invalid$salt$hash", ErrPHCInvalidFormat},
|
||||
{"zero time", "$argon2id$v=19$m=65536,t=0,p=4$salt$hash", ErrPHCInvalidFormat},
|
||||
{"zero memory", "$argon2id$v=19$m=0,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
|
||||
{"zero threads", "$argon2id$v=19$m=65536,t=3,p=0$salt$hash", ErrPHCInvalidFormat},
|
||||
{"excessive memory", "$argon2id$v=19$m=5000000,t=3,p=4$salt$hash", ErrPHCInvalidFormat},
|
||||
{"excessive time", "$argon2id$v=19$m=65536,t=2000,p=4$salt$hash", ErrPHCInvalidFormat},
|
||||
{"invalid salt encoding", "$argon2id$v=19$m=65536,t=3,p=4$!!!invalid!!!$hash", ErrPHCInvalidSalt},
|
||||
{"invalid hash encoding", "$argon2id$v=19$m=65536,t=3,p=4$" +
|
||||
base64.RawStdEncoding.EncodeToString([]byte("salt12345678")) + "$!!!invalid!!!", ErrPHCInvalidHash},
|
||||
{"short salt", "$argon2id$v=19$m=65536,t=3,p=4$" +
|
||||
base64.RawStdEncoding.EncodeToString([]byte("short")) + "$" +
|
||||
base64.RawStdEncoding.EncodeToString([]byte("hash1234567890123456")), ErrPHCInvalidSalt},
|
||||
{"short hash", "$argon2id$v=19$m=65536,t=3,p=4$" +
|
||||
base64.RawStdEncoding.EncodeToString([]byte("salt12345678")) + "$" +
|
||||
base64.RawStdEncoding.EncodeToString([]byte("short")), ErrPHCInvalidHash},
|
||||
{"too few parts", "$argon2id$v=19$m=65536,t=3,p=4", ErrPHCInvalidFormat},
|
||||
{"too many parts", "$argon2id$v=19$m=65536,t=3,p=4$salt$hash$extra", ErrPHCInvalidFormat},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidatePHCHashFormat(tc.hash)
|
||||
assert.ErrorIs(t, err, tc.wantErr, "Test case: %s", tc.name)
|
||||
})
|
||||
}
|
||||
|
||||
// Test that validation doesn't require password
|
||||
err = ValidatePHCHashFormat(validHash)
|
||||
assert.NoError(t, err, "Should validate format without password")
|
||||
|
||||
// Verify that a validated hash can still be used for verification
|
||||
err = ValidatePHCHashFormat(validHash)
|
||||
require.NoError(t, err)
|
||||
err = VerifyPassword("testPassword123", validHash)
|
||||
assert.NoError(t, err, "Validated hash should still work for password verification")
|
||||
}
|
||||
85
scram.go
85
scram.go
@ -17,6 +17,13 @@ import (
|
||||
|
||||
// SCRAM-SHA256 implementation
|
||||
|
||||
const (
|
||||
// ScramHandshakeTimeout defines maximum time for completing SCRAM handshake
|
||||
ScramHandshakeTimeout = 30 * time.Second
|
||||
// ScramCleanupInterval defines how often expired handshakes are cleaned
|
||||
ScramCleanupInterval = 60 * time.Second
|
||||
)
|
||||
|
||||
// Credential stores SCRAM authentication data
|
||||
type Credential struct {
|
||||
Username string
|
||||
@ -166,13 +173,6 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ScramServer handles server-side SCRAM authentication
|
||||
type ScramServer struct {
|
||||
credentials map[string]*Credential
|
||||
handshakes map[string]*HandshakeState
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// HandshakeState tracks ongoing authentication
|
||||
type HandshakeState struct {
|
||||
Username string
|
||||
@ -184,19 +184,59 @@ type HandshakeState struct {
|
||||
verifying int32 // Atomic flag to prevent race during verification
|
||||
}
|
||||
|
||||
// ScramServer handles server-side SCRAM authentication
|
||||
type ScramServer struct {
|
||||
credentials map[string]*Credential
|
||||
handshakes map[string]*HandshakeState
|
||||
mu sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupStop chan struct{}
|
||||
}
|
||||
|
||||
// NewScramServer creates SCRAM server
|
||||
func NewScramServer() *ScramServer {
|
||||
return &ScramServer{
|
||||
s := &ScramServer{
|
||||
credentials: make(map[string]*Credential),
|
||||
handshakes: make(map[string]*HandshakeState),
|
||||
cleanupTicker: time.NewTicker(ScramCleanupInterval),
|
||||
cleanupStop: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go s.cleanupLoop()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the server and cleanup goroutine
|
||||
func (s *ScramServer) Stop() {
|
||||
close(s.cleanupStop)
|
||||
s.cleanupTicker.Stop()
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired handshakes
|
||||
func (s *ScramServer) cleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-s.cleanupTicker.C:
|
||||
s.cleanupExpiredHandshakes()
|
||||
case <-s.cleanupStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddCredential registers user credential
|
||||
func (s *ScramServer) AddCredential(cred *Credential) {
|
||||
// cleanupExpiredHandshakes removes handshakes older than timeout
|
||||
func (s *ScramServer) cleanupExpiredHandshakes() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.credentials[cred.Username] = cred
|
||||
|
||||
cutoff := time.Now().Add(-ScramHandshakeTimeout)
|
||||
for nonce, state := range s.handshakes {
|
||||
if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 {
|
||||
delete(s.handshakes, nonce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessClientFirstMessage processes initial auth request
|
||||
@ -237,9 +277,6 @@ func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (S
|
||||
}
|
||||
s.handshakes[fullNonce] = state
|
||||
|
||||
// Cleanup old handshakes
|
||||
s.cleanupHandshakes()
|
||||
|
||||
return ServerFirstMessage{
|
||||
FullNonce: fullNonce,
|
||||
Salt: base64.StdEncoding.EncodeToString(cred.Salt),
|
||||
@ -272,10 +309,11 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (
|
||||
}()
|
||||
|
||||
// Check timeout
|
||||
if time.Since(state.CreatedAt) > 60*time.Second {
|
||||
if time.Since(state.CreatedAt) > ScramHandshakeTimeout {
|
||||
return ServerFinalMessage{}, ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
// [rest of verification logic unchanged]
|
||||
// Decode client proof
|
||||
clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof)
|
||||
if err != nil {
|
||||
@ -318,6 +356,13 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddCredential registers user credential
|
||||
func (s *ScramServer) AddCredential(cred *Credential) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.credentials[cred.Username] = cred
|
||||
}
|
||||
|
||||
func (s *ScramServer) cleanupHandshakes() {
|
||||
cutoff := time.Now().Add(-60 * time.Second)
|
||||
for nonce, state := range s.handshakes {
|
||||
@ -365,16 +410,13 @@ func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) {
|
||||
|
||||
// ProcessServerFirstMessage handles server challenge
|
||||
func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) {
|
||||
// Check timeout (30 seconds)
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
||||
// Check timeout
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout {
|
||||
return ClientFinalRequest{}, ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
c.serverFirst = &msg
|
||||
|
||||
// Handle enumeration prevention - server may send fake response
|
||||
// We still process it normally and let verification fail later
|
||||
|
||||
// Decode salt
|
||||
salt, err := base64.StdEncoding.DecodeString(msg.Salt)
|
||||
if err != nil {
|
||||
@ -414,10 +456,11 @@ func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientF
|
||||
// VerifyServerFinalMessage validates server signature
|
||||
func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error {
|
||||
// Check timeout
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout {
|
||||
return ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
// [rest unchanged]
|
||||
if c.authMessage == "" || c.serverKey == nil {
|
||||
return ErrSCRAMInvalidState
|
||||
}
|
||||
|
||||
173
scram_test.go
173
scram_test.go
@ -2,7 +2,10 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -144,3 +147,173 @@ func TestScram_CredentialImportExport(t *testing.T) {
|
||||
|
||||
t.Log("SCRAM credential import/export successful")
|
||||
}
|
||||
|
||||
// TestScramServerCleanup verifies automatic cleanup of expired handshakes
|
||||
func TestScramServerCleanup(t *testing.T) {
|
||||
// Create server with short cleanup interval for testing
|
||||
server := NewScramServer()
|
||||
defer server.Stop()
|
||||
|
||||
// Add a test credential
|
||||
cred := &Credential{
|
||||
Username: "testuser",
|
||||
Salt: []byte("salt1234567890123456"),
|
||||
ArgonTime: 1,
|
||||
ArgonMemory: 64,
|
||||
ArgonThreads: 1,
|
||||
StoredKey: []byte("stored_key_placeholder"),
|
||||
ServerKey: []byte("server_key_placeholder"),
|
||||
}
|
||||
server.AddCredential(cred)
|
||||
|
||||
// Start multiple handshakes
|
||||
var nonces []string
|
||||
for i := 0; i < 5; i++ {
|
||||
clientNonce := fmt.Sprintf("client-nonce-%d", i)
|
||||
msg, err := server.ProcessClientFirstMessage("testuser", clientNonce)
|
||||
require.NoError(t, err)
|
||||
nonces = append(nonces, msg.FullNonce)
|
||||
}
|
||||
|
||||
// Verify all handshakes exist
|
||||
server.mu.RLock()
|
||||
assert.Len(t, server.handshakes, 5)
|
||||
server.mu.RUnlock()
|
||||
|
||||
// Manually set old timestamp for first 3 handshakes
|
||||
server.mu.Lock()
|
||||
oldTime := time.Now().Add(-2 * ScramHandshakeTimeout)
|
||||
count := 0
|
||||
for nonce := range server.handshakes {
|
||||
if count < 3 {
|
||||
server.handshakes[nonce].CreatedAt = oldTime
|
||||
count++
|
||||
}
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// Trigger cleanup manually
|
||||
server.cleanupExpiredHandshakes()
|
||||
|
||||
// Verify only 2 handshakes remain
|
||||
server.mu.RLock()
|
||||
assert.Len(t, server.handshakes, 2, "Expired handshakes should be cleaned up")
|
||||
server.mu.RUnlock()
|
||||
}
|
||||
|
||||
// TestScramConcurrentSameUser verifies multiple concurrent authentications for same user
|
||||
func TestScramConcurrentSameUser(t *testing.T) {
|
||||
server, username, password, _ := setupScramTest(t)
|
||||
defer server.Stop()
|
||||
|
||||
// Number of concurrent authentication attempts
|
||||
numAttempts := 10
|
||||
results := make(chan error, numAttempts)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < numAttempts; i++ {
|
||||
wg.Add(1)
|
||||
go func(attempt int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine performs full authentication
|
||||
client := NewScramClient(username, password)
|
||||
|
||||
// Step 1: Client first
|
||||
clientFirst, err := client.StartAuthentication()
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Step 2: Server first
|
||||
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Step 3: Client final
|
||||
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Step 4: Server final
|
||||
serverFinal, err := server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Step 5: Client verify
|
||||
err = client.VerifyServerFinalMessage(serverFinal)
|
||||
results <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Verify all attempts succeeded
|
||||
successCount := 0
|
||||
for err := range results {
|
||||
if err == nil {
|
||||
successCount++
|
||||
} else {
|
||||
t.Logf("Auth attempt failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, numAttempts, successCount,
|
||||
"All concurrent authentication attempts should succeed")
|
||||
|
||||
// Verify no handshakes remain after completion
|
||||
server.mu.RLock()
|
||||
assert.Empty(t, server.handshakes, "All handshakes should be cleaned up after completion")
|
||||
server.mu.RUnlock()
|
||||
}
|
||||
|
||||
// TestScramExplicitTimeout verifies timeout enforcement
|
||||
func TestScramExplicitTimeout(t *testing.T) {
|
||||
// Save original timeout and set shorter one for testing
|
||||
originalTimeout := ScramHandshakeTimeout
|
||||
// Note: Can't modify const at runtime, so we test with delay instead
|
||||
|
||||
server, username, password, _ := setupScramTest(t)
|
||||
defer server.Stop()
|
||||
|
||||
client := NewScramClient(username, password)
|
||||
|
||||
// Start authentication
|
||||
clientFirst, err := client.StartAuthentication()
|
||||
require.NoError(t, err)
|
||||
|
||||
serverFirst, err := server.ProcessClientFirstMessage(clientFirst.Username, clientFirst.ClientNonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually expire the handshake
|
||||
server.mu.Lock()
|
||||
for nonce := range server.handshakes {
|
||||
server.handshakes[nonce].CreatedAt = time.Now().Add(-2 * ScramHandshakeTimeout)
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// Client processes server message (should work, client tracks own timeout)
|
||||
clientFinal, err := client.ProcessServerFirstMessage(serverFirst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should reject due to timeout
|
||||
_, err = server.ProcessClientFinalMessage(clientFinal.FullNonce, clientFinal.ClientProof)
|
||||
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Server should reject expired handshake")
|
||||
|
||||
// Test client-side timeout
|
||||
client2 := NewScramClient(username, password)
|
||||
client2.startTime = time.Now().Add(-2 * ScramHandshakeTimeout)
|
||||
|
||||
_, err = client2.ProcessServerFirstMessage(serverFirst)
|
||||
assert.ErrorIs(t, err, ErrSCRAMTimeout, "Client should reject after timeout")
|
||||
|
||||
_ = originalTimeout // Suppress unused variable warning
|
||||
}
|
||||
Reference in New Issue
Block a user