v0.6.0 auth restructuring, scram auth added, more tests added

This commit is contained in:
2025-10-02 17:16:43 -04:00
parent 3047e556f7
commit 490fb777ab
37 changed files with 2283 additions and 888 deletions

View File

@ -4,7 +4,7 @@ package sink
import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
@ -13,26 +13,25 @@ import (
"sync/atomic"
"time"
"logwisp/src/internal/auth"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"logwisp/src/internal/format"
"logwisp/src/internal/scram"
"github.com/lixenwraith/log"
)
// Forwards log entries to a remote TCP endpoint
type TCPClientSink struct {
input chan core.LogEntry
config TCPClientConfig
conn net.Conn
connMu sync.RWMutex
done chan struct{}
wg sync.WaitGroup
startTime time.Time
logger *log.Logger
formatter format.Formatter
authenticator *auth.Authenticator
input chan core.LogEntry
config TCPClientConfig
conn net.Conn
connMu sync.RWMutex
done chan struct{}
wg sync.WaitGroup
startTime time.Time
logger *log.Logger
formatter format.Formatter
// Reconnection state
reconnecting atomic.Bool
@ -49,24 +48,22 @@ type TCPClientSink struct {
// Holds TCP client sink configuration
type TCPClientConfig struct {
Address string
BufferSize int64
DialTimeout time.Duration
WriteTimeout time.Duration
ReadTimeout time.Duration
KeepAlive time.Duration
Address string `toml:"address"`
BufferSize int64 `toml:"buffer_size"`
DialTimeout time.Duration `toml:"dial_timeout_seconds"`
WriteTimeout time.Duration `toml:"write_timeout_seconds"`
ReadTimeout time.Duration `toml:"read_timeout_seconds"`
KeepAlive time.Duration `toml:"keep_alive_seconds"`
// Security
Username string
Password string
AuthType string `toml:"auth_type"`
Username string `toml:"username"`
Password string `toml:"password"`
// Reconnection settings
ReconnectDelay time.Duration
MaxReconnectDelay time.Duration
ReconnectBackoff float64
// TLS config
TLS *config.TLSConfig
ReconnectDelay time.Duration `toml:"reconnect_delay_ms"`
MaxReconnectDelay time.Duration `toml:"max_reconnect_delay_seconds"`
ReconnectBackoff float64 `toml:"reconnect_backoff"`
}
// Creates a new TCP client sink
@ -120,11 +117,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
cfg.ReconnectBackoff = backoff
}
if username, ok := options["username"].(string); ok {
cfg.Username = username
}
if password, ok := options["password"].(string); ok {
cfg.Password = password
if authType, ok := options["auth_type"].(string); ok {
switch authType {
case "none":
cfg.AuthType = authType
case "scram":
cfg.AuthType = authType
if username, ok := options["username"].(string); ok && username != "" {
cfg.Username = username
} else {
return nil, fmt.Errorf("invalid scram username")
}
if password, ok := options["password"].(string); ok && password != "" {
cfg.Password = password
} else {
return nil, fmt.Errorf("invalid scram password")
}
default:
return nil, fmt.Errorf("tcp_client sink: invalid auth_type '%s' (must be 'none' or 'scram')", authType)
}
}
t := &TCPClientSink{
@ -304,49 +315,115 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
}
// Handle authentication if credentials configured
if t.config.Username != "" && t.config.Password != "" {
// Read auth challenge
reader := bufio.NewReader(conn)
challenge, err := reader.ReadString('\n')
if err != nil {
// SCRAM authentication if credentials configured
if t.config.AuthType == "scram" {
if err := t.performSCRAMAuth(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to read auth challenge: %w", err)
}
if strings.TrimSpace(challenge) == "AUTH_REQUIRED" {
// Send credentials
creds := t.config.Username + ":" + t.config.Password
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
authCmd := fmt.Sprintf("AUTH basic %s\n", encodedCreds)
if _, err := conn.Write([]byte(authCmd)); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to send auth: %w", err)
}
// Read response
response, err := reader.ReadString('\n')
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to read auth response: %w", err)
}
if strings.TrimSpace(response) != "AUTH_OK" {
conn.Close()
return nil, fmt.Errorf("authentication failed: %s", response)
}
t.logger.Debug("msg", "TCP authentication successful",
"component", "tcp_client_sink",
"address", t.config.Address,
"username", t.config.Username)
return nil, fmt.Errorf("SCRAM authentication failed: %w", err)
}
t.logger.Debug("msg", "SCRAM authentication completed",
"component", "tcp_client_sink",
"address", t.config.Address)
}
return conn, nil
}
func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
reader := bufio.NewReader(conn)
// Create SCRAM client
scramClient := scram.NewClient(t.config.Username, t.config.Password)
// Step 1: Send ClientFirst
clientFirst, err := scramClient.StartAuthentication()
if err != nil {
return fmt.Errorf("failed to start SCRAM: %w", err)
}
clientFirstJSON, _ := json.Marshal(clientFirst)
msg := fmt.Sprintf("SCRAM-FIRST %s\n", clientFirstJSON)
if _, err := conn.Write([]byte(msg)); err != nil {
return fmt.Errorf("failed to send SCRAM-FIRST: %w", err)
}
// Step 2: Receive ServerFirst challenge
response, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read SCRAM challenge: %w", err)
}
parts := strings.Fields(strings.TrimSpace(response))
if len(parts) != 2 || parts[0] != "SCRAM-CHALLENGE" {
return fmt.Errorf("unexpected server response: %s", response)
}
var serverFirst scram.ServerFirst
if err := json.Unmarshal([]byte(parts[1]), &serverFirst); err != nil {
return fmt.Errorf("failed to parse server challenge: %w", err)
}
// Step 3: Process challenge and send proof
clientFinal, err := scramClient.ProcessServerFirst(&serverFirst)
if err != nil {
return fmt.Errorf("failed to process challenge: %w", err)
}
clientFinalJSON, _ := json.Marshal(clientFinal)
msg = fmt.Sprintf("SCRAM-PROOF %s\n", clientFinalJSON)
if _, err := conn.Write([]byte(msg)); err != nil {
return fmt.Errorf("failed to send SCRAM-PROOF: %w", err)
}
// Step 4: Receive ServerFinal
response, err = reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read SCRAM result: %w", err)
}
parts = strings.Fields(strings.TrimSpace(response))
if len(parts) < 1 {
return fmt.Errorf("empty server response")
}
switch parts[0] {
case "SCRAM-OK":
if len(parts) != 2 {
return fmt.Errorf("invalid SCRAM-OK response")
}
var serverFinal scram.ServerFinal
if err := json.Unmarshal([]byte(parts[1]), &serverFinal); err != nil {
return fmt.Errorf("failed to parse server signature: %w", err)
}
// Verify server signature
if err := scramClient.VerifyServerFinal(&serverFinal); err != nil {
return fmt.Errorf("server signature verification failed: %w", err)
}
t.logger.Info("msg", "SCRAM authentication successful",
"component", "tcp_client_sink",
"address", t.config.Address,
"username", t.config.Username,
"session_id", serverFinal.SessionID)
return nil
case "SCRAM-FAIL":
reason := "unknown"
if len(parts) > 1 {
reason = strings.Join(parts[1:], " ")
}
return fmt.Errorf("authentication failed: %s", reason)
default:
return fmt.Errorf("unexpected response: %s", response)
}
}
func (t *TCPClientSink) monitorConnection(conn net.Conn) {
// Simple connection monitoring by periodic zero-byte reads
ticker := time.NewTicker(5 * time.Second)