502 lines
13 KiB
Go
502 lines
13 KiB
Go
// FILE: auth/scram.go
|
|
package auth
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/argon2"
|
|
)
|
|
|
|
// SCRAM-SHA256 implementation
|
|
|
|
// Credential stores SCRAM authentication data
|
|
type Credential struct {
|
|
Username string
|
|
Salt []byte
|
|
ArgonTime uint32
|
|
ArgonMemory uint32
|
|
ArgonThreads uint8
|
|
StoredKey []byte // SHA256(ClientKey)
|
|
ServerKey []byte
|
|
}
|
|
|
|
// Export returns credential as config-friendly map
|
|
func (c *Credential) Export() map[string]any {
|
|
return map[string]any{
|
|
"username": c.Username,
|
|
"salt": base64.StdEncoding.EncodeToString(c.Salt),
|
|
"argon_time": c.ArgonTime,
|
|
"argon_memory": c.ArgonMemory,
|
|
"argon_threads": c.ArgonThreads,
|
|
"stored_key": base64.StdEncoding.EncodeToString(c.StoredKey),
|
|
"server_key": base64.StdEncoding.EncodeToString(c.ServerKey),
|
|
}
|
|
}
|
|
|
|
// ImportCredential creates credential from map
|
|
func ImportCredential(data map[string]any) (*Credential, error) {
|
|
username, ok := data["username"].(string)
|
|
if !ok {
|
|
return nil, ErrCredMissingUsername
|
|
}
|
|
|
|
saltStr, ok := data["salt"].(string)
|
|
if !ok {
|
|
return nil, ErrCredMissingSalt
|
|
}
|
|
salt, err := base64.StdEncoding.DecodeString(saltStr)
|
|
if err != nil {
|
|
return nil, ErrCredInvalidSalt
|
|
}
|
|
|
|
// Handle both float64 (from JSON) and int types
|
|
getUint32 := func(key string) (uint32, error) {
|
|
val, ok := data[key]
|
|
if !ok {
|
|
switch key {
|
|
case "argon_time":
|
|
return 0, ErrCredMissingTime
|
|
case "argon_memory":
|
|
return 0, ErrCredMissingMemory
|
|
default:
|
|
return 0, fmt.Errorf("missing %s", key)
|
|
}
|
|
}
|
|
switch v := val.(type) {
|
|
case float64:
|
|
return uint32(v), nil
|
|
case int:
|
|
return uint32(v), nil
|
|
case uint32:
|
|
return v, nil
|
|
default:
|
|
return 0, fmt.Errorf("invalid type for %s", key)
|
|
}
|
|
}
|
|
|
|
argonTime, err := getUint32("argon_time")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
argonMemory, err := getUint32("argon_memory")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
threadsVal, ok := data["argon_threads"]
|
|
if !ok {
|
|
return nil, ErrCredMissingThreads
|
|
}
|
|
var argonThreads uint8
|
|
switch v := threadsVal.(type) {
|
|
case float64:
|
|
argonThreads = uint8(v)
|
|
case int:
|
|
argonThreads = uint8(v)
|
|
case uint8:
|
|
argonThreads = v
|
|
default:
|
|
return nil, fmt.Errorf("%w: argon_threads", ErrCredInvalidType)
|
|
}
|
|
|
|
storedKeyStr, ok := data["stored_key"].(string)
|
|
if !ok {
|
|
return nil, ErrCredMissingStoredKey
|
|
}
|
|
storedKey, err := base64.StdEncoding.DecodeString(storedKeyStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %v", ErrCredInvalidStoredKey, err)
|
|
}
|
|
|
|
serverKeyStr, ok := data["server_key"].(string)
|
|
if !ok {
|
|
return nil, ErrCredMissingServerKey
|
|
}
|
|
serverKey, err := base64.StdEncoding.DecodeString(serverKeyStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %v", ErrCredInvalidServerKey, err)
|
|
}
|
|
|
|
return &Credential{
|
|
Username: username,
|
|
Salt: salt,
|
|
ArgonTime: argonTime,
|
|
ArgonMemory: argonMemory,
|
|
ArgonThreads: argonThreads,
|
|
StoredKey: storedKey,
|
|
ServerKey: serverKey,
|
|
}, nil
|
|
}
|
|
|
|
// DeriveCredential creates SCRAM credential from password
|
|
func DeriveCredential(username, password string, salt []byte, time, memory uint32, threads uint8) (*Credential, error) {
|
|
if len(salt) < 16 {
|
|
return nil, ErrSCRAMSaltTooShort
|
|
}
|
|
|
|
if time == 0 || memory == 0 || threads == 0 {
|
|
return nil, ErrSCRAMZeroParams
|
|
}
|
|
|
|
// Derive salted password using Argon2id
|
|
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, DefaultArgonKeyLen)
|
|
|
|
// Derive keys
|
|
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
|
storedKey := sha256.Sum256(clientKey)
|
|
|
|
return &Credential{
|
|
Username: username,
|
|
Salt: salt,
|
|
ArgonTime: time,
|
|
ArgonMemory: memory,
|
|
ArgonThreads: threads,
|
|
StoredKey: storedKey[:],
|
|
ServerKey: serverKey,
|
|
}, nil
|
|
}
|
|
|
|
// ScramServer handles server-side SCRAM authentication
|
|
type ScramServer struct {
|
|
credentials map[string]*Credential
|
|
handshakes map[string]*HandshakeState
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// HandshakeState tracks ongoing authentication
|
|
type HandshakeState struct {
|
|
Username string
|
|
ClientNonce string
|
|
ServerNonce string
|
|
FullNonce string
|
|
Credential *Credential
|
|
CreatedAt time.Time
|
|
verifying int32 // Atomic flag to prevent race during verification
|
|
}
|
|
|
|
// NewScramServer creates SCRAM server
|
|
func NewScramServer() *ScramServer {
|
|
return &ScramServer{
|
|
credentials: make(map[string]*Credential),
|
|
handshakes: make(map[string]*HandshakeState),
|
|
}
|
|
}
|
|
|
|
// AddCredential registers user credential
|
|
func (s *ScramServer) AddCredential(cred *Credential) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.credentials[cred.Username] = cred
|
|
}
|
|
|
|
// ProcessClientFirstMessage processes initial auth request
|
|
func (s *ScramServer) ProcessClientFirstMessage(username, clientNonce string) (ServerFirstMessage, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Check if user exists
|
|
cred, exists := s.credentials[username]
|
|
if !exists {
|
|
// Prevent user enumeration - still generate response
|
|
salt := make([]byte, 16)
|
|
rand.Read(salt)
|
|
serverNonce := generateNonce()
|
|
|
|
return ServerFirstMessage{
|
|
FullNonce: clientNonce + serverNonce,
|
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
|
ArgonTime: DefaultArgonTime,
|
|
ArgonMemory: DefaultArgonMemory,
|
|
ArgonThreads: DefaultArgonThreads,
|
|
}, ErrInvalidCredentials
|
|
}
|
|
|
|
// Generate server nonce
|
|
serverNonce := generateNonce()
|
|
fullNonce := clientNonce + serverNonce
|
|
|
|
// Store handshake state
|
|
state := &HandshakeState{
|
|
Username: username,
|
|
ClientNonce: clientNonce,
|
|
ServerNonce: serverNonce,
|
|
FullNonce: fullNonce,
|
|
Credential: cred,
|
|
CreatedAt: time.Now(),
|
|
verifying: 0,
|
|
}
|
|
s.handshakes[fullNonce] = state
|
|
|
|
// Cleanup old handshakes
|
|
s.cleanupHandshakes()
|
|
|
|
return ServerFirstMessage{
|
|
FullNonce: fullNonce,
|
|
Salt: base64.StdEncoding.EncodeToString(cred.Salt),
|
|
ArgonTime: cred.ArgonTime,
|
|
ArgonMemory: cred.ArgonMemory,
|
|
ArgonThreads: cred.ArgonThreads,
|
|
}, nil
|
|
}
|
|
|
|
// ProcessClientFinalMessage verifies client proof
|
|
func (s *ScramServer) ProcessClientFinalMessage(fullNonce, clientProof string) (ServerFinalMessage, error) {
|
|
s.mu.RLock()
|
|
state, exists := s.handshakes[fullNonce]
|
|
s.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return ServerFinalMessage{}, ErrSCRAMInvalidNonce
|
|
}
|
|
|
|
// Mark as verifying to prevent deletion race
|
|
if !atomic.CompareAndSwapInt32(&state.verifying, 0, 1) {
|
|
return ServerFinalMessage{}, ErrSCRAMVerifyInProgress
|
|
}
|
|
defer func() {
|
|
atomic.StoreInt32(&state.verifying, 0)
|
|
// Safe to delete after verification completes
|
|
s.mu.Lock()
|
|
delete(s.handshakes, fullNonce)
|
|
s.mu.Unlock()
|
|
}()
|
|
|
|
// Check timeout
|
|
if time.Since(state.CreatedAt) > 60*time.Second {
|
|
return ServerFinalMessage{}, ErrSCRAMTimeout
|
|
}
|
|
|
|
// Decode client proof
|
|
clientProofBytes, err := base64.StdEncoding.DecodeString(clientProof)
|
|
if err != nil {
|
|
return ServerFinalMessage{}, ErrSCRAMInvalidProof
|
|
}
|
|
|
|
// Build auth message
|
|
clientFirstBare := fmt.Sprintf("u=%s,n=%s", state.Username, state.ClientNonce)
|
|
serverFirst := ServerFirstMessage{
|
|
FullNonce: state.FullNonce,
|
|
Salt: base64.StdEncoding.EncodeToString(state.Credential.Salt),
|
|
ArgonTime: state.Credential.ArgonTime,
|
|
ArgonMemory: state.Credential.ArgonMemory,
|
|
ArgonThreads: state.Credential.ArgonThreads,
|
|
}
|
|
clientFinalBare := fmt.Sprintf("r=%s", fullNonce)
|
|
authMessage := clientFirstBare + "," + serverFirst.Marshal() + "," + clientFinalBare
|
|
|
|
// Compute client signature
|
|
clientSignature := computeHMAC(state.Credential.StoredKey, []byte(authMessage))
|
|
|
|
// XOR to get ClientKey
|
|
if len(clientProofBytes) != len(clientSignature) {
|
|
return ServerFinalMessage{}, ErrSCRAMInvalidProofLen
|
|
}
|
|
clientKey := xorBytes(clientProofBytes, clientSignature)
|
|
|
|
// Verify by computing StoredKey
|
|
computedStoredKey := sha256.Sum256(clientKey)
|
|
if subtle.ConstantTimeCompare(computedStoredKey[:], state.Credential.StoredKey) != 1 {
|
|
return ServerFinalMessage{}, ErrInvalidCredentials
|
|
}
|
|
|
|
// Generate server signature for mutual auth
|
|
serverSignature := computeHMAC(state.Credential.ServerKey, []byte(authMessage))
|
|
|
|
return ServerFinalMessage{
|
|
ServerSignature: base64.StdEncoding.EncodeToString(serverSignature),
|
|
Username: state.Username,
|
|
}, nil
|
|
}
|
|
|
|
func (s *ScramServer) cleanupHandshakes() {
|
|
cutoff := time.Now().Add(-60 * time.Second)
|
|
for nonce, state := range s.handshakes {
|
|
if state.CreatedAt.Before(cutoff) && atomic.LoadInt32(&state.verifying) == 0 {
|
|
delete(s.handshakes, nonce)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ScramClient handles client-side SCRAM authentication
|
|
type ScramClient struct {
|
|
Username string
|
|
Password string
|
|
clientNonce string
|
|
serverFirst *ServerFirstMessage
|
|
authMessage string
|
|
serverKey []byte
|
|
startTime time.Time // Track handshake start
|
|
}
|
|
|
|
// NewScramClient creates SCRAM client
|
|
func NewScramClient(username, password string) *ScramClient {
|
|
return &ScramClient{
|
|
Username: username,
|
|
Password: password,
|
|
}
|
|
}
|
|
|
|
// StartAuthentication generates initial client message
|
|
func (c *ScramClient) StartAuthentication() (ClientFirstRequest, error) {
|
|
c.startTime = time.Now()
|
|
|
|
// Generate client nonce
|
|
nonce := make([]byte, 32)
|
|
if _, err := rand.Read(nonce); err != nil {
|
|
return ClientFirstRequest{}, ErrSCRAMNonceGenFailed
|
|
}
|
|
c.clientNonce = base64.StdEncoding.EncodeToString(nonce)
|
|
|
|
return ClientFirstRequest{
|
|
Username: c.Username,
|
|
ClientNonce: c.clientNonce,
|
|
}, nil
|
|
}
|
|
|
|
// ProcessServerFirstMessage handles server challenge
|
|
func (c *ScramClient) ProcessServerFirstMessage(msg ServerFirstMessage) (ClientFinalRequest, error) {
|
|
// Check timeout (30 seconds)
|
|
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
|
return ClientFinalRequest{}, ErrSCRAMTimeout
|
|
}
|
|
|
|
c.serverFirst = &msg
|
|
|
|
// Handle enumeration prevention - server may send fake response
|
|
// We still process it normally and let verification fail later
|
|
|
|
// Decode salt
|
|
salt, err := base64.StdEncoding.DecodeString(msg.Salt)
|
|
if err != nil {
|
|
return ClientFinalRequest{}, ErrSCRAMInvalidSalt
|
|
}
|
|
|
|
// Validate parameters
|
|
if msg.ArgonTime == 0 || msg.ArgonMemory == 0 || msg.ArgonThreads == 0 {
|
|
return ClientFinalRequest{}, ErrSCRAMZeroParams
|
|
}
|
|
|
|
// Derive keys using Argon2id
|
|
saltedPassword := argon2.IDKey([]byte(c.Password), salt, msg.ArgonTime, msg.ArgonMemory, msg.ArgonThreads, 32)
|
|
|
|
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
|
storedKey := sha256.Sum256(clientKey)
|
|
|
|
// Build auth message
|
|
clientFirstBare := fmt.Sprintf("u=%s,n=%s", c.Username, c.clientNonce)
|
|
clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce)
|
|
c.authMessage = clientFirstBare + "," + msg.Marshal() + "," + clientFinalBare
|
|
|
|
// Compute client proof
|
|
clientSignature := computeHMAC(storedKey[:], []byte(c.authMessage))
|
|
clientProof := xorBytes(clientKey, clientSignature)
|
|
|
|
// Store server key for verification
|
|
c.serverKey = serverKey
|
|
|
|
return ClientFinalRequest{
|
|
FullNonce: msg.FullNonce,
|
|
ClientProof: base64.StdEncoding.EncodeToString(clientProof),
|
|
}, nil
|
|
}
|
|
|
|
// VerifyServerFinalMessage validates server signature
|
|
func (c *ScramClient) VerifyServerFinalMessage(msg ServerFinalMessage) error {
|
|
// Check timeout
|
|
if !c.startTime.IsZero() && time.Since(c.startTime) > 30*time.Second {
|
|
return ErrSCRAMTimeout
|
|
}
|
|
|
|
if c.authMessage == "" || c.serverKey == nil {
|
|
return ErrSCRAMInvalidState
|
|
}
|
|
|
|
// Compute expected server signature
|
|
expectedSig := computeHMAC(c.serverKey, []byte(c.authMessage))
|
|
|
|
// Decode received signature
|
|
receivedSig, err := base64.StdEncoding.DecodeString(msg.ServerSignature)
|
|
if err != nil {
|
|
return ErrSCRAMServerAuthFailed
|
|
}
|
|
|
|
// Constant-time comparison
|
|
if subtle.ConstantTimeCompare(expectedSig, receivedSig) != 1 {
|
|
return ErrSCRAMServerAuthFailed
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Reset clears client state for retry
|
|
func (c *ScramClient) Reset() {
|
|
c.clientNonce = ""
|
|
c.serverFirst = nil
|
|
c.authMessage = ""
|
|
c.serverKey = nil
|
|
c.startTime = time.Time{}
|
|
}
|
|
|
|
// SCRAM message types
|
|
type ClientFirstRequest struct {
|
|
Username string `json:"username"`
|
|
ClientNonce string `json:"client_nonce"`
|
|
}
|
|
|
|
type ServerFirstMessage struct {
|
|
FullNonce string `json:"full_nonce"`
|
|
Salt string `json:"salt"`
|
|
ArgonTime uint32 `json:"argon_time"`
|
|
ArgonMemory uint32 `json:"argon_memory"`
|
|
ArgonThreads uint8 `json:"argon_threads"`
|
|
}
|
|
|
|
func (s ServerFirstMessage) Marshal() string {
|
|
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
|
|
s.FullNonce, s.Salt, s.ArgonTime, s.ArgonMemory, s.ArgonThreads)
|
|
}
|
|
|
|
type ClientFinalRequest struct {
|
|
FullNonce string `json:"full_nonce"`
|
|
ClientProof string `json:"client_proof"`
|
|
}
|
|
|
|
type ServerFinalMessage struct {
|
|
ServerSignature string `json:"server_signature"`
|
|
Username string `json:"username,omitempty"`
|
|
}
|
|
|
|
// Helper functions
|
|
func computeHMAC(key, message []byte) []byte {
|
|
mac := hmac.New(sha256.New, key)
|
|
mac.Write(message)
|
|
return mac.Sum(nil)
|
|
}
|
|
|
|
func xorBytes(a, b []byte) []byte {
|
|
if len(a) != len(b) {
|
|
panic("xor length mismatch")
|
|
}
|
|
result := make([]byte, len(a))
|
|
for i := range a {
|
|
result[i] = a[i] ^ b[i]
|
|
}
|
|
return result
|
|
}
|
|
|
|
func generateNonce() string {
|
|
b := make([]byte, 32)
|
|
rand.Read(b)
|
|
return base64.StdEncoding.EncodeToString(b)
|
|
} |