v0.2.1 argon2 and scram imporoved

This commit is contained in:
2025-11-04 08:10:16 -05:00
parent 3471030edd
commit aafa680a35
4 changed files with 371 additions and 23 deletions

View File

@ -17,6 +17,13 @@ import (
// SCRAM-SHA256 implementation
const (
// ScramHandshakeTimeout defines maximum time for completing SCRAM handshake
ScramHandshakeTimeout = 30 * time.Second
// ScramCleanupInterval defines how often expired handshakes are cleaned
ScramCleanupInterval = 60 * time.Second
)
// Credential stores SCRAM authentication data
type Credential struct {
Username string
@ -166,13 +173,6 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
}, nil
}
// ScramServer handles server-side SCRAM authentication
type ScramServer struct {
credentials map[string]*Credential
handshakes map[string]*HandshakeState
mu sync.RWMutex
}
// HandshakeState tracks ongoing authentication
type HandshakeState struct {
Username string
@ -184,19 +184,59 @@ type HandshakeState struct {
verifying int32 // Atomic flag to prevent race during verification
}
// ScramServer handles server-side SCRAM authentication
type ScramServer struct {
credentials map[string]*Credential
handshakes map[string]*HandshakeState
mu sync.RWMutex
cleanupTicker *time.Ticker
cleanupStop chan struct{}
}
// NewScramServer creates SCRAM server
func NewScramServer() *ScramServer {
return &ScramServer{
credentials: make(map[string]*Credential),
handshakes: make(map[string]*HandshakeState),
s := &ScramServer{
credentials: make(map[string]*Credential),
handshakes: make(map[string]*HandshakeState),
cleanupTicker: time.NewTicker(ScramCleanupInterval),
cleanupStop: make(chan struct{}),
}
// Start background cleanup goroutine
go s.cleanupLoop()
return s
}
// Stop gracefully shuts down the server and cleanup goroutine
func (s *ScramServer) Stop() {
close(s.cleanupStop)
s.cleanupTicker.Stop()
}
// cleanupLoop runs periodic cleanup of expired handshakes
func (s *ScramServer) cleanupLoop() {
for {
select {
case <-s.cleanupTicker.C:
s.cleanupExpiredHandshakes()
case <-s.cleanupStop:
return
}
}
}
// AddCredential registers user credential
func (s *ScramServer) AddCredential(cred *Credential) {
// cleanupExpiredHandshakes removes handshakes older than timeout
func (s *ScramServer) cleanupExpiredHandshakes() {
s.mu.Lock()
defer s.mu.Unlock()
s.credentials[cred.Username] = cred
cutoff := time.Now().Add(-ScramHandshakeTimeout)
for nonce, state := range s.handshakes {
if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 {
delete(s.handshakes, nonce)
}
}
}
// ProcessClientFirstMessage processes initial auth request
@ -237,9 +277,6 @@ func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (S
}
s.handshakes[fullNonce] = state
// Cleanup old handshakes
s.cleanupHandshakes()
return ServerFirstMessage{
FullNonce: fullNonce,
Salt: base64.StdEncoding.EncodeToString(cred.Salt),
@ -272,10 +309,11 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (
}()
// Check timeout
if time.Since(state.CreatedAt) > 60*time.Second {
if time.Since(state.CreatedAt) > ScramHandshakeTimeout {
return ServerFinalMessage{}, ErrSCRAMTimeout
}
// [rest of verification logic unchanged]
// Decode client proof
clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof)
if err != nil {
@ -318,6 +356,13 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (
}, nil
}
// AddCredential registers user credential
func (s *ScramServer) AddCredential(cred *Credential) {
s.mu.Lock()
defer s.mu.Unlock()
s.credentials[cred.Username] = cred
}
func (s *ScramServer) cleanupHandshakes() {
cutoff := time.Now().Add(-60 * time.Second)
for nonce, state := range s.handshakes {
@ -365,16 +410,13 @@ func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) {
// ProcessServerFirstMessage handles server challenge
func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) {
// Check timeout (30 seconds)
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
// Check timeout
if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout {
return ClientFinalRequest{}, ErrSCRAMTimeout
}
c.serverFirst = &msg
// Handle enumeration prevention - server may send fake response
// We still process it normally and let verification fail later
// Decode salt
salt, err := base64.StdEncoding.DecodeString(msg.Salt)
if err != nil {
@ -414,10 +456,11 @@ func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientF
// VerifyServerFinalMessage validates server signature
func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error {
// Check timeout
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout {
return ErrSCRAMTimeout
}
// [rest unchanged]
if c.authMessage == "" || c.serverKey == nil {
return ErrSCRAMInvalidState
}