Files
logwisp/src/internal/sink/tcp_client.go

537 lines
14 KiB
Go

// FILE: logwisp/src/internal/sink/tcp_client.go
package sink
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"logwisp/src/internal/format"
tlspkg "logwisp/src/internal/tls"
"github.com/lixenwraith/log"
)
// TCPClientSink forwards log entries to a remote TCP endpoint
type TCPClientSink struct {
input chan core.LogEntry
config TCPClientConfig
conn net.Conn
connMu sync.RWMutex
done chan struct{}
wg sync.WaitGroup
startTime time.Time
logger *log.Logger
formatter format.Formatter
// TLS support
tlsManager *tlspkg.Manager
tlsConfig *tls.Config
// Reconnection state
reconnecting atomic.Bool
lastConnectErr error
connectTime time.Time
// Statistics
totalProcessed atomic.Uint64
totalFailed atomic.Uint64
totalReconnects atomic.Uint64
lastProcessed atomic.Value // time.Time
connectionUptime atomic.Value // time.Duration
}
// TCPClientConfig holds TCP client sink configuration
type TCPClientConfig struct {
Address string
BufferSize int64
DialTimeout time.Duration
WriteTimeout time.Duration
ReadTimeout time.Duration
KeepAlive time.Duration
// Reconnection settings
ReconnectDelay time.Duration
MaxReconnectDelay time.Duration
ReconnectBackoff float64
// TLS config
SSL *config.SSLConfig
}
// NewTCPClientSink creates a new TCP client sink
func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) {
cfg := TCPClientConfig{
BufferSize: int64(1000),
DialTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
ReadTimeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
ReconnectDelay: time.Second,
MaxReconnectDelay: 30 * time.Second,
ReconnectBackoff: float64(1.5),
}
// Extract address
address, ok := options["address"].(string)
if !ok || address == "" {
return nil, fmt.Errorf("tcp_client sink requires 'address' option")
}
// Validate address format
_, _, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("invalid address format (expected host:port): %w", err)
}
cfg.Address = address
// Extract other options
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
cfg.BufferSize = bufSize
}
if dialTimeout, ok := options["dial_timeout_seconds"].(int64); ok && dialTimeout > 0 {
cfg.DialTimeout = time.Duration(dialTimeout) * time.Second
}
if writeTimeout, ok := options["write_timeout_seconds"].(int64); ok && writeTimeout > 0 {
cfg.WriteTimeout = time.Duration(writeTimeout) * time.Second
}
if readTimeout, ok := options["read_timeout_seconds"].(int64); ok && readTimeout > 0 {
cfg.ReadTimeout = time.Duration(readTimeout) * time.Second
}
if keepAlive, ok := options["keep_alive_seconds"].(int64); ok && keepAlive > 0 {
cfg.KeepAlive = time.Duration(keepAlive) * time.Second
}
if reconnectDelay, ok := options["reconnect_delay_ms"].(int64); ok && reconnectDelay > 0 {
cfg.ReconnectDelay = time.Duration(reconnectDelay) * time.Millisecond
}
if maxReconnectDelay, ok := options["max_reconnect_delay_seconds"].(int64); ok && maxReconnectDelay > 0 {
cfg.MaxReconnectDelay = time.Duration(maxReconnectDelay) * time.Second
}
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
cfg.ReconnectBackoff = backoff
}
// Extract SSL config
if ssl, ok := options["ssl"].(map[string]any); ok {
cfg.SSL = &config.SSLConfig{}
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
if certFile, ok := ssl["cert_file"].(string); ok {
cfg.SSL.CertFile = certFile
}
if keyFile, ok := ssl["key_file"].(string); ok {
cfg.SSL.KeyFile = keyFile
}
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
if caFile, ok := ssl["client_ca_file"].(string); ok {
cfg.SSL.ClientCAFile = caFile
}
if insecure, ok := ssl["insecure_skip_verify"].(bool); ok {
cfg.SSL.InsecureSkipVerify = insecure
}
if caFile, ok := ssl["ca_file"].(string); ok {
cfg.SSL.CAFile = caFile
}
}
t := &TCPClientSink{
input: make(chan core.LogEntry, cfg.BufferSize),
config: cfg,
done: make(chan struct{}),
startTime: time.Now(),
logger: logger,
formatter: formatter,
}
t.lastProcessed.Store(time.Time{})
t.connectionUptime.Store(time.Duration(0))
// Initialize TLS manager if SSL is configured
if cfg.SSL != nil && cfg.SSL.Enabled {
// Build custom TLS config for client
t.tlsConfig = &tls.Config{
InsecureSkipVerify: cfg.SSL.InsecureSkipVerify,
}
// Extract server name from address for SNI
host, _, err := net.SplitHostPort(cfg.Address)
if err != nil {
return nil, fmt.Errorf("failed to parse address for SNI: %w", err)
}
t.tlsConfig.ServerName = host
// Load custom CA for server verification
if cfg.SSL.CAFile != "" {
caCert, err := os.ReadFile(cfg.SSL.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.SSL.CAFile, err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.SSL.CAFile)
}
t.tlsConfig.RootCAs = caCertPool
logger.Debug("msg", "Custom CA loaded for server verification",
"component", "tcp_client_sink",
"ca_file", cfg.SSL.CAFile)
}
// Load client certificate for mTLS
if cfg.SSL.CertFile != "" && cfg.SSL.KeyFile != "" {
cert, err := tls.LoadX509KeyPair(cfg.SSL.CertFile, cfg.SSL.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}
t.tlsConfig.Certificates = []tls.Certificate{cert}
logger.Info("msg", "Client certificate loaded for mTLS",
"component", "tcp_client_sink",
"cert_file", cfg.SSL.CertFile)
}
// Set minimum TLS version if configured
if cfg.SSL.MinVersion != "" {
t.tlsConfig.MinVersion = parseTLSVersion(cfg.SSL.MinVersion, tls.VersionTLS12)
} else {
t.tlsConfig.MinVersion = tls.VersionTLS12 // Default minimum
}
logger.Info("msg", "TLS enabled for TCP client",
"component", "tcp_client_sink",
"address", cfg.Address,
"server_name", host,
"insecure", cfg.SSL.InsecureSkipVerify,
"mtls", cfg.SSL.CertFile != "")
}
return t, nil
}
func (t *TCPClientSink) Input() chan<- core.LogEntry {
return t.input
}
func (t *TCPClientSink) Start(ctx context.Context) error {
// Start connection manager
t.wg.Add(1)
go t.connectionManager(ctx)
// Start processing loop
t.wg.Add(1)
go t.processLoop(ctx)
t.logger.Info("msg", "TCP client sink started",
"component", "tcp_client_sink",
"address", t.config.Address)
return nil
}
func (t *TCPClientSink) Stop() {
t.logger.Info("msg", "Stopping TCP client sink")
close(t.done)
t.wg.Wait()
// Close connection
t.connMu.Lock()
if t.conn != nil {
_ = t.conn.Close()
}
t.connMu.Unlock()
t.logger.Info("msg", "TCP client sink stopped",
"total_processed", t.totalProcessed.Load(),
"total_failed", t.totalFailed.Load(),
"total_reconnects", t.totalReconnects.Load())
}
func (t *TCPClientSink) GetStats() SinkStats {
lastProc, _ := t.lastProcessed.Load().(time.Time)
uptime, _ := t.connectionUptime.Load().(time.Duration)
t.connMu.RLock()
connected := t.conn != nil
t.connMu.RUnlock()
activeConns := int64(0)
if connected {
activeConns = 1
}
return SinkStats{
Type: "tcp_client",
TotalProcessed: t.totalProcessed.Load(),
ActiveConnections: activeConns,
StartTime: t.startTime,
LastProcessed: lastProc,
Details: map[string]any{
"address": t.config.Address,
"connected": connected,
"reconnecting": t.reconnecting.Load(),
"total_failed": t.totalFailed.Load(),
"total_reconnects": t.totalReconnects.Load(),
"connection_uptime": uptime.Seconds(),
"last_error": fmt.Sprintf("%v", t.lastConnectErr),
},
}
}
func (t *TCPClientSink) connectionManager(ctx context.Context) {
defer t.wg.Done()
reconnectDelay := t.config.ReconnectDelay
for {
select {
case <-ctx.Done():
return
case <-t.done:
return
default:
}
// Attempt to connect
t.reconnecting.Store(true)
conn, err := t.connect()
t.reconnecting.Store(false)
if err != nil {
t.lastConnectErr = err
t.logger.Warn("msg", "Failed to connect to TCP server",
"component", "tcp_client_sink",
"address", t.config.Address,
"error", err,
"retry_delay", reconnectDelay)
// Wait before retry
select {
case <-ctx.Done():
return
case <-t.done:
return
case <-time.After(reconnectDelay):
}
// Exponential backoff
reconnectDelay = time.Duration(float64(reconnectDelay) * t.config.ReconnectBackoff)
if reconnectDelay > t.config.MaxReconnectDelay {
reconnectDelay = t.config.MaxReconnectDelay
}
continue
}
// Connection successful
t.lastConnectErr = nil
reconnectDelay = t.config.ReconnectDelay // Reset backoff
t.connectTime = time.Now()
t.totalReconnects.Add(1)
t.connMu.Lock()
t.conn = conn
t.connMu.Unlock()
t.logger.Info("msg", "Connected to TCP server",
"component", "tcp_client_sink",
"address", t.config.Address,
"local_addr", conn.LocalAddr())
// Monitor connection
t.monitorConnection(conn)
// Connection lost, clear it
t.connMu.Lock()
t.conn = nil
t.connMu.Unlock()
// Update connection uptime
uptime := time.Since(t.connectTime)
t.connectionUptime.Store(uptime)
t.logger.Warn("msg", "Lost connection to TCP server",
"component", "tcp_client_sink",
"address", t.config.Address,
"uptime", uptime)
}
}
func (t *TCPClientSink) connect() (net.Conn, error) {
dialer := &net.Dialer{
Timeout: t.config.DialTimeout,
KeepAlive: t.config.KeepAlive,
}
conn, err := dialer.Dial("tcp", t.config.Address)
if err != nil {
return nil, err
}
// Set TCP keep-alive
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
}
// Wrap with TLS if configured
if t.tlsConfig != nil {
t.logger.Debug("msg", "Initiating TLS handshake",
"component", "tcp_client_sink",
"address", t.config.Address)
tlsConn := tls.Client(conn, t.tlsConfig)
// Perform handshake with timeout
handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := tlsConn.HandshakeContext(handshakeCtx); err != nil {
conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
// Log connection details
state := tlsConn.ConnectionState()
t.logger.Info("msg", "TLS connection established",
"component", "tcp_client_sink",
"address", t.config.Address,
"tls_version", tlsVersionString(state.Version),
"cipher_suite", tls.CipherSuiteName(state.CipherSuite),
"server_name", state.ServerName)
return tlsConn, nil
}
return conn, nil
}
func (t *TCPClientSink) monitorConnection(conn net.Conn) {
// Simple connection monitoring by periodic zero-byte reads
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
buf := make([]byte, 1)
for {
select {
case <-t.done:
return
case <-ticker.C:
// Set read deadline
if err := conn.SetReadDeadline(time.Now().Add(t.config.ReadTimeout)); err != nil {
t.logger.Debug("msg", "Failed to set read deadline", "error", err)
return
}
// Try to read (we don't expect any data)
_, err := conn.Read(buf)
if err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
// Timeout is expected, connection is still alive
continue
}
// Real error, connection is dead
return
}
}
}
}
func (t *TCPClientSink) processLoop(ctx context.Context) {
defer t.wg.Done()
for {
select {
case entry, ok := <-t.input:
if !ok {
return
}
t.totalProcessed.Add(1)
t.lastProcessed.Store(time.Now())
// Send entry
if err := t.sendEntry(entry); err != nil {
t.totalFailed.Add(1)
t.logger.Debug("msg", "Failed to send log entry",
"component", "tcp_client_sink",
"error", err)
}
case <-ctx.Done():
return
case <-t.done:
return
}
}
}
func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
// Get current connection
t.connMu.RLock()
conn := t.conn
t.connMu.RUnlock()
if conn == nil {
return fmt.Errorf("not connected")
}
// Format data
data, err := t.formatter.Format(entry)
if err != nil {
return fmt.Errorf("failed to marshal entry: %w", err)
}
// Set write deadline
if err := conn.SetWriteDeadline(time.Now().Add(t.config.WriteTimeout)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
// Write data
n, err := conn.Write(data)
if err != nil {
// Connection error, it will be reconnected
return fmt.Errorf("write failed: %w", err)
}
if n != len(data) {
return fmt.Errorf("partial write: %d/%d bytes", n, len(data))
}
return nil
}
// tlsVersionString returns human-readable TLS version
func tlsVersionString(version uint16) string {
switch version {
case tls.VersionTLS10:
return "TLS1.0"
case tls.VersionTLS11:
return "TLS1.1"
case tls.VersionTLS12:
return "TLS1.2"
case tls.VersionTLS13:
return "TLS1.3"
default:
return fmt.Sprintf("0x%04x", version)
}
}
// parseTLSVersion converts string to TLS version constant
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
switch strings.ToUpper(version) {
case "TLS1.0", "TLS10":
return tls.VersionTLS10
case "TLS1.1", "TLS11":
return tls.VersionTLS11
case "TLS1.2", "TLS12":
return tls.VersionTLS12
case "TLS1.3", "TLS13":
return tls.VersionTLS13
default:
return defaultVersion
}
}