v0.4.1 authentication impelemented, not tested and docs not updated
This commit is contained in:
@ -36,7 +36,6 @@ type TCPSink struct {
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
netLimiter *limit.NetLimiter
|
||||
ipChecker *limit.IPChecker
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
@ -50,6 +49,11 @@ type TCPSink struct {
|
||||
lastProcessed atomic.Value // time.Time
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
|
||||
// Write error tracking
|
||||
writeErrors atomic.Uint64
|
||||
consecutiveWriteErrors map[gnet.Conn]int
|
||||
errorMu sync.Mutex
|
||||
}
|
||||
|
||||
// TCPConfig holds TCP sink configuration
|
||||
@ -141,6 +145,22 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
||||
}
|
||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||
for _, entry := range ipWhitelist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||
for _, entry := range ipBlacklist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t := &TCPSink{
|
||||
@ -191,17 +211,6 @@ func (t *TCPSink) Start(ctx context.Context) error {
|
||||
gnet.WithReusePort(true),
|
||||
)
|
||||
|
||||
// Add TLS if configured
|
||||
if t.tlsManager != nil {
|
||||
// tlsConfig := t.tlsManager.GetTCPConfig()
|
||||
// TODO: tlsConfig is not used, wrapper to be implemented, non-TLS stream to be available without wrapper
|
||||
// ☢ SECURITY: gnet doesn't support TLS natively - would need wrapper
|
||||
// This is a limitation that requires implementing TLS at application layer
|
||||
t.logger.Warn("msg", "TLS configured but gnet doesn't support native TLS",
|
||||
"component", "tcp_sink",
|
||||
"workaround", "Use stunnel or nginx TCP proxy for TLS termination")
|
||||
}
|
||||
|
||||
// Start gnet server
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
@ -338,7 +347,29 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
t.server.mu.RLock()
|
||||
for conn, client := range t.server.clients {
|
||||
if client.authenticated {
|
||||
conn.AsyncWrite(data, nil)
|
||||
// Send through TLS bridge if present
|
||||
if client.tlsBridge != nil {
|
||||
if _, err := client.tlsBridge.Write(data); err != nil {
|
||||
// TLS write failed, connection likely dead
|
||||
t.logger.Debug("msg", "TLS write failed",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
conn.Close()
|
||||
}
|
||||
} else {
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
t.handleWriteError(c, err)
|
||||
} else {
|
||||
// Reset consecutive error count on success
|
||||
t.errorMu.Lock()
|
||||
delete(t.consecutiveWriteErrors, c)
|
||||
t.errorMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
t.server.mu.RUnlock()
|
||||
@ -364,7 +395,22 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
conn.AsyncWrite(data, nil)
|
||||
if client.tlsBridge != nil {
|
||||
if _, err := client.tlsBridge.Write(data); err != nil {
|
||||
t.logger.Debug("msg", "TLS heartbeat write failed",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
conn.Close()
|
||||
}
|
||||
} else {
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
t.handleWriteError(c, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
t.server.mu.RUnlock()
|
||||
@ -375,6 +421,36 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle write errors with threshold-based connection termination
|
||||
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
|
||||
t.errorMu.Lock()
|
||||
defer t.errorMu.Unlock()
|
||||
|
||||
// Track consecutive errors per connection
|
||||
if t.consecutiveWriteErrors == nil {
|
||||
t.consecutiveWriteErrors = make(map[gnet.Conn]int)
|
||||
}
|
||||
|
||||
t.consecutiveWriteErrors[c]++
|
||||
errorCount := t.consecutiveWriteErrors[c]
|
||||
|
||||
t.logger.Debug("msg", "AsyncWrite error",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"error", err,
|
||||
"consecutive_errors", errorCount)
|
||||
|
||||
// Close connection after 3 consecutive write errors
|
||||
if errorCount >= 3 {
|
||||
t.logger.Warn("msg", "Closing connection due to repeated write errors",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"error_count", errorCount)
|
||||
delete(t.consecutiveWriteErrors, c)
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Create heartbeat as a proper LogEntry
|
||||
func (t *TCPSink) createHeartbeatEntry() core.LogEntry {
|
||||
message := "heartbeat"
|
||||
@ -406,11 +482,13 @@ func (t *TCPSink) GetActiveConnections() int64 {
|
||||
|
||||
// tcpClient represents a connected TCP client with auth state
|
||||
type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authenticated bool
|
||||
session *auth.Session
|
||||
authTimeout time.Time
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authenticated bool
|
||||
session *auth.Session
|
||||
authTimeout time.Time
|
||||
tlsBridge *tls.GNetTLSConn
|
||||
authTimeoutSet bool
|
||||
}
|
||||
|
||||
// tcpServer handles gnet events with authentication
|
||||
@ -434,15 +512,13 @@ func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
}
|
||||
|
||||
func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
remoteAddr := c.RemoteAddr().String()
|
||||
remoteAddr := c.RemoteAddr()
|
||||
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
|
||||
|
||||
// Check IP access control first
|
||||
if s.sink.ipChecker != nil {
|
||||
if !s.sink.ipChecker.IsAllowed(c.RemoteAddr()) {
|
||||
s.sink.logger.Warn("msg", "TCP connection denied by IP filter",
|
||||
"remote_addr", remoteAddr)
|
||||
return nil, gnet.Close
|
||||
// Reject IPv6 connections immediately
|
||||
if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok {
|
||||
if tcpAddr.IP.To4() == nil {
|
||||
return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close
|
||||
}
|
||||
}
|
||||
|
||||
@ -467,11 +543,26 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
s.sink.netLimiter.AddConnection(remoteStr)
|
||||
}
|
||||
|
||||
// Create client state
|
||||
// Create client state without auth timeout initially
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
||||
authTimeout: time.Now().Add(30 * time.Second), // 30s to authenticate
|
||||
conn: c,
|
||||
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
||||
authTimeoutSet: false, // Auth timeout not started yet
|
||||
}
|
||||
|
||||
// Initialize TLS bridge if enabled
|
||||
if s.sink.tlsManager != nil {
|
||||
tlsConfig := s.sink.tlsManager.GetTCPConfig()
|
||||
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
|
||||
client.tlsBridge.Handshake() // Start async handshake
|
||||
|
||||
s.sink.logger.Debug("msg", "TLS handshake initiated",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", remoteAddr)
|
||||
} else if s.sink.authenticator != nil {
|
||||
// Only set auth timeout if no TLS (plain connection)
|
||||
client.authTimeout = time.Now().Add(30 * time.Second) // TODO: configurable or non-hardcoded timer
|
||||
client.authTimeoutSet = true
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
@ -485,7 +576,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
"requires_auth", s.sink.authenticator != nil)
|
||||
|
||||
// Send auth prompt if authentication is required
|
||||
if s.sink.authenticator != nil {
|
||||
if s.sink.authenticator != nil && s.sink.tlsManager == nil {
|
||||
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
|
||||
return authPrompt, gnet.None
|
||||
}
|
||||
@ -498,9 +589,22 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
|
||||
// Remove client state
|
||||
s.mu.Lock()
|
||||
client := s.clients[c]
|
||||
delete(s.clients, c)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Clean up TLS bridge if present
|
||||
if client != nil && client.tlsBridge != nil {
|
||||
client.tlsBridge.Close()
|
||||
s.sink.logger.Debug("msg", "TLS connection closed",
|
||||
"remote_addr", remoteAddr)
|
||||
}
|
||||
|
||||
// Clean up write error tracking
|
||||
s.sink.errorMu.Lock()
|
||||
delete(s.sink.consecutiveWriteErrors, c)
|
||||
s.sink.errorMu.Unlock()
|
||||
|
||||
// Remove connection tracking
|
||||
if s.sink.netLimiter != nil {
|
||||
s.sink.netLimiter.RemoveConnection(remoteAddr)
|
||||
@ -523,13 +627,18 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Check auth timeout
|
||||
if !client.authenticated && time.Now().After(client.authTimeout) {
|
||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
// // Check auth timeout
|
||||
// if !client.authenticated && time.Now().After(client.authTimeout) {
|
||||
// s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
// "component", "tcp_sink",
|
||||
// "remote_addr", c.RemoteAddr().String())
|
||||
// if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
|
||||
// client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
|
||||
// } else if client.tlsBridge == nil {
|
||||
// c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||
// }
|
||||
// return gnet.Close
|
||||
// }
|
||||
|
||||
// Read all available data
|
||||
data, err := c.Next(-1)
|
||||
@ -540,6 +649,70 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Process through TLS bridge if present
|
||||
if client.tlsBridge != nil {
|
||||
// Feed encrypted data into TLS engine
|
||||
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
|
||||
s.sink.logger.Error("msg", "TLS processing error",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", err)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Check if handshake is complete
|
||||
if !client.tlsBridge.IsHandshakeDone() {
|
||||
// Still handshaking, wait for more data
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Check handshake result
|
||||
_, hsErr := client.tlsBridge.HandshakeComplete()
|
||||
if hsErr != nil {
|
||||
s.sink.logger.Error("msg", "TLS handshake failed",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", hsErr)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Set auth timeout only after TLS handshake completes
|
||||
if !client.authTimeoutSet && s.sink.authenticator != nil && !client.authenticated {
|
||||
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||
client.authTimeoutSet = true
|
||||
s.sink.logger.Debug("msg", "Auth timeout started after TLS handshake",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
}
|
||||
|
||||
// Read decrypted plaintext
|
||||
data = client.tlsBridge.Read()
|
||||
if data == nil || len(data) == 0 {
|
||||
// No plaintext available yet
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// First data after TLS handshake - send auth prompt if needed
|
||||
if s.sink.authenticator != nil && !client.authenticated &&
|
||||
len(client.buffer.Bytes()) == 0 {
|
||||
authPrompt := []byte("AUTH REQUIRED\n")
|
||||
client.tlsBridge.Write(authPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Only check auth timeout if it has been set
|
||||
if !client.authenticated && client.authTimeoutSet && time.Now().After(client.authTimeout) {
|
||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
|
||||
client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
|
||||
} else if client.tlsBridge == nil {
|
||||
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||
}
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// If not authenticated, expect auth command
|
||||
if !client.authenticated {
|
||||
client.buffer.Write(data)
|
||||
@ -551,7 +724,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
// Parse AUTH command: AUTH <method> <credentials>
|
||||
parts := strings.SplitN(string(line), " ", 3)
|
||||
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||
c.AsyncWrite([]byte("ERROR: Invalid auth format\n"), nil)
|
||||
// Send error through TLS if enabled
|
||||
errMsg := []byte("AUTH FAILED\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(errMsg)
|
||||
} else {
|
||||
c.AsyncWrite(errMsg, nil)
|
||||
}
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
@ -563,7 +742,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"method", parts[1],
|
||||
"error", err)
|
||||
c.AsyncWrite([]byte(fmt.Sprintf("AUTH FAILED: %v\n", err)), nil)
|
||||
// Send error through TLS if enabled
|
||||
errMsg := []byte("AUTH FAILED\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(errMsg)
|
||||
} else {
|
||||
c.AsyncWrite(errMsg, nil)
|
||||
}
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
@ -575,11 +760,19 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
s.mu.Unlock()
|
||||
|
||||
s.sink.logger.Info("msg", "TCP client authenticated",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"username", session.Username,
|
||||
"method", session.Method)
|
||||
"method", session.Method,
|
||||
"tls", client.tlsBridge != nil)
|
||||
|
||||
c.AsyncWrite([]byte("AUTH OK\n"), nil)
|
||||
// Send success through TLS if enabled
|
||||
successMsg := []byte("AUTH OK\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(successMsg)
|
||||
} else {
|
||||
c.AsyncWrite(successMsg, nil)
|
||||
}
|
||||
|
||||
// Clear buffer after auth
|
||||
client.buffer.Reset()
|
||||
@ -610,7 +803,7 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
|
||||
// Initialize TLS manager if SSL is configured
|
||||
if t.config.SSL != nil && t.config.SSL.Enabled {
|
||||
tlsManager, err := tls.New(t.config.SSL, t.logger)
|
||||
tlsManager, err := tls.NewManager(t.config.SSL, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to create TLS manager",
|
||||
"component", "tcp_sink",
|
||||
@ -624,5 +817,6 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
t.logger.Info("msg", "Authentication configured for TCP sink",
|
||||
"component", "tcp_sink",
|
||||
"auth_type", authCfg.Type,
|
||||
"tls_enabled", t.tlsManager != nil)
|
||||
"tls_enabled", t.tlsManager != nil,
|
||||
"tls_bridge", t.tlsManager != nil)
|
||||
}
|
||||
Reference in New Issue
Block a user