v0.4.0 authentication added and router mode removed
This commit is contained in:
@ -11,10 +11,12 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/tls"
|
||||
"logwisp/src/internal/version"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
@ -35,19 +37,24 @@ type HTTPSink struct {
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
// Security components
|
||||
authenticator *auth.Authenticator
|
||||
tlsManager *tls.Manager
|
||||
authConfig *config.AuthConfig
|
||||
|
||||
// Path configuration
|
||||
streamPath string
|
||||
statusPath string
|
||||
|
||||
// For router integration
|
||||
standalone bool
|
||||
|
||||
// Net limiting
|
||||
netLimiter *limit.NetLimiter
|
||||
ipChecker *limit.IPChecker
|
||||
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
lastProcessed atomic.Value // time.Time
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
}
|
||||
|
||||
// HTTPConfig holds HTTP sink configuration
|
||||
@ -98,6 +105,32 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
||||
}
|
||||
}
|
||||
|
||||
// Extract SSL config
|
||||
if ssl, ok := options["ssl"].(map[string]any); ok {
|
||||
cfg.SSL = &config.SSLConfig{}
|
||||
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
|
||||
if certFile, ok := ssl["cert_file"].(string); ok {
|
||||
cfg.SSL.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := ssl["key_file"].(string); ok {
|
||||
cfg.SSL.KeyFile = keyFile
|
||||
}
|
||||
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
|
||||
if caFile, ok := ssl["client_ca_file"].(string); ok {
|
||||
cfg.SSL.ClientCAFile = caFile
|
||||
}
|
||||
cfg.SSL.VerifyClientCert, _ = ssl["verify_client_cert"].(bool)
|
||||
if minVer, ok := ssl["min_version"].(string); ok {
|
||||
cfg.SSL.MinVersion = minVer
|
||||
}
|
||||
if maxVer, ok := ssl["max_version"].(string); ok {
|
||||
cfg.SSL.MaxVersion = maxVer
|
||||
}
|
||||
if ciphers, ok := ssl["cipher_suites"].(string); ok {
|
||||
cfg.SSL.CipherSuites = ciphers
|
||||
}
|
||||
}
|
||||
|
||||
// Extract net limit config
|
||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
||||
cfg.NetLimit = &config.NetLimitConfig{}
|
||||
@ -132,7 +165,6 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
||||
done: make(chan struct{}),
|
||||
streamPath: cfg.StreamPath,
|
||||
statusPath: cfg.StatusPath,
|
||||
standalone: true,
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
}
|
||||
@ -151,13 +183,6 @@ func (h *HTTPSink) Input() chan<- core.LogEntry {
|
||||
}
|
||||
|
||||
func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
if !h.standalone {
|
||||
// In router mode, don't start our own server
|
||||
h.logger.Debug("msg", "HTTP sink in router mode, skipping server start",
|
||||
"component", "http_sink")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create fasthttp adapter for logging
|
||||
fasthttpLogger := compat.NewFastHTTPAdapter(h.logger)
|
||||
|
||||
@ -168,6 +193,12 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
Logger: fasthttpLogger,
|
||||
}
|
||||
|
||||
// Configure TLS if enabled
|
||||
if h.tlsManager != nil {
|
||||
tlsConfig := h.tlsManager.GetHTTPConfig()
|
||||
h.server.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%d", h.config.Port)
|
||||
|
||||
// Run server in separate goroutine to avoid blocking
|
||||
@ -178,7 +209,16 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
"port", h.config.Port,
|
||||
"stream_path", h.streamPath,
|
||||
"status_path", h.statusPath)
|
||||
err := h.server.ListenAndServe(addr)
|
||||
|
||||
var err error
|
||||
if h.tlsManager != nil {
|
||||
// HTTPS server
|
||||
err = h.server.ListenAndServeTLS(addr, "", "")
|
||||
} else {
|
||||
// HTTP server
|
||||
err = h.server.ListenAndServe(addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
@ -210,8 +250,8 @@ func (h *HTTPSink) Stop() {
|
||||
// Signal all client handlers to stop
|
||||
close(h.done)
|
||||
|
||||
// Shutdown HTTP server if in standalone mode
|
||||
if h.standalone && h.server != nil {
|
||||
// Shutdown HTTP server
|
||||
if h.server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
h.server.ShutdownWithContext(ctx)
|
||||
@ -231,6 +271,18 @@ func (h *HTTPSink) GetStats() SinkStats {
|
||||
netLimitStats = h.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var authStats map[string]any
|
||||
if h.authenticator != nil {
|
||||
authStats = h.authenticator.GetStats()
|
||||
authStats["failures"] = h.authFailures.Load()
|
||||
authStats["successes"] = h.authSuccesses.Load()
|
||||
}
|
||||
|
||||
var tlsStats map[string]any
|
||||
if h.tlsManager != nil {
|
||||
tlsStats = h.tlsManager.GetStats()
|
||||
}
|
||||
|
||||
return SinkStats{
|
||||
Type: "http",
|
||||
TotalProcessed: h.totalProcessed.Load(),
|
||||
@ -245,42 +297,83 @@ func (h *HTTPSink) GetStats() SinkStats {
|
||||
"status": h.statusPath,
|
||||
},
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
"tls": tlsStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetRouterMode configures the sink for use with a router
|
||||
func (h *HTTPSink) SetRouterMode() {
|
||||
h.standalone = false
|
||||
h.logger.Debug("msg", "HTTP sink set to router mode",
|
||||
"component", "http_sink")
|
||||
}
|
||||
|
||||
// RouteRequest handles a request from the router
|
||||
func (h *HTTPSink) RouteRequest(ctx *fasthttp.RequestCtx) {
|
||||
h.requestHandler(ctx)
|
||||
}
|
||||
|
||||
func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
// Check net limit first
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
||||
ctx.SetStatusCode(int(statusCode))
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]any{
|
||||
"error": message,
|
||||
"retry_after": "60", // seconds
|
||||
})
|
||||
return
|
||||
|
||||
// Check IP access control
|
||||
if h.ipChecker != nil {
|
||||
if !h.ipChecker.IsAllowed(ctx.RemoteAddr()) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetContentType("text/plain")
|
||||
ctx.SetBodyString("Forbidden")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check net limit
|
||||
if h.netLimiter != nil {
|
||||
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
||||
ctx.SetStatusCode(int(statusCode))
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]any{
|
||||
"error": message,
|
||||
"retry_after": "60", // seconds
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Status endpoint doesn't require auth
|
||||
if path == h.statusPath {
|
||||
h.handleStatus(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate request
|
||||
var session *auth.Session
|
||||
if h.authenticator != nil {
|
||||
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||
var err error
|
||||
session, err = h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
|
||||
if err != nil {
|
||||
h.authFailures.Add(1)
|
||||
h.logger.Warn("msg", "Authentication failed",
|
||||
"component", "http_sink",
|
||||
"remote_addr", remoteAddr,
|
||||
"error", err)
|
||||
|
||||
// Return 401 with WWW-Authenticate header
|
||||
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
||||
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
|
||||
realm := h.authConfig.BasicAuth.Realm
|
||||
if realm == "" {
|
||||
realm = "LogWisp"
|
||||
}
|
||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
|
||||
} else if h.authConfig.Type == "bearer" {
|
||||
ctx.Response.Header.Set("WWW-Authenticate", "Bearer")
|
||||
}
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Authentication required",
|
||||
})
|
||||
return
|
||||
}
|
||||
h.authSuccesses.Add(1)
|
||||
}
|
||||
|
||||
switch path {
|
||||
case h.streamPath:
|
||||
h.handleStream(ctx)
|
||||
case h.statusPath:
|
||||
h.handleStatus(ctx)
|
||||
h.handleStream(ctx, session)
|
||||
default:
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetContentType("application/json")
|
||||
@ -292,7 +385,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
|
||||
func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) {
|
||||
// Track connection for net limiting
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
if h.netLimiter != nil {
|
||||
@ -330,7 +423,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
|
||||
case <-h.done:
|
||||
return
|
||||
default:
|
||||
// Drop if client buffer full, may flood logging for slow client
|
||||
// Drop if client buffer full
|
||||
h.logger.Debug("msg", "Dropped entry for slow client",
|
||||
"component", "http_sink",
|
||||
"remote_addr", remoteAddr)
|
||||
@ -348,6 +441,8 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
|
||||
newCount := h.activeClients.Add(1)
|
||||
h.logger.Debug("msg", "HTTP client connected",
|
||||
"remote_addr", remoteAddr,
|
||||
"username", session.Username,
|
||||
"auth_method", session.Method,
|
||||
"active_clients", newCount)
|
||||
|
||||
h.wg.Add(1)
|
||||
@ -356,6 +451,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
|
||||
newCount := h.activeClients.Add(-1)
|
||||
h.logger.Debug("msg", "HTTP client disconnected",
|
||||
"remote_addr", remoteAddr,
|
||||
"username", session.Username,
|
||||
"active_clients", newCount)
|
||||
h.wg.Done()
|
||||
}()
|
||||
@ -364,12 +460,15 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
|
||||
clientID := fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
connectionInfo := map[string]any{
|
||||
"client_id": clientID,
|
||||
"username": session.Username,
|
||||
"auth_method": session.Method,
|
||||
"stream_path": h.streamPath,
|
||||
"status_path": h.statusPath,
|
||||
"buffer_size": h.config.BufferSize,
|
||||
"tls": h.tlsManager != nil,
|
||||
}
|
||||
data, _ := json.Marshal(connectionInfo)
|
||||
fmt.Fprintf(w, "event: connected\ndata: %s\n", data)
|
||||
fmt.Fprintf(w, "event: connected\ndata: %s\n\n", data)
|
||||
w.Flush()
|
||||
|
||||
var ticker *time.Ticker
|
||||
@ -402,6 +501,13 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
|
||||
case <-tickerChan:
|
||||
// Validate session is still active
|
||||
if h.authenticator != nil && !h.authenticator.ValidateSession(session.ID) {
|
||||
fmt.Fprintf(w, "event: disconnect\ndata: {\"reason\":\"session_expired\"}\n\n")
|
||||
w.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
heartbeatEntry := h.createHeartbeatEntry()
|
||||
if err := h.formatEntryForSSE(w, heartbeatEntry); err != nil {
|
||||
h.logger.Error("msg", "Failed to format heartbeat",
|
||||
@ -437,8 +543,10 @@ func (h *HTTPSink) formatEntryForSSE(w *bufio.Writer, entry core.LogEntry) error
|
||||
lines := bytes.Split(formatted, []byte{'\n'})
|
||||
for _, line := range lines {
|
||||
// SSE needs "data: " prefix for each line
|
||||
// TODO: validate above, is 'data: ' really necessary? make it optional if it works without it?
|
||||
fmt.Fprintf(w, "data: %s\n", line)
|
||||
}
|
||||
fmt.Fprintf(w, "\n") // Empty line to terminate event
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -478,6 +586,26 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
var authStats any
|
||||
if h.authenticator != nil {
|
||||
authStats = h.authenticator.GetStats()
|
||||
authStats.(map[string]any)["failures"] = h.authFailures.Load()
|
||||
authStats.(map[string]any)["successes"] = h.authSuccesses.Load()
|
||||
} else {
|
||||
authStats = map[string]any{
|
||||
"enabled": false,
|
||||
}
|
||||
}
|
||||
|
||||
var tlsStats any
|
||||
if h.tlsManager != nil {
|
||||
tlsStats = h.tlsManager.GetStats()
|
||||
} else {
|
||||
tlsStats = map[string]any{
|
||||
"enabled": false,
|
||||
}
|
||||
}
|
||||
|
||||
status := map[string]any{
|
||||
"service": "LogWisp",
|
||||
"version": version.Short(),
|
||||
@ -487,7 +615,6 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
"active_clients": h.activeClients.Load(),
|
||||
"buffer_size": h.config.BufferSize,
|
||||
"uptime_seconds": int(time.Since(h.startTime).Seconds()),
|
||||
"mode": map[string]bool{"standalone": h.standalone, "router": !h.standalone},
|
||||
},
|
||||
"endpoints": map[string]string{
|
||||
"transport": h.streamPath,
|
||||
@ -499,11 +626,15 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
"interval": h.config.Heartbeat.IntervalSeconds,
|
||||
"format": h.config.Heartbeat.Format,
|
||||
},
|
||||
"ssl": map[string]bool{
|
||||
"enabled": h.config.SSL != nil && h.config.SSL.Enabled,
|
||||
},
|
||||
"tls": tlsStats,
|
||||
"auth": authStats,
|
||||
"net_limit": netLimitStats,
|
||||
},
|
||||
"statistics": map[string]any{
|
||||
"total_processed": h.totalProcessed.Load(),
|
||||
"auth_failures": h.authFailures.Load(),
|
||||
"auth_successes": h.authSuccesses.Load(),
|
||||
},
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(status)
|
||||
@ -523,4 +654,34 @@ func (h *HTTPSink) GetStreamPath() string {
|
||||
// GetStatusPath returns the configured status endpoint path
|
||||
func (h *HTTPSink) GetStatusPath() string {
|
||||
return h.statusPath
|
||||
}
|
||||
|
||||
func (h *HTTPSink) SetNetAccessConfig(cfg *config.NetAccessConfig) {
|
||||
h.ipChecker = limit.NewIPChecker(cfg, h.logger)
|
||||
if h.ipChecker != nil {
|
||||
h.logger.Info("msg", "IP access control configured for HTTP sink",
|
||||
"component", "http_sink")
|
||||
}
|
||||
}
|
||||
|
||||
// SetAuthConfig configures http sink authentication
|
||||
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
h.authConfig = authCfg
|
||||
authenticator, err := auth.New(authCfg, h.logger)
|
||||
if err != nil {
|
||||
h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink",
|
||||
"component", "http_sink",
|
||||
"error", err)
|
||||
// Continue without auth
|
||||
return
|
||||
}
|
||||
h.authenticator = authenticator
|
||||
|
||||
h.logger.Info("msg", "Authentication configured for HTTP sink",
|
||||
"component", "http_sink",
|
||||
"auth_type", authCfg.Type)
|
||||
}
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
@ -31,4 +32,14 @@ type SinkStats struct {
|
||||
StartTime time.Time
|
||||
LastProcessed time.Time
|
||||
Details map[string]any
|
||||
}
|
||||
|
||||
// NetAccessSetter is an interface for sinks that can accept network access configuration
|
||||
type NetAccessSetter interface {
|
||||
SetNetAccessConfig(cfg *config.NetAccessConfig)
|
||||
}
|
||||
|
||||
// AuthSetter is an interface for sinks that can accept an AuthConfig.
|
||||
type AuthSetter interface {
|
||||
SetAuthConfig(auth *config.AuthConfig)
|
||||
}
|
||||
@ -2,18 +2,22 @@
|
||||
package sink
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/tls"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/lixenwraith/log/compat"
|
||||
@ -32,12 +36,20 @@ type TCPSink struct {
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
netLimiter *limit.NetLimiter
|
||||
ipChecker *limit.IPChecker
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
// Security components
|
||||
authenticator *auth.Authenticator
|
||||
tlsManager *tls.Manager
|
||||
authConfig *config.AuthConfig
|
||||
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
lastProcessed atomic.Value // time.Time
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
}
|
||||
|
||||
// TCPConfig holds TCP sink configuration
|
||||
@ -78,6 +90,32 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
||||
}
|
||||
}
|
||||
|
||||
// Extract SSL config
|
||||
if ssl, ok := options["ssl"].(map[string]any); ok {
|
||||
cfg.SSL = &config.SSLConfig{}
|
||||
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
|
||||
if certFile, ok := ssl["cert_file"].(string); ok {
|
||||
cfg.SSL.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := ssl["key_file"].(string); ok {
|
||||
cfg.SSL.KeyFile = keyFile
|
||||
}
|
||||
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
|
||||
if caFile, ok := ssl["client_ca_file"].(string); ok {
|
||||
cfg.SSL.ClientCAFile = caFile
|
||||
}
|
||||
cfg.SSL.VerifyClientCert, _ = ssl["verify_client_cert"].(bool)
|
||||
if minVer, ok := ssl["min_version"].(string); ok {
|
||||
cfg.SSL.MinVersion = minVer
|
||||
}
|
||||
if maxVer, ok := ssl["max_version"].(string); ok {
|
||||
cfg.SSL.MaxVersion = maxVer
|
||||
}
|
||||
if ciphers, ok := ssl["cipher_suites"].(string); ok {
|
||||
cfg.SSL.CipherSuites = ciphers
|
||||
}
|
||||
}
|
||||
|
||||
// Extract net limit config
|
||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
||||
cfg.NetLimit = &config.NetLimitConfig{}
|
||||
@ -115,6 +153,7 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
||||
}
|
||||
t.lastProcessed.Store(time.Time{})
|
||||
|
||||
// Initialize net limiter
|
||||
if cfg.NetLimit != nil && cfg.NetLimit.Enabled {
|
||||
t.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger)
|
||||
}
|
||||
@ -127,7 +166,10 @@ func (t *TCPSink) Input() chan<- core.LogEntry {
|
||||
}
|
||||
|
||||
func (t *TCPSink) Start(ctx context.Context) error {
|
||||
t.server = &tcpServer{sink: t}
|
||||
t.server = &tcpServer{
|
||||
sink: t,
|
||||
clients: make(map[gnet.Conn]*tcpClient),
|
||||
}
|
||||
|
||||
// Start log broadcast loop
|
||||
t.wg.Add(1)
|
||||
@ -136,24 +178,39 @@ func (t *TCPSink) Start(ctx context.Context) error {
|
||||
t.broadcastLoop(ctx)
|
||||
}()
|
||||
|
||||
// Configure gnet
|
||||
// Configure gnet options
|
||||
addr := fmt.Sprintf("tcp://:%d", t.config.Port)
|
||||
|
||||
// Create a gnet adapter using the existing logger instance
|
||||
gnetLogger := compat.NewGnetAdapter(t.logger)
|
||||
|
||||
var opts []gnet.Option
|
||||
opts = append(opts,
|
||||
gnet.WithLogger(gnetLogger),
|
||||
gnet.WithMulticore(true),
|
||||
gnet.WithReusePort(true),
|
||||
)
|
||||
|
||||
// Add TLS if configured
|
||||
if t.tlsManager != nil {
|
||||
// tlsConfig := t.tlsManager.GetTCPConfig()
|
||||
// TODO: tlsConfig is not used, wrapper to be implemented, non-TLS stream to be available without wrapper
|
||||
// ☢ SECURITY: gnet doesn't support TLS natively - would need wrapper
|
||||
// This is a limitation that requires implementing TLS at application layer
|
||||
t.logger.Warn("msg", "TLS configured but gnet doesn't support native TLS",
|
||||
"component", "tcp_sink",
|
||||
"workaround", "Use stunnel or nginx TCP proxy for TLS termination")
|
||||
}
|
||||
|
||||
// Start gnet server
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
t.logger.Info("msg", "Starting TCP server",
|
||||
"component", "tcp_sink",
|
||||
"port", t.config.Port)
|
||||
"port", t.config.Port,
|
||||
"auth", t.authenticator != nil)
|
||||
|
||||
err := gnet.Run(t.server, addr,
|
||||
gnet.WithLogger(gnetLogger),
|
||||
gnet.WithMulticore(true),
|
||||
gnet.WithReusePort(true),
|
||||
)
|
||||
err := gnet.Run(t.server, addr, opts...)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "TCP server failed",
|
||||
"component", "tcp_sink",
|
||||
@ -219,6 +276,18 @@ func (t *TCPSink) GetStats() SinkStats {
|
||||
netLimitStats = t.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var authStats map[string]any
|
||||
if t.authenticator != nil {
|
||||
authStats = t.authenticator.GetStats()
|
||||
authStats["failures"] = t.authFailures.Load()
|
||||
authStats["successes"] = t.authSuccesses.Load()
|
||||
}
|
||||
|
||||
var tlsStats map[string]any
|
||||
if t.tlsManager != nil {
|
||||
tlsStats = t.tlsManager.GetStats()
|
||||
}
|
||||
|
||||
return SinkStats{
|
||||
Type: "tcp",
|
||||
TotalProcessed: t.totalProcessed.Load(),
|
||||
@ -229,6 +298,8 @@ func (t *TCPSink) GetStats() SinkStats {
|
||||
"port": t.config.Port,
|
||||
"buffer_size": t.config.BufferSize,
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
"tls": tlsStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -263,11 +334,14 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
t.server.connections.Range(func(key, value any) bool {
|
||||
conn := key.(gnet.Conn)
|
||||
conn.AsyncWrite(data, nil)
|
||||
return true
|
||||
})
|
||||
// Broadcast only to authenticated clients
|
||||
t.server.mu.RLock()
|
||||
for conn, client := range t.server.clients {
|
||||
if client.authenticated {
|
||||
conn.AsyncWrite(data, nil)
|
||||
}
|
||||
}
|
||||
t.server.mu.RUnlock()
|
||||
|
||||
case <-tickerChan:
|
||||
heartbeatEntry := t.createHeartbeatEntry()
|
||||
@ -279,11 +353,21 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
t.server.connections.Range(func(key, value any) bool {
|
||||
conn := key.(gnet.Conn)
|
||||
conn.AsyncWrite(data, nil)
|
||||
return true
|
||||
})
|
||||
t.server.mu.RLock()
|
||||
for conn, client := range t.server.clients {
|
||||
if client.authenticated {
|
||||
// Validate session is still active
|
||||
if t.authenticator != nil && client.session != nil {
|
||||
if !t.authenticator.ValidateSession(client.session.ID) {
|
||||
// Session expired, close connection
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
}
|
||||
conn.AsyncWrite(data, nil)
|
||||
}
|
||||
}
|
||||
t.server.mu.RUnlock()
|
||||
|
||||
case <-t.done:
|
||||
return
|
||||
@ -320,11 +404,21 @@ func (t *TCPSink) GetActiveConnections() int64 {
|
||||
return t.activeConns.Load()
|
||||
}
|
||||
|
||||
// tcpServer handles gnet events
|
||||
// tcpClient represents a connected TCP client with auth state
|
||||
type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authenticated bool
|
||||
session *auth.Session
|
||||
authTimeout time.Time
|
||||
}
|
||||
|
||||
// tcpServer handles gnet events with authentication
|
||||
type tcpServer struct {
|
||||
gnet.BuiltinEventEngine
|
||||
sink *TCPSink
|
||||
connections sync.Map
|
||||
sink *TCPSink
|
||||
clients map[gnet.Conn]*tcpClient
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
@ -343,9 +437,17 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
remoteAddr := c.RemoteAddr().String()
|
||||
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
|
||||
|
||||
// Check IP access control first
|
||||
if s.sink.ipChecker != nil {
|
||||
if !s.sink.ipChecker.IsAllowed(c.RemoteAddr()) {
|
||||
s.sink.logger.Warn("msg", "TCP connection denied by IP filter",
|
||||
"remote_addr", remoteAddr)
|
||||
return nil, gnet.Close
|
||||
}
|
||||
}
|
||||
|
||||
// Check net limit
|
||||
if s.sink.netLimiter != nil {
|
||||
// Parse the remote address to get proper net.Addr
|
||||
remoteStr := c.RemoteAddr().String()
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr)
|
||||
if err != nil {
|
||||
@ -358,7 +460,6 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
if !s.sink.netLimiter.CheckTCP(tcpAddr) {
|
||||
s.sink.logger.Warn("msg", "TCP connection net limited",
|
||||
"remote_addr", remoteAddr)
|
||||
// Silently close connection when net limited
|
||||
return nil, gnet.Close
|
||||
}
|
||||
|
||||
@ -366,24 +467,43 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
s.sink.netLimiter.AddConnection(remoteStr)
|
||||
}
|
||||
|
||||
s.connections.Store(c, struct{}{})
|
||||
// Create client state
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
||||
authTimeout: time.Now().Add(30 * time.Second), // 30s to authenticate
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.clients[c] = client
|
||||
s.mu.Unlock()
|
||||
|
||||
newCount := s.sink.activeConns.Add(1)
|
||||
s.sink.logger.Debug("msg", "TCP connection opened",
|
||||
"remote_addr", remoteAddr,
|
||||
"active_connections", newCount)
|
||||
"active_connections", newCount,
|
||||
"requires_auth", s.sink.authenticator != nil)
|
||||
|
||||
// Send auth prompt if authentication is required
|
||||
if s.sink.authenticator != nil {
|
||||
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
|
||||
return authPrompt, gnet.None
|
||||
}
|
||||
|
||||
return nil, gnet.None
|
||||
}
|
||||
|
||||
func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
s.connections.Delete(c)
|
||||
|
||||
remoteAddr := c.RemoteAddr().String()
|
||||
|
||||
// Remove client state
|
||||
s.mu.Lock()
|
||||
delete(s.clients, c)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Remove connection tracking
|
||||
if s.sink.netLimiter != nil {
|
||||
s.sink.netLimiter.RemoveConnection(c.RemoteAddr().String())
|
||||
s.sink.netLimiter.RemoveConnection(remoteAddr)
|
||||
}
|
||||
|
||||
newCount := s.sink.activeConns.Add(-1)
|
||||
@ -395,7 +515,114 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
}
|
||||
|
||||
func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
// We don't expect input from clients, just discard
|
||||
s.mu.RLock()
|
||||
client, exists := s.clients[c]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Check auth timeout
|
||||
if !client.authenticated && time.Now().After(client.authTimeout) {
|
||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Read all available data
|
||||
data, err := c.Next(-1)
|
||||
if err != nil {
|
||||
s.sink.logger.Error("msg", "Error reading from connection",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// If not authenticated, expect auth command
|
||||
if !client.authenticated {
|
||||
client.buffer.Write(data)
|
||||
|
||||
// Look for complete auth line
|
||||
if line, err := client.buffer.ReadBytes('\n'); err == nil {
|
||||
line = bytes.TrimSpace(line)
|
||||
|
||||
// Parse AUTH command: AUTH <method> <credentials>
|
||||
parts := strings.SplitN(string(line), " ", 3)
|
||||
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||
c.AsyncWrite([]byte("ERROR: Invalid auth format\n"), nil)
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Authenticate
|
||||
session, err := s.sink.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String())
|
||||
if err != nil {
|
||||
s.sink.authFailures.Add(1)
|
||||
s.sink.logger.Warn("msg", "TCP authentication failed",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"method", parts[1],
|
||||
"error", err)
|
||||
c.AsyncWrite([]byte(fmt.Sprintf("AUTH FAILED: %v\n", err)), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Authentication successful
|
||||
s.sink.authSuccesses.Add(1)
|
||||
s.mu.Lock()
|
||||
client.authenticated = true
|
||||
client.session = session
|
||||
s.mu.Unlock()
|
||||
|
||||
s.sink.logger.Info("msg", "TCP client authenticated",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"username", session.Username,
|
||||
"method", session.Method)
|
||||
|
||||
c.AsyncWrite([]byte("AUTH OK\n"), nil)
|
||||
|
||||
// Clear buffer after auth
|
||||
client.buffer.Reset()
|
||||
}
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Authenticated clients shouldn't send data, just discard
|
||||
c.Discard(-1)
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// SetAuthConfig configures tcp sink authentication
|
||||
func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
t.authConfig = authCfg
|
||||
authenticator, err := auth.New(authCfg, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
t.authenticator = authenticator
|
||||
|
||||
// Initialize TLS manager if SSL is configured
|
||||
if t.config.SSL != nil && t.config.SSL.Enabled {
|
||||
tlsManager, err := tls.New(t.config.SSL, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to create TLS manager",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
// Continue without TLS
|
||||
return
|
||||
}
|
||||
t.tlsManager = tlsManager
|
||||
}
|
||||
|
||||
t.logger.Info("msg", "Authentication configured for TCP sink",
|
||||
"component", "tcp_sink",
|
||||
"auth_type", authCfg.Type,
|
||||
"tls_enabled", t.tlsManager != nil)
|
||||
}
|
||||
Reference in New Issue
Block a user