v0.4.1 authentication impelemented, not tested and docs not updated
This commit is contained in:
@ -48,7 +48,6 @@ type HTTPSink struct {
|
||||
|
||||
// Net limiting
|
||||
netLimiter *limit.NetLimiter
|
||||
ipChecker *limit.IPChecker
|
||||
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
@ -156,6 +155,22 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
||||
}
|
||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||
for _, entry := range ipWhitelist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||
for _, entry := range ipBlacklist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h := &HTTPSink{
|
||||
@ -195,8 +210,10 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
|
||||
// Configure TLS if enabled
|
||||
if h.tlsManager != nil {
|
||||
tlsConfig := h.tlsManager.GetHTTPConfig()
|
||||
h.server.TLSConfig = tlsConfig
|
||||
h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
|
||||
h.logger.Info("msg", "TLS enabled for HTTP sink",
|
||||
"component", "http_sink",
|
||||
"port", h.config.Port)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%d", h.config.Port)
|
||||
@ -208,12 +225,13 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
"component", "http_sink",
|
||||
"port", h.config.Port,
|
||||
"stream_path", h.streamPath,
|
||||
"status_path", h.statusPath)
|
||||
"status_path", h.statusPath,
|
||||
"tls_enabled", h.tlsManager != nil)
|
||||
|
||||
var err error
|
||||
if h.tlsManager != nil {
|
||||
// HTTPS server
|
||||
err = h.server.ListenAndServeTLS(addr, "", "")
|
||||
err = h.server.ListenAndServeTLS(addr, h.config.SSL.CertFile, h.config.SSL.KeyFile)
|
||||
} else {
|
||||
// HTTP server
|
||||
err = h.server.ListenAndServe(addr)
|
||||
@ -306,24 +324,18 @@ func (h *HTTPSink) GetStats() SinkStats {
|
||||
func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
|
||||
// Check IP access control
|
||||
if h.ipChecker != nil {
|
||||
if !h.ipChecker.IsAllowed(ctx.RemoteAddr()) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetContentType("text/plain")
|
||||
ctx.SetBodyString("Forbidden")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check net limit
|
||||
if h.netLimiter != nil {
|
||||
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
||||
ctx.SetStatusCode(int(statusCode))
|
||||
ctx.SetContentType("application/json")
|
||||
h.logger.Warn("msg", "Net limited",
|
||||
"component", "http_sink",
|
||||
"remote_addr", remoteAddr,
|
||||
"status_code", statusCode,
|
||||
"error", message)
|
||||
json.NewEncoder(ctx).Encode(map[string]any{
|
||||
"error": message,
|
||||
"retry_after": "60", // seconds
|
||||
"error": "Too many requests",
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -355,7 +367,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
|
||||
realm := h.authConfig.BasicAuth.Realm
|
||||
if realm == "" {
|
||||
realm = "LogWisp"
|
||||
realm = "Restricted"
|
||||
}
|
||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
|
||||
} else if h.authConfig.Type == "bearer" {
|
||||
@ -364,7 +376,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Authentication required",
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -379,8 +391,6 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]any{
|
||||
"error": "Not Found",
|
||||
"message": fmt.Sprintf("Available endpoints: %s (SSE transport), %s (status)",
|
||||
h.streamPath, h.statusPath),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -656,14 +666,6 @@ func (h *HTTPSink) GetStatusPath() string {
|
||||
return h.statusPath
|
||||
}
|
||||
|
||||
func (h *HTTPSink) SetNetAccessConfig(cfg *config.NetAccessConfig) {
|
||||
h.ipChecker = limit.NewIPChecker(cfg, h.logger)
|
||||
if h.ipChecker != nil {
|
||||
h.logger.Info("msg", "IP access control configured for HTTP sink",
|
||||
"component", "http_sink")
|
||||
}
|
||||
}
|
||||
|
||||
// SetAuthConfig configures http sink authentication
|
||||
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
|
||||
@ -4,8 +4,12 @@ package sink
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -55,6 +59,7 @@ type HTTPClientConfig struct {
|
||||
|
||||
// TLS configuration
|
||||
InsecureSkipVerify bool
|
||||
CAFile string
|
||||
}
|
||||
|
||||
// NewHTTPClientSink creates a new HTTP client sink
|
||||
@ -126,6 +131,11 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
|
||||
cfg.Headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
// Extract TLS options
|
||||
if caFile, ok := options["ca_file"].(string); ok && caFile != "" {
|
||||
cfg.CAFile = caFile
|
||||
}
|
||||
|
||||
h := &HTTPClientSink{
|
||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||
config: cfg,
|
||||
@ -147,17 +157,28 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
|
||||
DisableHeaderNamesNormalizing: true,
|
||||
}
|
||||
|
||||
// TODO: Implement custom TLS configuration, including InsecureSkipVerify,
|
||||
// by setting a custom dialer on the fasthttp.Client.
|
||||
// For example:
|
||||
// if cfg.InsecureSkipVerify {
|
||||
// h.client.Dial = func(addr string) (net.Conn, error) {
|
||||
// return fasthttp.DialDualStackTimeout(addr, cfg.Timeout, &tls.Config{
|
||||
// InsecureSkipVerify: true,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// FIXED: Removed incorrect TLS configuration that referenced non-existent field
|
||||
// Configure TLS if using HTTPS
|
||||
if strings.HasPrefix(cfg.URL, "https://") {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
// Load custom CA if provided
|
||||
if cfg.CAFile != "" {
|
||||
caCert, err := os.ReadFile(cfg.CAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA file: %w", err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate")
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
// Set TLS config directly on the client
|
||||
h.client.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
@ -341,15 +362,22 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
if attempt > 0 {
|
||||
// Wait before retry
|
||||
time.Sleep(retryDelay)
|
||||
retryDelay = time.Duration(float64(retryDelay) * h.config.RetryBackoff)
|
||||
|
||||
// Calculate new delay with overflow protection
|
||||
newDelay := time.Duration(float64(retryDelay) * h.config.RetryBackoff)
|
||||
|
||||
// Cap at maximum to prevent integer overflow
|
||||
if newDelay > h.config.Timeout || newDelay < retryDelay {
|
||||
// Either exceeded max or overflowed (negative/wrapped)
|
||||
retryDelay = h.config.Timeout
|
||||
} else {
|
||||
retryDelay = newDelay
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: defer placement issue
|
||||
// Create request
|
||||
// Acquire resources inside loop, release immediately after use
|
||||
req := fasthttp.AcquireRequest()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(h.config.URL)
|
||||
req.Header.SetMethod("POST")
|
||||
@ -362,35 +390,50 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
|
||||
// Send request
|
||||
err := h.client.DoTimeout(req, resp, h.config.Timeout)
|
||||
|
||||
// Capture response before releasing
|
||||
statusCode := resp.StatusCode()
|
||||
var responseBody []byte
|
||||
if len(resp.Body()) > 0 {
|
||||
responseBody = make([]byte, len(resp.Body()))
|
||||
copy(responseBody, resp.Body())
|
||||
}
|
||||
|
||||
// Release immediately, not deferred
|
||||
fasthttp.ReleaseRequest(req)
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Handle errors
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("request failed: %w", err)
|
||||
h.logger.Warn("msg", "HTTP request failed",
|
||||
"component", "http_client_sink",
|
||||
"attempt", attempt+1,
|
||||
"max_retries", h.config.MaxRetries,
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check response status
|
||||
statusCode := resp.StatusCode()
|
||||
if statusCode >= 200 && statusCode < 300 {
|
||||
// Success
|
||||
h.logger.Debug("msg", "Batch sent successfully",
|
||||
"component", "http_client_sink",
|
||||
"batch_size", len(batch),
|
||||
"status_code", statusCode)
|
||||
"status_code", statusCode,
|
||||
"attempt", attempt+1)
|
||||
return
|
||||
}
|
||||
|
||||
// Non-2xx status
|
||||
lastErr = fmt.Errorf("server returned status %d: %s", statusCode, resp.Body())
|
||||
lastErr = fmt.Errorf("server returned status %d: %s", statusCode, responseBody)
|
||||
|
||||
// Don't retry on 4xx errors (client errors)
|
||||
if statusCode >= 400 && statusCode < 500 {
|
||||
h.logger.Error("msg", "Batch rejected by server",
|
||||
"component", "http_client_sink",
|
||||
"status_code", statusCode,
|
||||
"response", string(resp.Body()),
|
||||
"response", string(responseBody),
|
||||
"batch_size", len(batch))
|
||||
h.failedBatches.Add(1)
|
||||
return
|
||||
@ -400,13 +443,14 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
"component", "http_client_sink",
|
||||
"attempt", attempt+1,
|
||||
"status_code", statusCode,
|
||||
"response", string(resp.Body()))
|
||||
"response", string(responseBody))
|
||||
}
|
||||
|
||||
// All retries failed
|
||||
h.logger.Error("msg", "Failed to send batch after retries",
|
||||
// All retries exhausted
|
||||
h.logger.Error("msg", "Failed to send batch after all retries",
|
||||
"component", "http_client_sink",
|
||||
"batch_size", len(batch),
|
||||
"retries", h.config.MaxRetries,
|
||||
"last_error", lastErr)
|
||||
h.failedBatches.Add(1)
|
||||
}
|
||||
@ -34,11 +34,6 @@ type SinkStats struct {
|
||||
Details map[string]any
|
||||
}
|
||||
|
||||
// NetAccessSetter is an interface for sinks that can accept network access configuration
|
||||
type NetAccessSetter interface {
|
||||
SetNetAccessConfig(cfg *config.NetAccessConfig)
|
||||
}
|
||||
|
||||
// AuthSetter is an interface for sinks that can accept an AuthConfig.
|
||||
type AuthSetter interface {
|
||||
SetAuthConfig(auth *config.AuthConfig)
|
||||
|
||||
@ -36,7 +36,6 @@ type TCPSink struct {
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
netLimiter *limit.NetLimiter
|
||||
ipChecker *limit.IPChecker
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
@ -50,6 +49,11 @@ type TCPSink struct {
|
||||
lastProcessed atomic.Value // time.Time
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
|
||||
// Write error tracking
|
||||
writeErrors atomic.Uint64
|
||||
consecutiveWriteErrors map[gnet.Conn]int
|
||||
errorMu sync.Mutex
|
||||
}
|
||||
|
||||
// TCPConfig holds TCP sink configuration
|
||||
@ -141,6 +145,22 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
||||
}
|
||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||
for _, entry := range ipWhitelist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||
for _, entry := range ipBlacklist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t := &TCPSink{
|
||||
@ -191,17 +211,6 @@ func (t *TCPSink) Start(ctx context.Context) error {
|
||||
gnet.WithReusePort(true),
|
||||
)
|
||||
|
||||
// Add TLS if configured
|
||||
if t.tlsManager != nil {
|
||||
// tlsConfig := t.tlsManager.GetTCPConfig()
|
||||
// TODO: tlsConfig is not used, wrapper to be implemented, non-TLS stream to be available without wrapper
|
||||
// ☢ SECURITY: gnet doesn't support TLS natively - would need wrapper
|
||||
// This is a limitation that requires implementing TLS at application layer
|
||||
t.logger.Warn("msg", "TLS configured but gnet doesn't support native TLS",
|
||||
"component", "tcp_sink",
|
||||
"workaround", "Use stunnel or nginx TCP proxy for TLS termination")
|
||||
}
|
||||
|
||||
// Start gnet server
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
@ -338,7 +347,29 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
t.server.mu.RLock()
|
||||
for conn, client := range t.server.clients {
|
||||
if client.authenticated {
|
||||
conn.AsyncWrite(data, nil)
|
||||
// Send through TLS bridge if present
|
||||
if client.tlsBridge != nil {
|
||||
if _, err := client.tlsBridge.Write(data); err != nil {
|
||||
// TLS write failed, connection likely dead
|
||||
t.logger.Debug("msg", "TLS write failed",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
conn.Close()
|
||||
}
|
||||
} else {
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
t.handleWriteError(c, err)
|
||||
} else {
|
||||
// Reset consecutive error count on success
|
||||
t.errorMu.Lock()
|
||||
delete(t.consecutiveWriteErrors, c)
|
||||
t.errorMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
t.server.mu.RUnlock()
|
||||
@ -364,7 +395,22 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
conn.AsyncWrite(data, nil)
|
||||
if client.tlsBridge != nil {
|
||||
if _, err := client.tlsBridge.Write(data); err != nil {
|
||||
t.logger.Debug("msg", "TLS heartbeat write failed",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
conn.Close()
|
||||
}
|
||||
} else {
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
t.handleWriteError(c, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
t.server.mu.RUnlock()
|
||||
@ -375,6 +421,36 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle write errors with threshold-based connection termination
|
||||
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
|
||||
t.errorMu.Lock()
|
||||
defer t.errorMu.Unlock()
|
||||
|
||||
// Track consecutive errors per connection
|
||||
if t.consecutiveWriteErrors == nil {
|
||||
t.consecutiveWriteErrors = make(map[gnet.Conn]int)
|
||||
}
|
||||
|
||||
t.consecutiveWriteErrors[c]++
|
||||
errorCount := t.consecutiveWriteErrors[c]
|
||||
|
||||
t.logger.Debug("msg", "AsyncWrite error",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"error", err,
|
||||
"consecutive_errors", errorCount)
|
||||
|
||||
// Close connection after 3 consecutive write errors
|
||||
if errorCount >= 3 {
|
||||
t.logger.Warn("msg", "Closing connection due to repeated write errors",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"error_count", errorCount)
|
||||
delete(t.consecutiveWriteErrors, c)
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Create heartbeat as a proper LogEntry
|
||||
func (t *TCPSink) createHeartbeatEntry() core.LogEntry {
|
||||
message := "heartbeat"
|
||||
@ -406,11 +482,13 @@ func (t *TCPSink) GetActiveConnections() int64 {
|
||||
|
||||
// tcpClient represents a connected TCP client with auth state
|
||||
type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authenticated bool
|
||||
session *auth.Session
|
||||
authTimeout time.Time
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authenticated bool
|
||||
session *auth.Session
|
||||
authTimeout time.Time
|
||||
tlsBridge *tls.GNetTLSConn
|
||||
authTimeoutSet bool
|
||||
}
|
||||
|
||||
// tcpServer handles gnet events with authentication
|
||||
@ -434,15 +512,13 @@ func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
}
|
||||
|
||||
func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
remoteAddr := c.RemoteAddr().String()
|
||||
remoteAddr := c.RemoteAddr()
|
||||
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
|
||||
|
||||
// Check IP access control first
|
||||
if s.sink.ipChecker != nil {
|
||||
if !s.sink.ipChecker.IsAllowed(c.RemoteAddr()) {
|
||||
s.sink.logger.Warn("msg", "TCP connection denied by IP filter",
|
||||
"remote_addr", remoteAddr)
|
||||
return nil, gnet.Close
|
||||
// Reject IPv6 connections immediately
|
||||
if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok {
|
||||
if tcpAddr.IP.To4() == nil {
|
||||
return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close
|
||||
}
|
||||
}
|
||||
|
||||
@ -467,11 +543,26 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
s.sink.netLimiter.AddConnection(remoteStr)
|
||||
}
|
||||
|
||||
// Create client state
|
||||
// Create client state without auth timeout initially
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
||||
authTimeout: time.Now().Add(30 * time.Second), // 30s to authenticate
|
||||
conn: c,
|
||||
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
||||
authTimeoutSet: false, // Auth timeout not started yet
|
||||
}
|
||||
|
||||
// Initialize TLS bridge if enabled
|
||||
if s.sink.tlsManager != nil {
|
||||
tlsConfig := s.sink.tlsManager.GetTCPConfig()
|
||||
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
|
||||
client.tlsBridge.Handshake() // Start async handshake
|
||||
|
||||
s.sink.logger.Debug("msg", "TLS handshake initiated",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", remoteAddr)
|
||||
} else if s.sink.authenticator != nil {
|
||||
// Only set auth timeout if no TLS (plain connection)
|
||||
client.authTimeout = time.Now().Add(30 * time.Second) // TODO: configurable or non-hardcoded timer
|
||||
client.authTimeoutSet = true
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
@ -485,7 +576,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
"requires_auth", s.sink.authenticator != nil)
|
||||
|
||||
// Send auth prompt if authentication is required
|
||||
if s.sink.authenticator != nil {
|
||||
if s.sink.authenticator != nil && s.sink.tlsManager == nil {
|
||||
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
|
||||
return authPrompt, gnet.None
|
||||
}
|
||||
@ -498,9 +589,22 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
|
||||
// Remove client state
|
||||
s.mu.Lock()
|
||||
client := s.clients[c]
|
||||
delete(s.clients, c)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Clean up TLS bridge if present
|
||||
if client != nil && client.tlsBridge != nil {
|
||||
client.tlsBridge.Close()
|
||||
s.sink.logger.Debug("msg", "TLS connection closed",
|
||||
"remote_addr", remoteAddr)
|
||||
}
|
||||
|
||||
// Clean up write error tracking
|
||||
s.sink.errorMu.Lock()
|
||||
delete(s.sink.consecutiveWriteErrors, c)
|
||||
s.sink.errorMu.Unlock()
|
||||
|
||||
// Remove connection tracking
|
||||
if s.sink.netLimiter != nil {
|
||||
s.sink.netLimiter.RemoveConnection(remoteAddr)
|
||||
@ -523,13 +627,18 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Check auth timeout
|
||||
if !client.authenticated && time.Now().After(client.authTimeout) {
|
||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
// // Check auth timeout
|
||||
// if !client.authenticated && time.Now().After(client.authTimeout) {
|
||||
// s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
// "component", "tcp_sink",
|
||||
// "remote_addr", c.RemoteAddr().String())
|
||||
// if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
|
||||
// client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
|
||||
// } else if client.tlsBridge == nil {
|
||||
// c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||
// }
|
||||
// return gnet.Close
|
||||
// }
|
||||
|
||||
// Read all available data
|
||||
data, err := c.Next(-1)
|
||||
@ -540,6 +649,70 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Process through TLS bridge if present
|
||||
if client.tlsBridge != nil {
|
||||
// Feed encrypted data into TLS engine
|
||||
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
|
||||
s.sink.logger.Error("msg", "TLS processing error",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", err)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Check if handshake is complete
|
||||
if !client.tlsBridge.IsHandshakeDone() {
|
||||
// Still handshaking, wait for more data
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Check handshake result
|
||||
_, hsErr := client.tlsBridge.HandshakeComplete()
|
||||
if hsErr != nil {
|
||||
s.sink.logger.Error("msg", "TLS handshake failed",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", hsErr)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Set auth timeout only after TLS handshake completes
|
||||
if !client.authTimeoutSet && s.sink.authenticator != nil && !client.authenticated {
|
||||
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||
client.authTimeoutSet = true
|
||||
s.sink.logger.Debug("msg", "Auth timeout started after TLS handshake",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
}
|
||||
|
||||
// Read decrypted plaintext
|
||||
data = client.tlsBridge.Read()
|
||||
if data == nil || len(data) == 0 {
|
||||
// No plaintext available yet
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// First data after TLS handshake - send auth prompt if needed
|
||||
if s.sink.authenticator != nil && !client.authenticated &&
|
||||
len(client.buffer.Bytes()) == 0 {
|
||||
authPrompt := []byte("AUTH REQUIRED\n")
|
||||
client.tlsBridge.Write(authPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Only check auth timeout if it has been set
|
||||
if !client.authenticated && client.authTimeoutSet && time.Now().After(client.authTimeout) {
|
||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
|
||||
client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
|
||||
} else if client.tlsBridge == nil {
|
||||
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||
}
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// If not authenticated, expect auth command
|
||||
if !client.authenticated {
|
||||
client.buffer.Write(data)
|
||||
@ -551,7 +724,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
// Parse AUTH command: AUTH <method> <credentials>
|
||||
parts := strings.SplitN(string(line), " ", 3)
|
||||
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||
c.AsyncWrite([]byte("ERROR: Invalid auth format\n"), nil)
|
||||
// Send error through TLS if enabled
|
||||
errMsg := []byte("AUTH FAILED\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(errMsg)
|
||||
} else {
|
||||
c.AsyncWrite(errMsg, nil)
|
||||
}
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
@ -563,7 +742,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"method", parts[1],
|
||||
"error", err)
|
||||
c.AsyncWrite([]byte(fmt.Sprintf("AUTH FAILED: %v\n", err)), nil)
|
||||
// Send error through TLS if enabled
|
||||
errMsg := []byte("AUTH FAILED\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(errMsg)
|
||||
} else {
|
||||
c.AsyncWrite(errMsg, nil)
|
||||
}
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
@ -575,11 +760,19 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
s.mu.Unlock()
|
||||
|
||||
s.sink.logger.Info("msg", "TCP client authenticated",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"username", session.Username,
|
||||
"method", session.Method)
|
||||
"method", session.Method,
|
||||
"tls", client.tlsBridge != nil)
|
||||
|
||||
c.AsyncWrite([]byte("AUTH OK\n"), nil)
|
||||
// Send success through TLS if enabled
|
||||
successMsg := []byte("AUTH OK\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(successMsg)
|
||||
} else {
|
||||
c.AsyncWrite(successMsg, nil)
|
||||
}
|
||||
|
||||
// Clear buffer after auth
|
||||
client.buffer.Reset()
|
||||
@ -610,7 +803,7 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
|
||||
// Initialize TLS manager if SSL is configured
|
||||
if t.config.SSL != nil && t.config.SSL.Enabled {
|
||||
tlsManager, err := tls.New(t.config.SSL, t.logger)
|
||||
tlsManager, err := tls.NewManager(t.config.SSL, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to create TLS manager",
|
||||
"component", "tcp_sink",
|
||||
@ -624,5 +817,6 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
t.logger.Info("msg", "Authentication configured for TCP sink",
|
||||
"component", "tcp_sink",
|
||||
"auth_type", authCfg.Type,
|
||||
"tls_enabled", t.tlsManager != nil)
|
||||
"tls_enabled", t.tlsManager != nil,
|
||||
"tls_bridge", t.tlsManager != nil)
|
||||
}
|
||||
@ -3,6 +3,7 @@ package sink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -10,8 +11,10 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
tlspkg "logwisp/src/internal/tls"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
@ -28,6 +31,10 @@ type TCPClientSink struct {
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
// TLS support
|
||||
tlsManager *tlspkg.Manager
|
||||
tlsConfig *tls.Config
|
||||
|
||||
// Reconnection state
|
||||
reconnecting atomic.Bool
|
||||
lastConnectErr error
|
||||
@ -53,6 +60,9 @@ type TCPClientConfig struct {
|
||||
ReconnectDelay time.Duration
|
||||
MaxReconnectDelay time.Duration
|
||||
ReconnectBackoff float64
|
||||
|
||||
// TLS config
|
||||
SSL *config.SSLConfig
|
||||
}
|
||||
|
||||
// NewTCPClientSink creates a new TCP client sink
|
||||
@ -103,6 +113,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
||||
cfg.ReconnectBackoff = backoff
|
||||
}
|
||||
|
||||
// Extract SSL config
|
||||
if ssl, ok := options["ssl"].(map[string]any); ok {
|
||||
cfg.SSL = &config.SSLConfig{}
|
||||
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
|
||||
if certFile, ok := ssl["cert_file"].(string); ok {
|
||||
cfg.SSL.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := ssl["key_file"].(string); ok {
|
||||
cfg.SSL.KeyFile = keyFile
|
||||
}
|
||||
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
|
||||
if caFile, ok := ssl["client_ca_file"].(string); ok {
|
||||
cfg.SSL.ClientCAFile = caFile
|
||||
}
|
||||
if insecure, ok := ssl["insecure_skip_verify"].(bool); ok {
|
||||
cfg.SSL.InsecureSkipVerify = insecure
|
||||
}
|
||||
}
|
||||
|
||||
t := &TCPClientSink{
|
||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||
config: cfg,
|
||||
@ -114,6 +143,34 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
||||
t.lastProcessed.Store(time.Time{})
|
||||
t.connectionUptime.Store(time.Duration(0))
|
||||
|
||||
// Initialize TLS manager if SSL is configured
|
||||
if cfg.SSL != nil && cfg.SSL.Enabled {
|
||||
tlsManager, err := tlspkg.NewManager(cfg.SSL, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
t.tlsManager = tlsManager
|
||||
|
||||
// Get client TLS config
|
||||
t.tlsConfig = tlsManager.GetTCPConfig()
|
||||
|
||||
// ADDED: Client-specific TLS config adjustments
|
||||
t.tlsConfig.InsecureSkipVerify = cfg.SSL.InsecureSkipVerify
|
||||
|
||||
// Extract server name from address for SNI
|
||||
host, _, err := net.SplitHostPort(cfg.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse address for SNI: %w", err)
|
||||
}
|
||||
t.tlsConfig.ServerName = host
|
||||
|
||||
logger.Info("msg", "TLS enabled for TCP client",
|
||||
"component", "tcp_client_sink",
|
||||
"address", cfg.Address,
|
||||
"server_name", host,
|
||||
"insecure", cfg.SSL.InsecureSkipVerify)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@ -280,6 +337,35 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
|
||||
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
||||
}
|
||||
|
||||
// Wrap with TLS if configured
|
||||
if t.tlsConfig != nil {
|
||||
t.logger.Debug("msg", "Initiating TLS handshake",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address)
|
||||
|
||||
tlsConn := tls.Client(conn, t.tlsConfig)
|
||||
|
||||
// Perform handshake with timeout
|
||||
handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := tlsConn.HandshakeContext(handshakeCtx); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
// Log connection details
|
||||
state := tlsConn.ConnectionState()
|
||||
t.logger.Info("msg", "TLS connection established",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"tls_version", tlsVersionString(state.Version),
|
||||
"cipher_suite", tls.CipherSuiteName(state.CipherSuite),
|
||||
"server_name", state.ServerName)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@ -295,7 +381,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) {
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Set read deadline
|
||||
// TODO: Add t.config.ReadTimeout instead of static value
|
||||
// TODO: Add t.config.ReadTimeout and after addition use it instead of static value
|
||||
if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
|
||||
t.logger.Debug("msg", "Failed to set read deadline", "error", err)
|
||||
return
|
||||
@ -378,4 +464,20 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tlsVersionString returns human-readable TLS version
|
||||
func tlsVersionString(version uint16) string {
|
||||
switch version {
|
||||
case tls.VersionTLS10:
|
||||
return "TLS1.0"
|
||||
case tls.VersionTLS11:
|
||||
return "TLS1.1"
|
||||
case tls.VersionTLS12:
|
||||
return "TLS1.2"
|
||||
case tls.VersionTLS13:
|
||||
return "TLS1.3"
|
||||
default:
|
||||
return fmt.Sprintf("0x%04x", version)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user