v0.6.0 auth restructuring, scram auth added, more tests added
This commit is contained in:
@ -4,7 +4,7 @@ package sink
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -13,26 +13,25 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/scram"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Forwards log entries to a remote TCP endpoint
|
||||
type TCPClientSink struct {
|
||||
input chan core.LogEntry
|
||||
config TCPClientConfig
|
||||
conn net.Conn
|
||||
connMu sync.RWMutex
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
startTime time.Time
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
authenticator *auth.Authenticator
|
||||
input chan core.LogEntry
|
||||
config TCPClientConfig
|
||||
conn net.Conn
|
||||
connMu sync.RWMutex
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
startTime time.Time
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
// Reconnection state
|
||||
reconnecting atomic.Bool
|
||||
@ -49,24 +48,22 @@ type TCPClientSink struct {
|
||||
|
||||
// Holds TCP client sink configuration
|
||||
type TCPClientConfig struct {
|
||||
Address string
|
||||
BufferSize int64
|
||||
DialTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
KeepAlive time.Duration
|
||||
Address string `toml:"address"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
DialTimeout time.Duration `toml:"dial_timeout_seconds"`
|
||||
WriteTimeout time.Duration `toml:"write_timeout_seconds"`
|
||||
ReadTimeout time.Duration `toml:"read_timeout_seconds"`
|
||||
KeepAlive time.Duration `toml:"keep_alive_seconds"`
|
||||
|
||||
// Security
|
||||
Username string
|
||||
Password string
|
||||
AuthType string `toml:"auth_type"`
|
||||
Username string `toml:"username"`
|
||||
Password string `toml:"password"`
|
||||
|
||||
// Reconnection settings
|
||||
ReconnectDelay time.Duration
|
||||
MaxReconnectDelay time.Duration
|
||||
ReconnectBackoff float64
|
||||
|
||||
// TLS config
|
||||
TLS *config.TLSConfig
|
||||
ReconnectDelay time.Duration `toml:"reconnect_delay_ms"`
|
||||
MaxReconnectDelay time.Duration `toml:"max_reconnect_delay_seconds"`
|
||||
ReconnectBackoff float64 `toml:"reconnect_backoff"`
|
||||
}
|
||||
|
||||
// Creates a new TCP client sink
|
||||
@ -120,11 +117,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
||||
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
|
||||
cfg.ReconnectBackoff = backoff
|
||||
}
|
||||
if username, ok := options["username"].(string); ok {
|
||||
cfg.Username = username
|
||||
}
|
||||
if password, ok := options["password"].(string); ok {
|
||||
cfg.Password = password
|
||||
if authType, ok := options["auth_type"].(string); ok {
|
||||
switch authType {
|
||||
case "none":
|
||||
cfg.AuthType = authType
|
||||
case "scram":
|
||||
cfg.AuthType = authType
|
||||
if username, ok := options["username"].(string); ok && username != "" {
|
||||
cfg.Username = username
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid scram username")
|
||||
}
|
||||
if password, ok := options["password"].(string); ok && password != "" {
|
||||
cfg.Password = password
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid scram password")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("tcp_client sink: invalid auth_type '%s' (must be 'none' or 'scram')", authType)
|
||||
}
|
||||
}
|
||||
|
||||
t := &TCPClientSink{
|
||||
@ -304,49 +315,115 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
|
||||
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
||||
}
|
||||
|
||||
// Handle authentication if credentials configured
|
||||
if t.config.Username != "" && t.config.Password != "" {
|
||||
// Read auth challenge
|
||||
reader := bufio.NewReader(conn)
|
||||
challenge, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
// SCRAM authentication if credentials configured
|
||||
if t.config.AuthType == "scram" {
|
||||
if err := t.performSCRAMAuth(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to read auth challenge: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(challenge) == "AUTH_REQUIRED" {
|
||||
// Send credentials
|
||||
creds := t.config.Username + ":" + t.config.Password
|
||||
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
|
||||
authCmd := fmt.Sprintf("AUTH basic %s\n", encodedCreds)
|
||||
|
||||
if _, err := conn.Write([]byte(authCmd)); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to send auth: %w", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to read auth response: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(response) != "AUTH_OK" {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %s", response)
|
||||
}
|
||||
|
||||
t.logger.Debug("msg", "TCP authentication successful",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"username", t.config.Username)
|
||||
return nil, fmt.Errorf("SCRAM authentication failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("msg", "SCRAM authentication completed",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
|
||||
reader := bufio.NewReader(conn)
|
||||
|
||||
// Create SCRAM client
|
||||
scramClient := scram.NewClient(t.config.Username, t.config.Password)
|
||||
|
||||
// Step 1: Send ClientFirst
|
||||
clientFirst, err := scramClient.StartAuthentication()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start SCRAM: %w", err)
|
||||
}
|
||||
|
||||
clientFirstJSON, _ := json.Marshal(clientFirst)
|
||||
msg := fmt.Sprintf("SCRAM-FIRST %s\n", clientFirstJSON)
|
||||
|
||||
if _, err := conn.Write([]byte(msg)); err != nil {
|
||||
return fmt.Errorf("failed to send SCRAM-FIRST: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Receive ServerFirst challenge
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read SCRAM challenge: %w", err)
|
||||
}
|
||||
|
||||
parts := strings.Fields(strings.TrimSpace(response))
|
||||
if len(parts) != 2 || parts[0] != "SCRAM-CHALLENGE" {
|
||||
return fmt.Errorf("unexpected server response: %s", response)
|
||||
}
|
||||
|
||||
var serverFirst scram.ServerFirst
|
||||
if err := json.Unmarshal([]byte(parts[1]), &serverFirst); err != nil {
|
||||
return fmt.Errorf("failed to parse server challenge: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Process challenge and send proof
|
||||
clientFinal, err := scramClient.ProcessServerFirst(&serverFirst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process challenge: %w", err)
|
||||
}
|
||||
|
||||
clientFinalJSON, _ := json.Marshal(clientFinal)
|
||||
msg = fmt.Sprintf("SCRAM-PROOF %s\n", clientFinalJSON)
|
||||
|
||||
if _, err := conn.Write([]byte(msg)); err != nil {
|
||||
return fmt.Errorf("failed to send SCRAM-PROOF: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Receive ServerFinal
|
||||
response, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read SCRAM result: %w", err)
|
||||
}
|
||||
|
||||
parts = strings.Fields(strings.TrimSpace(response))
|
||||
if len(parts) < 1 {
|
||||
return fmt.Errorf("empty server response")
|
||||
}
|
||||
|
||||
switch parts[0] {
|
||||
case "SCRAM-OK":
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid SCRAM-OK response")
|
||||
}
|
||||
|
||||
var serverFinal scram.ServerFinal
|
||||
if err := json.Unmarshal([]byte(parts[1]), &serverFinal); err != nil {
|
||||
return fmt.Errorf("failed to parse server signature: %w", err)
|
||||
}
|
||||
|
||||
// Verify server signature
|
||||
if err := scramClient.VerifyServerFinal(&serverFinal); err != nil {
|
||||
return fmt.Errorf("server signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
t.logger.Info("msg", "SCRAM authentication successful",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"username", t.config.Username,
|
||||
"session_id", serverFinal.SessionID)
|
||||
|
||||
return nil
|
||||
|
||||
case "SCRAM-FAIL":
|
||||
reason := "unknown"
|
||||
if len(parts) > 1 {
|
||||
reason = strings.Join(parts[1:], " ")
|
||||
}
|
||||
return fmt.Errorf("authentication failed: %s", reason)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected response: %s", response)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TCPClientSink) monitorConnection(conn net.Conn) {
|
||||
// Simple connection monitoring by periodic zero-byte reads
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
|
||||
Reference in New Issue
Block a user