v0.2.1 argon2 and scram imporoved
This commit is contained in:
89
scram.go
89
scram.go
@ -17,6 +17,13 @@ import (
|
||||
|
||||
// SCRAM-SHA256 implementation
|
||||
|
||||
const (
|
||||
// ScramHandshakeTimeout defines maximum time for completing SCRAM handshake
|
||||
ScramHandshakeTimeout = 30 * time.Second
|
||||
// ScramCleanupInterval defines how often expired handshakes are cleaned
|
||||
ScramCleanupInterval = 60 * time.Second
|
||||
)
|
||||
|
||||
// Credential stores SCRAM authentication data
|
||||
type Credential struct {
|
||||
Username string
|
||||
@ -166,13 +173,6 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ScramServer handles server-side SCRAM authentication
|
||||
type ScramServer struct {
|
||||
credentials map[string]*Credential
|
||||
handshakes map[string]*HandshakeState
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// HandshakeState tracks ongoing authentication
|
||||
type HandshakeState struct {
|
||||
Username string
|
||||
@ -184,19 +184,59 @@ type HandshakeState struct {
|
||||
verifying int32 // Atomic flag to prevent race during verification
|
||||
}
|
||||
|
||||
// ScramServer handles server-side SCRAM authentication
|
||||
type ScramServer struct {
|
||||
credentials map[string]*Credential
|
||||
handshakes map[string]*HandshakeState
|
||||
mu sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupStop chan struct{}
|
||||
}
|
||||
|
||||
// NewScramServer creates SCRAM server
|
||||
func NewScramServer() *ScramServer {
|
||||
return &ScramServer{
|
||||
credentials: make(map[string]*Credential),
|
||||
handshakes: make(map[string]*HandshakeState),
|
||||
s := &ScramServer{
|
||||
credentials: make(map[string]*Credential),
|
||||
handshakes: make(map[string]*HandshakeState),
|
||||
cleanupTicker: time.NewTicker(ScramCleanupInterval),
|
||||
cleanupStop: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go s.cleanupLoop()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the server and cleanup goroutine
|
||||
func (s *ScramServer) Stop() {
|
||||
close(s.cleanupStop)
|
||||
s.cleanupTicker.Stop()
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired handshakes
|
||||
func (s *ScramServer) cleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-s.cleanupTicker.C:
|
||||
s.cleanupExpiredHandshakes()
|
||||
case <-s.cleanupStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddCredential registers user credential
|
||||
func (s *ScramServer) AddCredential(cred *Credential) {
|
||||
// cleanupExpiredHandshakes removes handshakes older than timeout
|
||||
func (s *ScramServer) cleanupExpiredHandshakes() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.credentials[cred.Username] = cred
|
||||
|
||||
cutoff := time.Now().Add(-ScramHandshakeTimeout)
|
||||
for nonce, state := range s.handshakes {
|
||||
if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 {
|
||||
delete(s.handshakes, nonce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessClientFirstMessage processes initial auth request
|
||||
@ -237,9 +277,6 @@ func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (S
|
||||
}
|
||||
s.handshakes[fullNonce] = state
|
||||
|
||||
// Cleanup old handshakes
|
||||
s.cleanupHandshakes()
|
||||
|
||||
return ServerFirstMessage{
|
||||
FullNonce: fullNonce,
|
||||
Salt: base64.StdEncoding.EncodeToString(cred.Salt),
|
||||
@ -272,10 +309,11 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (
|
||||
}()
|
||||
|
||||
// Check timeout
|
||||
if time.Since(state.CreatedAt) > 60*time.Second {
|
||||
if time.Since(state.CreatedAt) > ScramHandshakeTimeout {
|
||||
return ServerFinalMessage{}, ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
// [rest of verification logic unchanged]
|
||||
// Decode client proof
|
||||
clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof)
|
||||
if err != nil {
|
||||
@ -318,6 +356,13 @@ func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddCredential registers user credential
|
||||
func (s *ScramServer) AddCredential(cred *Credential) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.credentials[cred.Username] = cred
|
||||
}
|
||||
|
||||
func (s *ScramServer) cleanupHandshakes() {
|
||||
cutoff := time.Now().Add(-60 * time.Second)
|
||||
for nonce, state := range s.handshakes {
|
||||
@ -365,16 +410,13 @@ func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) {
|
||||
|
||||
// ProcessServerFirstMessage handles server challenge
|
||||
func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) {
|
||||
// Check timeout (30 seconds)
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
||||
// Check timeout
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout {
|
||||
return ClientFinalRequest{}, ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
c.serverFirst = &msg
|
||||
|
||||
// Handle enumeration prevention - server may send fake response
|
||||
// We still process it normally and let verification fail later
|
||||
|
||||
// Decode salt
|
||||
salt, err := base64.StdEncoding.DecodeString(msg.Salt)
|
||||
if err != nil {
|
||||
@ -414,10 +456,11 @@ func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientF
|
||||
// VerifyServerFinalMessage validates server signature
|
||||
func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error {
|
||||
// Check timeout
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
||||
if !c.startTime.IsZero() && time.Since(c.startTime) > ScramHandshakeTimeout {
|
||||
return ErrSCRAMTimeout
|
||||
}
|
||||
|
||||
// [rest unchanged]
|
||||
if c.authMessage == "" || c.serverKey == nil {
|
||||
return ErrSCRAMInvalidState
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user