v0.4.1 authentication impelemented, not tested and docs not updated
This commit is contained in:
@ -3,6 +3,7 @@ package sink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -10,8 +11,10 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
tlspkg "logwisp/src/internal/tls"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
@ -28,6 +31,10 @@ type TCPClientSink struct {
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
// TLS support
|
||||
tlsManager *tlspkg.Manager
|
||||
tlsConfig *tls.Config
|
||||
|
||||
// Reconnection state
|
||||
reconnecting atomic.Bool
|
||||
lastConnectErr error
|
||||
@ -53,6 +60,9 @@ type TCPClientConfig struct {
|
||||
ReconnectDelay time.Duration
|
||||
MaxReconnectDelay time.Duration
|
||||
ReconnectBackoff float64
|
||||
|
||||
// TLS config
|
||||
SSL *config.SSLConfig
|
||||
}
|
||||
|
||||
// NewTCPClientSink creates a new TCP client sink
|
||||
@ -103,6 +113,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
||||
cfg.ReconnectBackoff = backoff
|
||||
}
|
||||
|
||||
// Extract SSL config
|
||||
if ssl, ok := options["ssl"].(map[string]any); ok {
|
||||
cfg.SSL = &config.SSLConfig{}
|
||||
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
|
||||
if certFile, ok := ssl["cert_file"].(string); ok {
|
||||
cfg.SSL.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := ssl["key_file"].(string); ok {
|
||||
cfg.SSL.KeyFile = keyFile
|
||||
}
|
||||
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
|
||||
if caFile, ok := ssl["client_ca_file"].(string); ok {
|
||||
cfg.SSL.ClientCAFile = caFile
|
||||
}
|
||||
if insecure, ok := ssl["insecure_skip_verify"].(bool); ok {
|
||||
cfg.SSL.InsecureSkipVerify = insecure
|
||||
}
|
||||
}
|
||||
|
||||
t := &TCPClientSink{
|
||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||
config: cfg,
|
||||
@ -114,6 +143,34 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
||||
t.lastProcessed.Store(time.Time{})
|
||||
t.connectionUptime.Store(time.Duration(0))
|
||||
|
||||
// Initialize TLS manager if SSL is configured
|
||||
if cfg.SSL != nil && cfg.SSL.Enabled {
|
||||
tlsManager, err := tlspkg.NewManager(cfg.SSL, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
t.tlsManager = tlsManager
|
||||
|
||||
// Get client TLS config
|
||||
t.tlsConfig = tlsManager.GetTCPConfig()
|
||||
|
||||
// ADDED: Client-specific TLS config adjustments
|
||||
t.tlsConfig.InsecureSkipVerify = cfg.SSL.InsecureSkipVerify
|
||||
|
||||
// Extract server name from address for SNI
|
||||
host, _, err := net.SplitHostPort(cfg.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse address for SNI: %w", err)
|
||||
}
|
||||
t.tlsConfig.ServerName = host
|
||||
|
||||
logger.Info("msg", "TLS enabled for TCP client",
|
||||
"component", "tcp_client_sink",
|
||||
"address", cfg.Address,
|
||||
"server_name", host,
|
||||
"insecure", cfg.SSL.InsecureSkipVerify)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@ -280,6 +337,35 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
|
||||
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
||||
}
|
||||
|
||||
// Wrap with TLS if configured
|
||||
if t.tlsConfig != nil {
|
||||
t.logger.Debug("msg", "Initiating TLS handshake",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address)
|
||||
|
||||
tlsConn := tls.Client(conn, t.tlsConfig)
|
||||
|
||||
// Perform handshake with timeout
|
||||
handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := tlsConn.HandshakeContext(handshakeCtx); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
// Log connection details
|
||||
state := tlsConn.ConnectionState()
|
||||
t.logger.Info("msg", "TLS connection established",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"tls_version", tlsVersionString(state.Version),
|
||||
"cipher_suite", tls.CipherSuiteName(state.CipherSuite),
|
||||
"server_name", state.ServerName)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@ -295,7 +381,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) {
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Set read deadline
|
||||
// TODO: Add t.config.ReadTimeout instead of static value
|
||||
// TODO: Add t.config.ReadTimeout and after addition use it instead of static value
|
||||
if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
|
||||
t.logger.Debug("msg", "Failed to set read deadline", "error", err)
|
||||
return
|
||||
@ -378,4 +464,20 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tlsVersionString returns human-readable TLS version
|
||||
func tlsVersionString(version uint16) string {
|
||||
switch version {
|
||||
case tls.VersionTLS10:
|
||||
return "TLS1.0"
|
||||
case tls.VersionTLS11:
|
||||
return "TLS1.1"
|
||||
case tls.VersionTLS12:
|
||||
return "TLS1.2"
|
||||
case tls.VersionTLS13:
|
||||
return "TLS1.3"
|
||||
default:
|
||||
return fmt.Sprintf("0x%04x", version)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user