From 45b2093569b02562a6c1385c12ff6028b33ca52b9614bdb6508976248a6379ac Mon Sep 17 00:00:00 2001 From: Lixen Wraith Date: Tue, 23 Sep 2025 12:03:42 -0400 Subject: [PATCH] v0.4.1 authentication impelemented, not tested and docs not updated --- go.mod | 1 + go.sum | 2 + src/internal/auth/authenticator.go | 289 ++++++++++++++++-- src/internal/config/auth.go | 2 - src/internal/config/limit.go | 32 -- src/internal/config/pipeline.go | 23 +- src/internal/config/server.go | 71 ++++- src/internal/config/ssl.go | 26 +- src/internal/config/validation.go | 5 - src/internal/limit/ip.go | 173 ----------- src/internal/limit/net.go | 470 +++++++++++++++++++++++------ src/internal/service/service.go | 3 - src/internal/sink/http.go | 60 ++-- src/internal/sink/http_client.go | 90 ++++-- src/internal/sink/sink.go | 5 - src/internal/sink/tcp.go | 282 ++++++++++++++--- src/internal/sink/tcp_client.go | 104 ++++++- src/internal/source/http.go | 63 +++- src/internal/source/tcp.go | 186 +++++++++++- src/internal/tls/gnet_bridge.go | 341 +++++++++++++++++++++ src/internal/tls/manager.go | 4 +- 21 files changed, 1779 insertions(+), 453 deletions(-) delete mode 100644 src/internal/limit/ip.go create mode 100644 src/internal/tls/gnet_bridge.go diff --git a/go.mod b/go.mod index bc94d64..115035b 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/valyala/fasthttp v1.65.0 golang.org/x/crypto v0.42.0 golang.org/x/term v0.35.0 + golang.org/x/time v0.13.0 ) require ( diff --git a/go.sum b/go.sum index 9de57e3..8598e9b 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= +golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= +golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/src/internal/auth/authenticator.go b/src/internal/auth/authenticator.go index 2e77bb6..3b4fa6b 100644 --- a/src/internal/auth/authenticator.go +++ b/src/internal/auth/authenticator.go @@ -3,8 +3,10 @@ package auth import ( "bufio" + "crypto/rand" "encoding/base64" "fmt" + "net" "os" "strings" "sync" @@ -15,8 +17,12 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/lixenwraith/log" "golang.org/x/crypto/bcrypt" + "golang.org/x/time/rate" ) +// Prevent unbounded map growth +const maxAuthTrackedIPs = 10000 + // Authenticator handles all authentication methods for a pipeline type Authenticator struct { config *config.AuthConfig @@ -30,6 +36,18 @@ type Authenticator struct { // Session tracking sessions map[string]*Session sessionMu sync.RWMutex + + // Brute-force protection + ipAuthAttempts map[string]*ipAuthState + authMu sync.RWMutex +} + +// ADDED: Per-IP auth attempt tracking +type ipAuthState struct { + limiter *rate.Limiter + failCount int + lastAttempt time.Time + blockedUntil time.Time } // Session represents an authenticated connection @@ -50,11 +68,12 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { } a := &Authenticator{ - config: cfg, - logger: logger, - basicUsers: make(map[string]string), - bearerTokens: make(map[string]bool), - sessions: make(map[string]*Session), + config: cfg, + logger: logger, + basicUsers: make(map[string]string), + bearerTokens: make(map[string]bool), + sessions: make(map[string]*Session), + ipAuthAttempts: make(map[string]*ipAuthState), } // Initialize Basic Auth users @@ -82,6 +101,7 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { a.jwtParser = jwt.NewParser( jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}), jwt.WithLeeway(5*time.Second), + jwt.WithExpirationRequired(), ) // Setup key function @@ -102,6 +122,9 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { // Start session cleanup go a.sessionCleanup() + // Start auth attempt cleanup + go a.authAttemptCleanup() + logger.Info("msg", "Authenticator initialized", "component", "auth", "type", cfg.Type) @@ -109,6 +132,129 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { return a, nil } +// Check and enforce rate limits +func (a *Authenticator) checkRateLimit(remoteAddr string) error { + ip, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + ip = remoteAddr // Fallback for malformed addresses + } + + a.authMu.Lock() + defer a.authMu.Unlock() + + state, exists := a.ipAuthAttempts[ip] + now := time.Now() + + if !exists { + // Check map size limit before creating new entry + if len(a.ipAuthAttempts) >= maxAuthTrackedIPs { + // Evict an old entry using simplified LRU + // Sample 20 random entries and evict the oldest + const sampleSize = 20 + var oldestIP string + oldestTime := now + + // Build sample + sampled := 0 + for sampledIP, sampledState := range a.ipAuthAttempts { + if sampledState.lastAttempt.Before(oldestTime) { + oldestIP = sampledIP + oldestTime = sampledState.lastAttempt + } + sampled++ + if sampled >= sampleSize { + break + } + } + + // Evict the oldest from our sample + if oldestIP != "" { + delete(a.ipAuthAttempts, oldestIP) + a.logger.Debug("msg", "Evicted old auth attempt state", + "component", "auth", + "evicted_ip", oldestIP, + "last_seen", oldestTime) + } + } + + // Create new state for this IP + // 5 attempts per minute, burst of 3 + state = &ipAuthState{ + limiter: rate.NewLimiter(rate.Every(12*time.Second), 3), + lastAttempt: now, + } + a.ipAuthAttempts[ip] = state + } + + // Check if IP is temporarily blocked + if now.Before(state.blockedUntil) { + remaining := state.blockedUntil.Sub(now) + a.logger.Warn("msg", "IP temporarily blocked", + "component", "auth", + "ip", ip, + "remaining", remaining) + // Sleep to slow down even blocked attempts + time.Sleep(2 * time.Second) + return fmt.Errorf("temporarily blocked, try again in %v", remaining.Round(time.Second)) + } + + // Check rate limit + if !state.limiter.Allow() { + state.failCount++ + + // Only set new blockedUntil if not already blocked + // This prevents indefinite block extension + if state.blockedUntil.IsZero() || now.After(state.blockedUntil) { + // Progressive blocking: 2^failCount minutes + blockMinutes := 1 << min(state.failCount, 6) // Cap at 64 minutes + state.blockedUntil = now.Add(time.Duration(blockMinutes) * time.Minute) + + a.logger.Warn("msg", "Rate limit exceeded, blocking IP", + "component", "auth", + "ip", ip, + "fail_count", state.failCount, + "block_duration", time.Duration(blockMinutes)*time.Minute) + } + + return fmt.Errorf("rate limit exceeded") + } + + state.lastAttempt = now + return nil +} + +// Record failed attempt +func (a *Authenticator) recordFailure(remoteAddr string) { + ip, _, _ := net.SplitHostPort(remoteAddr) + if ip == "" { + ip = remoteAddr + } + + a.authMu.Lock() + defer a.authMu.Unlock() + + if state, exists := a.ipAuthAttempts[ip]; exists { + state.failCount++ + state.lastAttempt = time.Now() + } +} + +// Reset failure count on success +func (a *Authenticator) recordSuccess(remoteAddr string) { + ip, _, _ := net.SplitHostPort(remoteAddr) + if ip == "" { + ip = remoteAddr + } + + a.authMu.Lock() + defer a.authMu.Unlock() + + if state, exists := a.ipAuthAttempts[ip]; exists { + state.failCount = 0 + state.blockedUntil = time.Time{} + } +} + // AuthenticateHTTP handles HTTP authentication headers func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) { if a == nil || a.config.Type == "none" { @@ -120,14 +266,31 @@ func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Sessio }, nil } + // Check rate limit + if err := a.checkRateLimit(remoteAddr); err != nil { + return nil, err + } + + var session *Session + var err error + switch a.config.Type { case "basic": - return a.authenticateBasic(authHeader, remoteAddr) + session, err = a.authenticateBasic(authHeader, remoteAddr) case "bearer": - return a.authenticateBearer(authHeader, remoteAddr) + session, err = a.authenticateBearer(authHeader, remoteAddr) default: - return nil, fmt.Errorf("unsupported auth type: %s", a.config.Type) + err = fmt.Errorf("unsupported auth type: %s", a.config.Type) } + + if err != nil { + a.recordFailure(remoteAddr) + time.Sleep(500 * time.Millisecond) + return nil, err + } + + a.recordSuccess(remoteAddr) + return session, nil } // AuthenticateTCP handles TCP connection authentication @@ -141,32 +304,54 @@ func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string) }, nil } + // Check rate limit first + if err := a.checkRateLimit(remoteAddr); err != nil { + return nil, err + } + + var session *Session + var err error + // TCP auth protocol: AUTH switch strings.ToLower(method) { case "token": if a.config.Type != "bearer" { - return nil, fmt.Errorf("token auth not configured") + err = fmt.Errorf("token auth not configured") + } else { + session, err = a.validateToken(credentials, remoteAddr) } - return a.validateToken(credentials, remoteAddr) case "basic": if a.config.Type != "basic" { - return nil, fmt.Errorf("basic auth not configured") + err = fmt.Errorf("basic auth not configured") + } else { + // Expect base64(username:password) + decoded, decErr := base64.StdEncoding.DecodeString(credentials) + if decErr != nil { + err = fmt.Errorf("invalid credentials encoding") + } else { + parts := strings.SplitN(string(decoded), ":", 2) + if len(parts) != 2 { + err = fmt.Errorf("invalid credentials format") + } else { + session, err = a.validateBasicAuth(parts[0], parts[1], remoteAddr) + } + } } - // Expect base64(username:password) - decoded, err := base64.StdEncoding.DecodeString(credentials) - if err != nil { - return nil, fmt.Errorf("invalid credentials encoding") - } - parts := strings.SplitN(string(decoded), ":", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid credentials format") - } - return a.validateBasicAuth(parts[0], parts[1], remoteAddr) default: - return nil, fmt.Errorf("unsupported auth method: %s", method) + err = fmt.Errorf("unsupported auth method: %s", method) } + + if err != nil { + a.recordFailure(remoteAddr) + // Add delay on failure + time.Sleep(500 * time.Millisecond) + return nil, err + } + + a.recordSuccess(remoteAddr) + return session, nil } func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) { @@ -255,6 +440,23 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error return nil, fmt.Errorf("invalid JWT token") } + // Explicit expiration check + if exp, ok := claims["exp"].(float64); ok { + if time.Now().Unix() > int64(exp) { + return nil, fmt.Errorf("token expired") + } + } else { + // Reject tokens without expiration + return nil, fmt.Errorf("token missing expiration claim") + } + + // Check not-before claim + if nbf, ok := claims["nbf"].(float64); ok { + if time.Now().Unix() < int64(nbf) { + return nil, fmt.Errorf("token not yet valid") + } + } + // Check issuer if configured if a.config.BearerAuth.JWT.Issuer != "" { if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer { @@ -264,7 +466,20 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error // Check audience if configured if a.config.BearerAuth.JWT.Audience != "" { - if aud, ok := claims["aud"].(string); !ok || aud != a.config.BearerAuth.JWT.Audience { + // Handle both string and []string audience formats + audValid := false + switch aud := claims["aud"].(type) { + case string: + audValid = aud == a.config.BearerAuth.JWT.Audience + case []interface{}: + for _, aa := range aud { + if audStr, ok := aa.(string); ok && audStr == a.config.BearerAuth.JWT.Audience { + audValid = true + break + } + } + } + if !audValid { return nil, fmt.Errorf("invalid token audience") } } @@ -322,6 +537,27 @@ func (a *Authenticator) sessionCleanup() { } } +// Cleanup old auth attempts +func (a *Authenticator) authAttemptCleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + a.authMu.Lock() + now := time.Now() + for ip, state := range a.ipAuthAttempts { + // Remove entries older than 1 hour with no recent activity + if now.Sub(state.lastAttempt) > time.Hour { + delete(a.ipAuthAttempts, ip) + a.logger.Debug("msg", "Cleaned up auth attempt state", + "component", "auth", + "ip", ip) + } + } + a.authMu.Unlock() + } +} + func (a *Authenticator) loadUsersFile(path string) error { file, err := os.Open(path) if err != nil { @@ -366,7 +602,12 @@ func (a *Authenticator) loadUsersFile(path string) error { } func generateSessionID() string { - return fmt.Sprintf("%d-%d", time.Now().UnixNano(), time.Now().Unix()) + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + // Fallback to a less secure method if crypto/rand fails + return fmt.Sprintf("fallback-%d", time.Now().UnixNano()) + } + return base64.URLEncoding.EncodeToString(b) } // ValidateSession checks if a session is still valid diff --git a/src/internal/config/auth.go b/src/internal/config/auth.go index 84f7544..012e68b 100644 --- a/src/internal/config/auth.go +++ b/src/internal/config/auth.go @@ -3,8 +3,6 @@ package config import ( "fmt" - "net" - "strings" ) type AuthConfig struct { diff --git a/src/internal/config/limit.go b/src/internal/config/limit.go index 05c8654..8507031 100644 --- a/src/internal/config/limit.go +++ b/src/internal/config/limit.go @@ -3,7 +3,6 @@ package config import ( "fmt" - "net" "strings" ) @@ -29,37 +28,6 @@ type RateLimitConfig struct { MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"` } -func validateNetAccess(pipelineName string, cfg *NetAccessConfig) error { - if cfg == nil { - return nil - } - - // Validate CIDR notation - for _, cidr := range cfg.IPWhitelist { - if !strings.Contains(cidr, "/") { - cidr = cidr + "/32" - } - if _, _, err := net.ParseCIDR(cidr); err != nil { - if net.ParseIP(cidr) == nil { - return fmt.Errorf("pipeline '%s': invalid IP whitelist entry: %s", pipelineName, cidr) - } - } - } - - for _, cidr := range cfg.IPBlacklist { - if !strings.Contains(cidr, "/") { - cidr = cidr + "/32" - } - if _, _, err := net.ParseCIDR(cidr); err != nil { - if net.ParseIP(cidr) == nil { - return fmt.Errorf("pipeline '%s': invalid IP blacklist entry: %s", pipelineName, cidr) - } - } - } - - return nil -} - func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { if cfg == nil { return nil diff --git a/src/internal/config/pipeline.go b/src/internal/config/pipeline.go index 9638838..6b26373 100644 --- a/src/internal/config/pipeline.go +++ b/src/internal/config/pipeline.go @@ -20,9 +20,6 @@ type PipelineConfig struct { // Rate limiting RateLimit *RateLimitConfig `toml:"rate_limit"` - // Network access control (IP filtering) - NetAccess *NetAccessConfig `toml:"net_access"` - // Filter configuration Filters []FilterConfig `toml:"filters"` @@ -37,12 +34,6 @@ type PipelineConfig struct { Auth *AuthConfig `toml:"auth"` } -// NetAccessConfig defines IP-based access control lists -type NetAccessConfig struct { - IPWhitelist []string `toml:"ip_whitelist"` - IPBlacklist []string `toml:"ip_blacklist"` -} - // SourceConfig represents an input data source type SourceConfig struct { // Source type: "directory", "file", "stdin", etc. @@ -132,6 +123,13 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err } } + // CHANGED: Validate SSL if present + if ssl, ok := cfg.Options["ssl"].(map[string]any); ok { + if err := validateSSLOptions("HTTP source", pipelineName, sourceIndex, ssl); err != nil { + return err + } + } + case "tcp": // Validate TCP source options port, ok := cfg.Options["port"].(int64) @@ -147,6 +145,13 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err } } + // CHANGED: Validate SSL if present + if ssl, ok := cfg.Options["ssl"].(map[string]any); ok { + if err := validateSSLOptions("TCP source", pipelineName, sourceIndex, ssl); err != nil { + return err + } + } + default: return fmt.Errorf("pipeline '%s' source[%d]: unknown source type '%s'", pipelineName, sourceIndex, cfg.Type) diff --git a/src/internal/config/server.go b/src/internal/config/server.go index eb16c9e..fc9e9d0 100644 --- a/src/internal/config/server.go +++ b/src/internal/config/server.go @@ -1,7 +1,11 @@ // FILE: logwisp/src/internal/config/server.go package config -import "fmt" +import ( + "fmt" + "net" + "strings" +) type TCPConfig struct { Enabled bool `toml:"enabled"` @@ -49,6 +53,10 @@ type NetLimitConfig struct { // Enable net limiting Enabled bool `toml:"enabled"` + // IP Access Control Lists + IPWhitelist []string `toml:"ip_whitelist"` + IPBlacklist []string `toml:"ip_blacklist"` + // Requests per second per client RequestsPerSecond float64 `toml:"requests_per_second"` @@ -90,6 +98,33 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl return nil } + // Validate IP lists if present + if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok { + for i, entry := range ipWhitelist { + entryStr, ok := entry.(string) + if !ok { + continue + } + if err := validateIPv4Entry(entryStr); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d] %s: whitelist[%d] %v", + pipelineName, sinkIndex, serverType, i, err) + } + } + } + + if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok { + for i, entry := range ipBlacklist { + entryStr, ok := entry.(string) + if !ok { + continue + } + if err := validateIPv4Entry(entryStr); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d] %s: blacklist[%d] %v", + pipelineName, sinkIndex, serverType, i, err) + } + } + } + // Validate requests per second rps, ok := rl["requests_per_second"].(float64) if !ok || rps <= 0 { @@ -132,5 +167,39 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl } } + return nil +} + +// validateIPv4Entry ensures an IP or CIDR is IPv4 +func validateIPv4Entry(entry string) error { + // Handle single IP + if !strings.Contains(entry, "/") { + ip := net.ParseIP(entry) + if ip == nil { + return fmt.Errorf("invalid IP address: %s", entry) + } + if ip.To4() == nil { + return fmt.Errorf("IPv6 not supported (IPv4-only): %s", entry) + } + return nil + } + + // Handle CIDR + ipAddr, ipNet, err := net.ParseCIDR(entry) + if err != nil { + return fmt.Errorf("invalid CIDR: %s", entry) + } + + // Check if the IP is IPv4 + if ipAddr.To4() == nil { + return fmt.Errorf("IPv6 CIDR not supported (IPv4-only): %s", entry) + } + + // Verify the network mask is appropriate for IPv4 + _, bits := ipNet.Mask.Size() + if bits != 32 { + return fmt.Errorf("invalid IPv4 CIDR mask (got %d bits, expected 32): %s", bits, entry) + } + return nil } \ No newline at end of file diff --git a/src/internal/config/ssl.go b/src/internal/config/ssl.go index cfd1b0d..5bdc61d 100644 --- a/src/internal/config/ssl.go +++ b/src/internal/config/ssl.go @@ -1,7 +1,10 @@ // FILE: logwisp/src/internal/config/ssl.go package config -import "fmt" +import ( + "fmt" + "os" +) type SSLConfig struct { Enabled bool `toml:"enabled"` @@ -13,6 +16,9 @@ type SSLConfig struct { ClientCAFile string `toml:"client_ca_file"` VerifyClientCert bool `toml:"verify_client_cert"` + // Option to skip verification for clients + InsecureSkipVerify bool `toml:"insecure_skip_verify"` + // TLS version constraints MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3" MaxVersion string `toml:"max_version"` @@ -31,11 +37,27 @@ func validateSSLOptions(serverType, pipelineName string, sinkIndex int, ssl map[ pipelineName, sinkIndex, serverType) } + // Validate that certificate files exist and are readable + if _, err := os.Stat(certFile); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d] %s: cert_file is not accessible: %w", + pipelineName, sinkIndex, serverType, err) + } + if _, err := os.Stat(keyFile); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d] %s: key_file is not accessible: %w", + pipelineName, sinkIndex, serverType, err) + } + if clientAuth, ok := ssl["client_auth"].(bool); ok && clientAuth { - if caFile, ok := ssl["client_ca_file"].(string); !ok || caFile == "" { + caFile, caOk := ssl["client_ca_file"].(string) + if !caOk || caFile == "" { return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified", pipelineName, sinkIndex, serverType) } + // Validate that the client CA file exists and is readable + if _, err := os.Stat(caFile); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d] %s: client_ca_file is not accessible: %w", + pipelineName, sinkIndex, serverType, err) + } } // Validate TLS versions diff --git a/src/internal/config/validation.go b/src/internal/config/validation.go index 010ff82..b343683 100644 --- a/src/internal/config/validation.go +++ b/src/internal/config/validation.go @@ -72,11 +72,6 @@ func (c *Config) validate() error { } } - // Validate net access if present - if err := validateNetAccess(pipeline.Name, pipeline.NetAccess); err != nil { - return err - } - // Validate auth if present if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil { return err diff --git a/src/internal/limit/ip.go b/src/internal/limit/ip.go deleted file mode 100644 index b76afa0..0000000 --- a/src/internal/limit/ip.go +++ /dev/null @@ -1,173 +0,0 @@ -// FILE: src/internal/limit/ip.go -package limit - -import ( - "net" - "strings" - - "logwisp/src/internal/config" - - "github.com/lixenwraith/log" -) - -// IPChecker handles IP-based access control lists -type IPChecker struct { - ipWhitelist []*net.IPNet - ipBlacklist []*net.IPNet - logger *log.Logger -} - -// NewIPChecker creates a new IPChecker. Returns nil if no rules are defined. -func NewIPChecker(cfg *config.NetAccessConfig, logger *log.Logger) *IPChecker { - if cfg == nil || (len(cfg.IPWhitelist) == 0 && len(cfg.IPBlacklist) == 0) { - return nil - } - - c := &IPChecker{ - ipWhitelist: make([]*net.IPNet, 0), - ipBlacklist: make([]*net.IPNet, 0), - logger: logger, - } - - // Parse whitelist entries - for _, cidr := range cfg.IPWhitelist { - if !strings.Contains(cidr, "/") { - cidr = cidr + "/32" - } - - _, ipNet, err := net.ParseCIDR(cidr) - if err != nil { - // Try parsing as plain IP - if ip := net.ParseIP(cidr); ip != nil { - if ip.To4() != nil { - ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)} - } else { - ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} - } - } else { - logger.Warn("msg", "Skipping invalid IP whitelist entry", - "component", "ip_checker", - "entry", cidr, - "error", err) - continue - } - } - c.ipWhitelist = append(c.ipWhitelist, ipNet) - } - - // Parse blacklist entries - for _, cidr := range cfg.IPBlacklist { - if !strings.Contains(cidr, "/") { - cidr = cidr + "/32" - } - - _, ipNet, err := net.ParseCIDR(cidr) - if err != nil { - // Try parsing as plain IP - if ip := net.ParseIP(cidr); ip != nil { - if ip.To4() != nil { - ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)} - } else { - ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} - } - } else { - logger.Warn("msg", "Skipping invalid IP blacklist entry", - "component", "ip_checker", - "entry", cidr, - "error", err) - continue - } - } - c.ipBlacklist = append(c.ipBlacklist, ipNet) - } - - logger.Info("msg", "IP checker initialized", - "component", "ip_checker", - "whitelist_rules", len(c.ipWhitelist), - "blacklist_rules", len(c.ipBlacklist)) - - return c -} - -// IsAllowed validates if a remote address is permitted -func (c *IPChecker) IsAllowed(remoteAddr net.Addr) bool { - if c == nil { - return true // No checker = allow all - } - - // No rules = allow all - if len(c.ipWhitelist) == 0 && len(c.ipBlacklist) == 0 { - return true - } - - // Extract IP from address - var ipStr string - switch addr := remoteAddr.(type) { - case *net.TCPAddr: - ipStr = addr.IP.String() - case *net.UDPAddr: - ipStr = addr.IP.String() - default: - // Try string parsing - addrStr := remoteAddr.String() - host, _, err := net.SplitHostPort(addrStr) - if err != nil { - ipStr = addrStr - } else { - ipStr = host - } - } - - ip := net.ParseIP(ipStr) - if ip == nil { - c.logger.Warn("msg", "Could not parse remote address to IP", - "component", "ip_checker", - "remote_addr", remoteAddr.String()) - return false // Deny unparseable addresses - } - - // Check blacklist first (deny takes precedence) - for _, ipNet := range c.ipBlacklist { - if ipNet.Contains(ip) { - c.logger.Warn("msg", "Blacklisted IP denied", - "component", "ip_checker", - "ip", ipStr, - "rule", ipNet.String()) - return false - } - } - - // If whitelist is configured, IP must be in it - if len(c.ipWhitelist) > 0 { - for _, ipNet := range c.ipWhitelist { - if ipNet.Contains(ip) { - c.logger.Debug("msg", "IP allowed by whitelist", - "component", "ip_checker", - "ip", ipStr, - "rule", ipNet.String()) - return true - } - } - // No whitelist match = deny - c.logger.Warn("msg", "IP not in whitelist", - "component", "ip_checker", - "ip", ipStr) - return false - } - - // No blacklist match + no whitelist configured = allow - return true -} - -// GetStats returns IP checker statistics -func (c *IPChecker) GetStats() map[string]any { - if c == nil { - return map[string]any{"enabled": false} - } - - return map[string]any{ - "enabled": true, - "whitelist_rules": len(c.ipWhitelist), - "blacklist_rules": len(c.ipBlacklist), - } -} \ No newline at end of file diff --git a/src/internal/limit/net.go b/src/internal/limit/net.go index 067d3d0..96bd236 100644 --- a/src/internal/limit/net.go +++ b/src/internal/limit/net.go @@ -14,11 +14,32 @@ import ( "github.com/lixenwraith/log" ) +// DenialReason indicates why a request was denied +type DenialReason string + +const ( + // IPv4Only is the enforcement message for IPv6 rejection + IPv4Only = "IPv4-only (IPv6 not supported)" +) + +const ( + ReasonAllowed DenialReason = "" + ReasonBlacklisted DenialReason = "IP denied by blacklist" + ReasonNotWhitelisted DenialReason = "IP not in whitelist" + ReasonRateLimited DenialReason = "Rate limit exceeded" + ReasonConnectionLimited DenialReason = "Connection limit exceeded" + ReasonInvalidIP DenialReason = "Invalid IP address" +) + // NetLimiter manages net limiting for a transport type NetLimiter struct { config config.NetLimitConfig logger *log.Logger + // IP Access Control Lists + ipWhitelist []*net.IPNet + ipBlacklist []*net.IPNet + // Per-IP limiters ipLimiters map[string]*ipLimiter ipMu sync.RWMutex @@ -27,17 +48,22 @@ type NetLimiter struct { globalLimiter *TokenBucket // Connection tracking - ipConnections map[string]*atomic.Int64 + ipConnections map[string]*connTracker connMu sync.RWMutex // Statistics - totalRequests atomic.Uint64 - blockedRequests atomic.Uint64 - uniqueIPs atomic.Uint64 + totalRequests atomic.Uint64 + blockedByBlacklist atomic.Uint64 + blockedByWhitelist atomic.Uint64 + blockedByRateLimit atomic.Uint64 + blockedByConnLimit atomic.Uint64 + blockedByInvalidIP atomic.Uint64 + uniqueIPs atomic.Uint64 // Cleanup - lastCleanup time.Time - cleanupMu sync.Mutex + lastCleanup time.Time + cleanupMu sync.Mutex + cleanupActive atomic.Bool // Lifecycle management ctx context.Context @@ -51,9 +77,20 @@ type ipLimiter struct { connections atomic.Int64 } +// Connection tracking with activity timestamp +type connTracker struct { + connections atomic.Int64 + lastSeen time.Time + mu sync.Mutex +} + // Creates a new net limiter func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter { - if !cfg.Enabled { + // Return nil only if nothing is configured + hasACL := len(cfg.IPWhitelist) > 0 || len(cfg.IPBlacklist) > 0 + hasRateLimit := cfg.Enabled + + if !hasACL && !hasRateLimit { return nil } @@ -65,28 +102,39 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter { l := &NetLimiter{ config: cfg, - ipLimiters: make(map[string]*ipLimiter), - ipConnections: make(map[string]*atomic.Int64), - lastCleanup: time.Now(), logger: logger, + ipWhitelist: make([]*net.IPNet, 0), + ipBlacklist: make([]*net.IPNet, 0), + ipLimiters: make(map[string]*ipLimiter), + ipConnections: make(map[string]*connTracker), + lastCleanup: time.Now(), ctx: ctx, cancel: cancel, cleanupDone: make(chan struct{}), } - // Create global limiter if not using per-IP limiting - if cfg.LimitBy == "global" { + // Parse IP lists + l.parseIPLists(cfg) + + // Create global limiter if configured + if cfg.Enabled && cfg.LimitBy == "global" { l.globalLimiter = NewTokenBucket( float64(cfg.BurstSize), cfg.RequestsPerSecond, ) } - // Start cleanup goroutine - go l.cleanupLoop() + // Start cleanup goroutine only if rate limiting is enabled + if cfg.Enabled { + go l.cleanupLoop() + } - l.logger.Info("msg", "Net limiter initialized", + logger.Info("msg", "Net limiter initialized", "component", "netlimit", + "acl_enabled", hasACL, + "rate_limiting", cfg.Enabled, + "whitelist_rules", len(l.ipWhitelist), + "blacklist_rules", len(l.ipBlacklist), "requests_per_second", cfg.RequestsPerSecond, "burst_size", cfg.BurstSize, "limit_by", cfg.LimitBy) @@ -94,6 +142,120 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter { return l } +// parseIPLists parses and validates IP whitelist/blacklist +func (l *NetLimiter) parseIPLists(cfg config.NetLimitConfig) { + // Parse whitelist + for _, entry := range cfg.IPWhitelist { + if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil { + l.ipWhitelist = append(l.ipWhitelist, ipNet) + } + } + + // Parse blacklist + for _, entry := range cfg.IPBlacklist { + if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil { + l.ipBlacklist = append(l.ipBlacklist, ipNet) + } + } +} + +// parseIPEntry parses a single IP or CIDR entry +func (l *NetLimiter) parseIPEntry(entry, listType string) *net.IPNet { + // Handle single IP + if !strings.Contains(entry, "/") { + ip := net.ParseIP(entry) + if ip == nil { + l.logger.Warn("msg", "Invalid IP entry", + "component", "netlimit", + "list", listType, + "entry", entry) + return nil + } + + // Reject IPv6 + if ip.To4() == nil { + l.logger.Warn("msg", "IPv6 address rejected", + "component", "netlimit", + "list", listType, + "entry", entry, + "reason", IPv4Only) + return nil + } + + return &net.IPNet{IP: ip.To4(), Mask: net.CIDRMask(32, 32)} + } + + // Parse CIDR + ipAddr, ipNet, err := net.ParseCIDR(entry) + if err != nil { + l.logger.Warn("msg", "Invalid CIDR entry", + "component", "netlimit", + "list", listType, + "entry", entry, + "error", err) + return nil + } + + // Reject IPv6 CIDR + if ipAddr.To4() == nil { + l.logger.Warn("msg", "IPv6 CIDR rejected", + "component", "netlimit", + "list", listType, + "entry", entry, + "reason", IPv4Only) + return nil + } + + // Ensure mask is IPv4 + _, bits := ipNet.Mask.Size() + if bits != 32 { + l.logger.Warn("msg", "Non-IPv4 CIDR mask rejected", + "component", "netlimit", + "list", listType, + "entry", entry, + "mask_bits", bits, + "reason", IPv4Only) + return nil + } + + return &net.IPNet{IP: ipAddr.To4(), Mask: ipNet.Mask} +} + +// checkIPAccess checks if an IP is allowed by ACLs +func (l *NetLimiter) checkIPAccess(ip net.IP) DenialReason { + // 1. Check blacklist first (deny takes precedence) + for _, ipNet := range l.ipBlacklist { + if ipNet.Contains(ip) { + l.blockedByBlacklist.Add(1) + l.logger.Debug("msg", "IP denied by blacklist", + "component", "netlimit", + "ip", ip.String(), + "rule", ipNet.String()) + return ReasonBlacklisted + } + } + + // 2. If whitelist is configured, IP must be in it + if len(l.ipWhitelist) > 0 { + for _, ipNet := range l.ipWhitelist { + if ipNet.Contains(ip) { + l.logger.Debug("msg", "IP allowed by whitelist", + "component", "netlimit", + "ip", ip.String(), + "rule", ipNet.String()) + return ReasonAllowed + } + } + l.blockedByWhitelist.Add(1) + l.logger.Debug("msg", "IP not in whitelist", + "component", "netlimit", + "ip", ip.String()) + return ReasonNotWhitelisted + } + + return ReasonAllowed +} + func (l *NetLimiter) Shutdown() { if l == nil { return @@ -121,9 +283,9 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6 l.totalRequests.Add(1) - ip, _, err := net.SplitHostPort(remoteAddr) + // Parse IP address + ipStr, _, err := net.SplitHostPort(remoteAddr) if err != nil { - // If we can't parse the IP, allow the request but log l.logger.Warn("msg", "Failed to parse remote addr", "component", "netlimit", "remote_addr", remoteAddr, @@ -131,56 +293,82 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6 return true, 0, "" } - // Only supporting ipv4 - if !isIPv4(ip) { - // Block non-IPv4 addresses to prevent complications - l.blockedRequests.Add(1) - l.logger.Warn("msg", "Non-IPv4 address blocked", + ip := net.ParseIP(ipStr) + if ip == nil { + l.blockedByInvalidIP.Add(1) + l.logger.Warn("msg", "Failed to parse IP", "component", "netlimit", - "ip", ip) - return false, 403, "IPv4 only" + "ip", ipStr) + return false, 403, string(ReasonInvalidIP) } - // Check connection limit for streaming endpoint + // Reject IPv6 connections + if !isIPv4(ip) { + l.blockedByInvalidIP.Add(1) + l.logger.Warn("msg", "IPv6 connection rejected", + "component", "netlimit", + "ip", ipStr, + "reason", IPv4Only) + return false, 403, IPv4Only + } + + // Normalize to IPv4 representation + ip = ip.To4() + + // Check IP access control + if reason := l.checkIPAccess(ip); reason != ReasonAllowed { + return false, 403, string(reason) + } + + // If rate limiting is not enabled, allow + if !l.config.Enabled { + return true, 0, "" + } + + // Check connection limits if l.config.MaxConnectionsPerIP > 0 { l.connMu.RLock() - counter, exists := l.ipConnections[ip] + tracker, exists := l.ipConnections[ipStr] l.connMu.RUnlock() - if exists && counter.Load() >= l.config.MaxConnectionsPerIP { - l.blockedRequests.Add(1) + if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP { + l.blockedByConnLimit.Add(1) statusCode = l.config.ResponseCode if statusCode == 0 { statusCode = 429 } - message = "Connection limit exceeded" - - l.logger.Warn("msg", "Connection limit exceeded", - "component", "netlimit", - "ip", ip, - "connections", counter.Load(), - "limit", l.config.MaxConnectionsPerIP) - - return false, statusCode, message + return false, statusCode, string(ReasonConnectionLimited) } } - // Check net limit - allowed = l.checkLimit(ip) - if !allowed { - l.blockedRequests.Add(1) + // Check rate limit + if !l.checkLimit(ipStr) { + l.blockedByRateLimit.Add(1) statusCode = l.config.ResponseCode if statusCode == 0 { statusCode = 429 } message = l.config.ResponseMessage if message == "" { - message = "Net limit exceeded" + message = string(ReasonRateLimited) } - l.logger.Debug("msg", "Request net limited", "ip", ip) + return false, statusCode, message } - return allowed, statusCode, message + return true, 0, "" +} + +// Update connection activity +func (l *NetLimiter) updateConnectionActivity(ip string) { + l.connMu.RLock() + tracker, exists := l.ipConnections[ip] + l.connMu.RUnlock() + + if exists { + tracker.mu.Lock() + tracker.lastSeen = time.Now() + tracker.mu.Unlock() + } } // Checks if a TCP connection should be allowed @@ -194,32 +382,45 @@ func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool { // Extract IP from TCP addr tcpAddr, ok := remoteAddr.(*net.TCPAddr) if !ok { - return true - } - - ip := tcpAddr.IP.String() - - // Only supporting ipv4 - if !isIPv4(ip) { - l.blockedRequests.Add(1) - l.logger.Warn("msg", "Non-IPv4 TCP connection blocked", - "component", "netlimit", - "ip", ip) + l.blockedByInvalidIP.Add(1) return false } - allowed := l.checkLimit(ip) - if !allowed { - l.blockedRequests.Add(1) - l.logger.Debug("msg", "TCP connection net limited", "ip", ip) + // Reject IPv6 connections + if !isIPv4(tcpAddr.IP) { + l.blockedByInvalidIP.Add(1) + l.logger.Warn("msg", "IPv6 TCP connection rejected", + "component", "netlimit", + "ip", tcpAddr.IP.String(), + "reason", IPv4Only) + return false } - return allowed + // Normalize to IPv4 representation + ip := tcpAddr.IP.To4() + + // Check IP access control + if reason := l.checkIPAccess(ip); reason != ReasonAllowed { + return false + } + + // If rate limiting is not enabled, allow + if !l.config.Enabled { + return true + } + + // Check rate limit + ipStr := tcpAddr.IP.String() + if !l.checkLimit(ipStr) { + l.blockedByRateLimit.Add(1) + return false + } + + return true } -func isIPv4(ip string) bool { - // Simple check: IPv4 addresses contain dots, IPv6 contain colons - return strings.Contains(ip, ".") && !strings.Contains(ip, ":") +func isIPv4(ip net.IP) bool { + return ip.To4() != nil } // Tracks a new connection for an IP @@ -230,23 +431,44 @@ func (l *NetLimiter) AddConnection(remoteAddr string) { ip, _, err := net.SplitHostPort(remoteAddr) if err != nil { + l.logger.Warn("msg", "Failed to parse remote address in AddConnection", + "component", "netlimit", + "remote_addr", remoteAddr, + "error", err) + return + } + + // IP validation + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + l.logger.Warn("msg", "Failed to parse IP in AddConnection", + "component", "netlimit", + "ip", ip) return } // Only supporting ipv4 - if !isIPv4(ip) { + if !isIPv4(parsedIP) { return } l.connMu.Lock() - counter, exists := l.ipConnections[ip] + tracker, exists := l.ipConnections[ip] if !exists { - counter = &atomic.Int64{} - l.ipConnections[ip] = counter + // Create new tracker with timestamp + tracker = &connTracker{ + lastSeen: time.Now(), + } + l.ipConnections[ip] = tracker } l.connMu.Unlock() - newCount := counter.Add(1) + newCount := tracker.connections.Add(1) + // Update activity timestamp + tracker.mu.Lock() + tracker.lastSeen = time.Now() + tracker.mu.Unlock() + l.logger.Debug("msg", "Connection added", "ip", ip, "connections", newCount) @@ -260,20 +482,33 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) { ip, _, err := net.SplitHostPort(remoteAddr) if err != nil { + l.logger.Warn("msg", "Failed to parse remote address in RemoveConnection", + "component", "netlimit", + "remote_addr", remoteAddr, + "error", err) + return + } + + // IP validation + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + l.logger.Warn("msg", "Failed to parse IP in RemoveConnection", + "component", "netlimit", + "ip", ip) return } // Only supporting ipv4 - if !isIPv4(ip) { + if !isIPv4(parsedIP) { return } l.connMu.RLock() - counter, exists := l.ipConnections[ip] + tracker, exists := l.ipConnections[ip] l.connMu.RUnlock() if exists { - newCount := counter.Add(-1) + newCount := tracker.connections.Add(-1) l.logger.Debug("msg", "Connection removed", "ip", ip, "connections", newCount) @@ -281,7 +516,7 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) { if newCount <= 0 { // Clean up if no more connections l.connMu.Lock() - if counter.Load() <= 0 { + if tracker.connections.Load() <= 0 { delete(l.ipConnections, ip) } l.connMu.Unlock() @@ -292,9 +527,7 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) { // Returns net limiter statistics func (l *NetLimiter) GetStats() map[string]any { if l == nil { - return map[string]any{ - "enabled": false, - } + return map[string]any{"enabled": false} } l.ipMu.RLock() @@ -303,18 +536,36 @@ func (l *NetLimiter) GetStats() map[string]any { l.connMu.RLock() totalConnections := 0 - for _, counter := range l.ipConnections { - totalConnections += int(counter.Load()) + for _, tracker := range l.ipConnections { + totalConnections += int(tracker.connections.Load()) } l.connMu.RUnlock() + totalBlocked := l.blockedByBlacklist.Load() + + l.blockedByWhitelist.Load() + + l.blockedByRateLimit.Load() + + l.blockedByConnLimit.Load() + + l.blockedByInvalidIP.Load() + return map[string]any{ - "enabled": true, - "total_requests": l.totalRequests.Load(), - "blocked_requests": l.blockedRequests.Load(), + "enabled": true, + "total_requests": l.totalRequests.Load(), + "total_blocked": totalBlocked, + "blocked_breakdown": map[string]uint64{ + "blacklist": l.blockedByBlacklist.Load(), + "whitelist": l.blockedByWhitelist.Load(), + "rate_limit": l.blockedByRateLimit.Load(), + "conn_limit": l.blockedByConnLimit.Load(), + "invalid_ip": l.blockedByInvalidIP.Load(), + }, "active_ips": activeIPs, "total_connections": totalConnections, - "config": map[string]any{ + "acl": map[string]int{ + "whitelist_rules": len(l.ipWhitelist), + "blacklist_rules": len(l.ipBlacklist), + }, + "rate_limit": map[string]any{ + "enabled": l.config.Enabled, "requests_per_second": l.config.RequestsPerSecond, "burst_size": l.config.BurstSize, "limit_by": l.config.LimitBy, @@ -324,6 +575,15 @@ func (l *NetLimiter) GetStats() map[string]any { // Performs the actual net limit check func (l *NetLimiter) checkLimit(ip string) bool { + // Validate IP format + parsedIP := net.ParseIP(ip) + if parsedIP == nil || !isIPv4(parsedIP) { + l.logger.Warn("msg", "Invalid or non-IPv4 address in rate limiter", + "component", "netlimit", + "ip", ip) + return false + } + // Maybe run cleanup l.maybeCleanup() @@ -358,10 +618,10 @@ func (l *NetLimiter) checkLimit(ip string) bool { // Check connection limit if configured if l.config.MaxConnectionsPerIP > 0 { l.connMu.RLock() - counter, exists := l.ipConnections[ip] + tracker, exists := l.ipConnections[ip] l.connMu.RUnlock() - if exists && counter.Load() >= l.config.MaxConnectionsPerIP { + if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP { return false } } @@ -379,14 +639,27 @@ func (l *NetLimiter) checkLimit(ip string) bool { // Runs cleanup if enough time has passed func (l *NetLimiter) maybeCleanup() { l.cleanupMu.Lock() - defer l.cleanupMu.Unlock() + // Check if enough time has passed if time.Since(l.lastCleanup) < 30*time.Second { + l.cleanupMu.Unlock() + return + } + + // Check if cleanup already running + if !l.cleanupActive.CompareAndSwap(false, true) { + l.cleanupMu.Unlock() return } l.lastCleanup = time.Now() - go l.cleanup() + l.cleanupMu.Unlock() + + // Run cleanup async + go func() { + defer l.cleanupActive.Store(false) + l.cleanup() + }() } // Removes stale IP limiters @@ -397,6 +670,8 @@ func (l *NetLimiter) cleanup() { l.ipMu.Lock() defer l.ipMu.Unlock() + // Clean up rate limiters + l.ipMu.Lock() cleaned := 0 for ip, lim := range l.ipLimiters { if now.Sub(lim.lastSeen) > staleTimeout { @@ -404,12 +679,37 @@ func (l *NetLimiter) cleanup() { cleaned++ } } + l.ipMu.Unlock() if cleaned > 0 { l.logger.Debug("msg", "Cleaned up stale IP limiters", + "component", "netlimit", "cleaned", cleaned, "remaining", len(l.ipLimiters)) } + + // Clean up stale connection trackers + l.connMu.Lock() + connCleaned := 0 + for ip, tracker := range l.ipConnections { + tracker.mu.Lock() + lastSeen := tracker.lastSeen + tracker.mu.Unlock() + + // Remove if no activity for 5 minutes AND no active connections + if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 { + delete(l.ipConnections, ip) + connCleaned++ + } + } + l.connMu.Unlock() + + if connCleaned > 0 { + l.logger.Debug("msg", "Cleaned up stale connection trackers", + "component", "netlimit", + "cleaned", connCleaned, + "remaining", len(l.ipConnections)) + } } // Runs periodic cleanup diff --git a/src/internal/service/service.go b/src/internal/service/service.go index f9b481f..e32c8cd 100644 --- a/src/internal/service/service.go +++ b/src/internal/service/service.go @@ -139,9 +139,6 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error { // Configure authentication for sinks that support it for _, sinkInst := range pipeline.Sinks { - if setter, ok := sinkInst.(sink.NetAccessSetter); ok { - setter.SetNetAccessConfig(cfg.NetAccess) - } if setter, ok := sinkInst.(sink.AuthSetter); ok { setter.SetAuthConfig(cfg.Auth) } diff --git a/src/internal/sink/http.go b/src/internal/sink/http.go index 9aefe76..818e74e 100644 --- a/src/internal/sink/http.go +++ b/src/internal/sink/http.go @@ -48,7 +48,6 @@ type HTTPSink struct { // Net limiting netLimiter *limit.NetLimiter - ipChecker *limit.IPChecker // Statistics totalProcessed atomic.Uint64 @@ -156,6 +155,22 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo if maxTotal, ok := rl["max_total_connections"].(int64); ok { cfg.NetLimit.MaxTotalConnections = maxTotal } + if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok { + cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist)) + for _, entry := range ipWhitelist { + if str, ok := entry.(string); ok { + cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str) + } + } + } + if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok { + cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist)) + for _, entry := range ipBlacklist { + if str, ok := entry.(string); ok { + cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str) + } + } + } } h := &HTTPSink{ @@ -195,8 +210,10 @@ func (h *HTTPSink) Start(ctx context.Context) error { // Configure TLS if enabled if h.tlsManager != nil { - tlsConfig := h.tlsManager.GetHTTPConfig() - h.server.TLSConfig = tlsConfig + h.server.TLSConfig = h.tlsManager.GetHTTPConfig() + h.logger.Info("msg", "TLS enabled for HTTP sink", + "component", "http_sink", + "port", h.config.Port) } addr := fmt.Sprintf(":%d", h.config.Port) @@ -208,12 +225,13 @@ func (h *HTTPSink) Start(ctx context.Context) error { "component", "http_sink", "port", h.config.Port, "stream_path", h.streamPath, - "status_path", h.statusPath) + "status_path", h.statusPath, + "tls_enabled", h.tlsManager != nil) var err error if h.tlsManager != nil { // HTTPS server - err = h.server.ListenAndServeTLS(addr, "", "") + err = h.server.ListenAndServeTLS(addr, h.config.SSL.CertFile, h.config.SSL.KeyFile) } else { // HTTP server err = h.server.ListenAndServe(addr) @@ -306,24 +324,18 @@ func (h *HTTPSink) GetStats() SinkStats { func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { remoteAddr := ctx.RemoteAddr().String() - // Check IP access control - if h.ipChecker != nil { - if !h.ipChecker.IsAllowed(ctx.RemoteAddr()) { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.SetContentType("text/plain") - ctx.SetBodyString("Forbidden") - return - } - } - // Check net limit if h.netLimiter != nil { if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed { ctx.SetStatusCode(int(statusCode)) ctx.SetContentType("application/json") + h.logger.Warn("msg", "Net limited", + "component", "http_sink", + "remote_addr", remoteAddr, + "status_code", statusCode, + "error", message) json.NewEncoder(ctx).Encode(map[string]any{ - "error": message, - "retry_after": "60", // seconds + "error": "Too many requests", }) return } @@ -355,7 +367,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil { realm := h.authConfig.BasicAuth.Realm if realm == "" { - realm = "LogWisp" + realm = "Restricted" } ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm)) } else if h.authConfig.Type == "bearer" { @@ -364,7 +376,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]string{ - "error": "Authentication required", + "error": "Unauthorized", }) return } @@ -379,8 +391,6 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]any{ "error": "Not Found", - "message": fmt.Sprintf("Available endpoints: %s (SSE transport), %s (status)", - h.streamPath, h.statusPath), }) } } @@ -656,14 +666,6 @@ func (h *HTTPSink) GetStatusPath() string { return h.statusPath } -func (h *HTTPSink) SetNetAccessConfig(cfg *config.NetAccessConfig) { - h.ipChecker = limit.NewIPChecker(cfg, h.logger) - if h.ipChecker != nil { - h.logger.Info("msg", "IP access control configured for HTTP sink", - "component", "http_sink") - } -} - // SetAuthConfig configures http sink authentication func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) { if authCfg == nil || authCfg.Type == "none" { diff --git a/src/internal/sink/http_client.go b/src/internal/sink/http_client.go index cde943b..668c19f 100644 --- a/src/internal/sink/http_client.go +++ b/src/internal/sink/http_client.go @@ -4,8 +4,12 @@ package sink import ( "bytes" "context" + "crypto/tls" + "crypto/x509" "fmt" "net/url" + "os" + "strings" "sync" "sync/atomic" "time" @@ -55,6 +59,7 @@ type HTTPClientConfig struct { // TLS configuration InsecureSkipVerify bool + CAFile string } // NewHTTPClientSink creates a new HTTP client sink @@ -126,6 +131,11 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for cfg.Headers["Content-Type"] = "application/json" } + // Extract TLS options + if caFile, ok := options["ca_file"].(string); ok && caFile != "" { + cfg.CAFile = caFile + } + h := &HTTPClientSink{ input: make(chan core.LogEntry, cfg.BufferSize), config: cfg, @@ -147,17 +157,28 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for DisableHeaderNamesNormalizing: true, } - // TODO: Implement custom TLS configuration, including InsecureSkipVerify, - // by setting a custom dialer on the fasthttp.Client. - // For example: - // if cfg.InsecureSkipVerify { - // h.client.Dial = func(addr string) (net.Conn, error) { - // return fasthttp.DialDualStackTimeout(addr, cfg.Timeout, &tls.Config{ - // InsecureSkipVerify: true, - // }) - // } - // } - // FIXED: Removed incorrect TLS configuration that referenced non-existent field + // Configure TLS if using HTTPS + if strings.HasPrefix(cfg.URL, "https://") { + tlsConfig := &tls.Config{ + InsecureSkipVerify: cfg.InsecureSkipVerify, + } + + // Load custom CA if provided + if cfg.CAFile != "" { + caCert, err := os.ReadFile(cfg.CAFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA file: %w", err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse CA certificate") + } + tlsConfig.RootCAs = caCertPool + } + + // Set TLS config directly on the client + h.client.TLSConfig = tlsConfig + } return h, nil } @@ -341,15 +362,22 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { if attempt > 0 { // Wait before retry time.Sleep(retryDelay) - retryDelay = time.Duration(float64(retryDelay) * h.config.RetryBackoff) + + // Calculate new delay with overflow protection + newDelay := time.Duration(float64(retryDelay) * h.config.RetryBackoff) + + // Cap at maximum to prevent integer overflow + if newDelay > h.config.Timeout || newDelay < retryDelay { + // Either exceeded max or overflowed (negative/wrapped) + retryDelay = h.config.Timeout + } else { + retryDelay = newDelay + } } - // TODO: defer placement issue - // Create request + // Acquire resources inside loop, release immediately after use req := fasthttp.AcquireRequest() - defer fasthttp.ReleaseRequest(req) resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(resp) req.SetRequestURI(h.config.URL) req.Header.SetMethod("POST") @@ -362,35 +390,50 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { // Send request err := h.client.DoTimeout(req, resp, h.config.Timeout) + + // Capture response before releasing + statusCode := resp.StatusCode() + var responseBody []byte + if len(resp.Body()) > 0 { + responseBody = make([]byte, len(resp.Body())) + copy(responseBody, resp.Body()) + } + + // Release immediately, not deferred + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + + // Handle errors if err != nil { lastErr = fmt.Errorf("request failed: %w", err) h.logger.Warn("msg", "HTTP request failed", "component", "http_client_sink", "attempt", attempt+1, + "max_retries", h.config.MaxRetries, "error", err) continue } // Check response status - statusCode := resp.StatusCode() if statusCode >= 200 && statusCode < 300 { // Success h.logger.Debug("msg", "Batch sent successfully", "component", "http_client_sink", "batch_size", len(batch), - "status_code", statusCode) + "status_code", statusCode, + "attempt", attempt+1) return } // Non-2xx status - lastErr = fmt.Errorf("server returned status %d: %s", statusCode, resp.Body()) + lastErr = fmt.Errorf("server returned status %d: %s", statusCode, responseBody) // Don't retry on 4xx errors (client errors) if statusCode >= 400 && statusCode < 500 { h.logger.Error("msg", "Batch rejected by server", "component", "http_client_sink", "status_code", statusCode, - "response", string(resp.Body()), + "response", string(responseBody), "batch_size", len(batch)) h.failedBatches.Add(1) return @@ -400,13 +443,14 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { "component", "http_client_sink", "attempt", attempt+1, "status_code", statusCode, - "response", string(resp.Body())) + "response", string(responseBody)) } - // All retries failed - h.logger.Error("msg", "Failed to send batch after retries", + // All retries exhausted + h.logger.Error("msg", "Failed to send batch after all retries", "component", "http_client_sink", "batch_size", len(batch), + "retries", h.config.MaxRetries, "last_error", lastErr) h.failedBatches.Add(1) } \ No newline at end of file diff --git a/src/internal/sink/sink.go b/src/internal/sink/sink.go index 7533d89..e08e9b4 100644 --- a/src/internal/sink/sink.go +++ b/src/internal/sink/sink.go @@ -34,11 +34,6 @@ type SinkStats struct { Details map[string]any } -// NetAccessSetter is an interface for sinks that can accept network access configuration -type NetAccessSetter interface { - SetNetAccessConfig(cfg *config.NetAccessConfig) -} - // AuthSetter is an interface for sinks that can accept an AuthConfig. type AuthSetter interface { SetAuthConfig(auth *config.AuthConfig) diff --git a/src/internal/sink/tcp.go b/src/internal/sink/tcp.go index 8808796..a4a3996 100644 --- a/src/internal/sink/tcp.go +++ b/src/internal/sink/tcp.go @@ -36,7 +36,6 @@ type TCPSink struct { engineMu sync.Mutex wg sync.WaitGroup netLimiter *limit.NetLimiter - ipChecker *limit.IPChecker logger *log.Logger formatter format.Formatter @@ -50,6 +49,11 @@ type TCPSink struct { lastProcessed atomic.Value // time.Time authFailures atomic.Uint64 authSuccesses atomic.Uint64 + + // Write error tracking + writeErrors atomic.Uint64 + consecutiveWriteErrors map[gnet.Conn]int + errorMu sync.Mutex } // TCPConfig holds TCP sink configuration @@ -141,6 +145,22 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For if maxTotal, ok := rl["max_total_connections"].(int64); ok { cfg.NetLimit.MaxTotalConnections = maxTotal } + if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok { + cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist)) + for _, entry := range ipWhitelist { + if str, ok := entry.(string); ok { + cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str) + } + } + } + if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok { + cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist)) + for _, entry := range ipBlacklist { + if str, ok := entry.(string); ok { + cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str) + } + } + } } t := &TCPSink{ @@ -191,17 +211,6 @@ func (t *TCPSink) Start(ctx context.Context) error { gnet.WithReusePort(true), ) - // Add TLS if configured - if t.tlsManager != nil { - // tlsConfig := t.tlsManager.GetTCPConfig() - // TODO: tlsConfig is not used, wrapper to be implemented, non-TLS stream to be available without wrapper - // ☢ SECURITY: gnet doesn't support TLS natively - would need wrapper - // This is a limitation that requires implementing TLS at application layer - t.logger.Warn("msg", "TLS configured but gnet doesn't support native TLS", - "component", "tcp_sink", - "workaround", "Use stunnel or nginx TCP proxy for TLS termination") - } - // Start gnet server errChan := make(chan error, 1) go func() { @@ -338,7 +347,29 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) { t.server.mu.RLock() for conn, client := range t.server.clients { if client.authenticated { - conn.AsyncWrite(data, nil) + // Send through TLS bridge if present + if client.tlsBridge != nil { + if _, err := client.tlsBridge.Write(data); err != nil { + // TLS write failed, connection likely dead + t.logger.Debug("msg", "TLS write failed", + "component", "tcp_sink", + "error", err) + conn.Close() + } + } else { + conn.AsyncWrite(data, func(c gnet.Conn, err error) error { + if err != nil { + t.writeErrors.Add(1) + t.handleWriteError(c, err) + } else { + // Reset consecutive error count on success + t.errorMu.Lock() + delete(t.consecutiveWriteErrors, c) + t.errorMu.Unlock() + } + return nil + }) + } } } t.server.mu.RUnlock() @@ -364,7 +395,22 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) { continue } } - conn.AsyncWrite(data, nil) + if client.tlsBridge != nil { + if _, err := client.tlsBridge.Write(data); err != nil { + t.logger.Debug("msg", "TLS heartbeat write failed", + "component", "tcp_sink", + "error", err) + conn.Close() + } + } else { + conn.AsyncWrite(data, func(c gnet.Conn, err error) error { + if err != nil { + t.writeErrors.Add(1) + t.handleWriteError(c, err) + } + return nil + }) + } } } t.server.mu.RUnlock() @@ -375,6 +421,36 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) { } } +// Handle write errors with threshold-based connection termination +func (t *TCPSink) handleWriteError(c gnet.Conn, err error) { + t.errorMu.Lock() + defer t.errorMu.Unlock() + + // Track consecutive errors per connection + if t.consecutiveWriteErrors == nil { + t.consecutiveWriteErrors = make(map[gnet.Conn]int) + } + + t.consecutiveWriteErrors[c]++ + errorCount := t.consecutiveWriteErrors[c] + + t.logger.Debug("msg", "AsyncWrite error", + "component", "tcp_sink", + "remote_addr", c.RemoteAddr(), + "error", err, + "consecutive_errors", errorCount) + + // Close connection after 3 consecutive write errors + if errorCount >= 3 { + t.logger.Warn("msg", "Closing connection due to repeated write errors", + "component", "tcp_sink", + "remote_addr", c.RemoteAddr(), + "error_count", errorCount) + delete(t.consecutiveWriteErrors, c) + c.Close() + } +} + // Create heartbeat as a proper LogEntry func (t *TCPSink) createHeartbeatEntry() core.LogEntry { message := "heartbeat" @@ -406,11 +482,13 @@ func (t *TCPSink) GetActiveConnections() int64 { // tcpClient represents a connected TCP client with auth state type tcpClient struct { - conn gnet.Conn - buffer bytes.Buffer - authenticated bool - session *auth.Session - authTimeout time.Time + conn gnet.Conn + buffer bytes.Buffer + authenticated bool + session *auth.Session + authTimeout time.Time + tlsBridge *tls.GNetTLSConn + authTimeoutSet bool } // tcpServer handles gnet events with authentication @@ -434,15 +512,13 @@ func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action { } func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { - remoteAddr := c.RemoteAddr().String() + remoteAddr := c.RemoteAddr() s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr) - // Check IP access control first - if s.sink.ipChecker != nil { - if !s.sink.ipChecker.IsAllowed(c.RemoteAddr()) { - s.sink.logger.Warn("msg", "TCP connection denied by IP filter", - "remote_addr", remoteAddr) - return nil, gnet.Close + // Reject IPv6 connections immediately + if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok { + if tcpAddr.IP.To4() == nil { + return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close } } @@ -467,11 +543,26 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { s.sink.netLimiter.AddConnection(remoteStr) } - // Create client state + // Create client state without auth timeout initially client := &tcpClient{ - conn: c, - authenticated: s.sink.authenticator == nil, // No auth = auto authenticated - authTimeout: time.Now().Add(30 * time.Second), // 30s to authenticate + conn: c, + authenticated: s.sink.authenticator == nil, // No auth = auto authenticated + authTimeoutSet: false, // Auth timeout not started yet + } + + // Initialize TLS bridge if enabled + if s.sink.tlsManager != nil { + tlsConfig := s.sink.tlsManager.GetTCPConfig() + client.tlsBridge = tls.NewServerConn(c, tlsConfig) + client.tlsBridge.Handshake() // Start async handshake + + s.sink.logger.Debug("msg", "TLS handshake initiated", + "component", "tcp_sink", + "remote_addr", remoteAddr) + } else if s.sink.authenticator != nil { + // Only set auth timeout if no TLS (plain connection) + client.authTimeout = time.Now().Add(30 * time.Second) // TODO: configurable or non-hardcoded timer + client.authTimeoutSet = true } s.mu.Lock() @@ -485,7 +576,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { "requires_auth", s.sink.authenticator != nil) // Send auth prompt if authentication is required - if s.sink.authenticator != nil { + if s.sink.authenticator != nil && s.sink.tlsManager == nil { authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH \nMethods: basic, token\n") return authPrompt, gnet.None } @@ -498,9 +589,22 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action { // Remove client state s.mu.Lock() + client := s.clients[c] delete(s.clients, c) s.mu.Unlock() + // Clean up TLS bridge if present + if client != nil && client.tlsBridge != nil { + client.tlsBridge.Close() + s.sink.logger.Debug("msg", "TLS connection closed", + "remote_addr", remoteAddr) + } + + // Clean up write error tracking + s.sink.errorMu.Lock() + delete(s.sink.consecutiveWriteErrors, c) + s.sink.errorMu.Unlock() + // Remove connection tracking if s.sink.netLimiter != nil { s.sink.netLimiter.RemoveConnection(remoteAddr) @@ -523,13 +627,18 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { return gnet.Close } - // Check auth timeout - if !client.authenticated && time.Now().After(client.authTimeout) { - s.sink.logger.Warn("msg", "Authentication timeout", - "remote_addr", c.RemoteAddr().String()) - c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil) - return gnet.Close - } + // // Check auth timeout + // if !client.authenticated && time.Now().After(client.authTimeout) { + // s.sink.logger.Warn("msg", "Authentication timeout", + // "component", "tcp_sink", + // "remote_addr", c.RemoteAddr().String()) + // if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() { + // client.tlsBridge.Write([]byte("AUTH TIMEOUT\n")) + // } else if client.tlsBridge == nil { + // c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil) + // } + // return gnet.Close + // } // Read all available data data, err := c.Next(-1) @@ -540,6 +649,70 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { return gnet.Close } + // Process through TLS bridge if present + if client.tlsBridge != nil { + // Feed encrypted data into TLS engine + if err := client.tlsBridge.ProcessIncoming(data); err != nil { + s.sink.logger.Error("msg", "TLS processing error", + "component", "tcp_sink", + "remote_addr", c.RemoteAddr().String(), + "error", err) + return gnet.Close + } + + // Check if handshake is complete + if !client.tlsBridge.IsHandshakeDone() { + // Still handshaking, wait for more data + return gnet.None + } + + // Check handshake result + _, hsErr := client.tlsBridge.HandshakeComplete() + if hsErr != nil { + s.sink.logger.Error("msg", "TLS handshake failed", + "component", "tcp_sink", + "remote_addr", c.RemoteAddr().String(), + "error", hsErr) + return gnet.Close + } + + // Set auth timeout only after TLS handshake completes + if !client.authTimeoutSet && s.sink.authenticator != nil && !client.authenticated { + client.authTimeout = time.Now().Add(30 * time.Second) + client.authTimeoutSet = true + s.sink.logger.Debug("msg", "Auth timeout started after TLS handshake", + "component", "tcp_sink", + "remote_addr", c.RemoteAddr().String()) + } + + // Read decrypted plaintext + data = client.tlsBridge.Read() + if data == nil || len(data) == 0 { + // No plaintext available yet + return gnet.None + } + + // First data after TLS handshake - send auth prompt if needed + if s.sink.authenticator != nil && !client.authenticated && + len(client.buffer.Bytes()) == 0 { + authPrompt := []byte("AUTH REQUIRED\n") + client.tlsBridge.Write(authPrompt) + } + } + + // Only check auth timeout if it has been set + if !client.authenticated && client.authTimeoutSet && time.Now().After(client.authTimeout) { + s.sink.logger.Warn("msg", "Authentication timeout", + "component", "tcp_sink", + "remote_addr", c.RemoteAddr().String()) + if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() { + client.tlsBridge.Write([]byte("AUTH TIMEOUT\n")) + } else if client.tlsBridge == nil { + c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil) + } + return gnet.Close + } + // If not authenticated, expect auth command if !client.authenticated { client.buffer.Write(data) @@ -551,7 +724,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { // Parse AUTH command: AUTH parts := strings.SplitN(string(line), " ", 3) if len(parts) != 3 || parts[0] != "AUTH" { - c.AsyncWrite([]byte("ERROR: Invalid auth format\n"), nil) + // Send error through TLS if enabled + errMsg := []byte("AUTH FAILED\n") + if client.tlsBridge != nil { + client.tlsBridge.Write(errMsg) + } else { + c.AsyncWrite(errMsg, nil) + } return gnet.None } @@ -563,7 +742,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { "remote_addr", c.RemoteAddr().String(), "method", parts[1], "error", err) - c.AsyncWrite([]byte(fmt.Sprintf("AUTH FAILED: %v\n", err)), nil) + // Send error through TLS if enabled + errMsg := []byte("AUTH FAILED\n") + if client.tlsBridge != nil { + client.tlsBridge.Write(errMsg) + } else { + c.AsyncWrite(errMsg, nil) + } return gnet.Close } @@ -575,11 +760,19 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { s.mu.Unlock() s.sink.logger.Info("msg", "TCP client authenticated", + "component", "tcp_sink", "remote_addr", c.RemoteAddr().String(), "username", session.Username, - "method", session.Method) + "method", session.Method, + "tls", client.tlsBridge != nil) - c.AsyncWrite([]byte("AUTH OK\n"), nil) + // Send success through TLS if enabled + successMsg := []byte("AUTH OK\n") + if client.tlsBridge != nil { + client.tlsBridge.Write(successMsg) + } else { + c.AsyncWrite(successMsg, nil) + } // Clear buffer after auth client.buffer.Reset() @@ -610,7 +803,7 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) { // Initialize TLS manager if SSL is configured if t.config.SSL != nil && t.config.SSL.Enabled { - tlsManager, err := tls.New(t.config.SSL, t.logger) + tlsManager, err := tls.NewManager(t.config.SSL, t.logger) if err != nil { t.logger.Error("msg", "Failed to create TLS manager", "component", "tcp_sink", @@ -624,5 +817,6 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) { t.logger.Info("msg", "Authentication configured for TCP sink", "component", "tcp_sink", "auth_type", authCfg.Type, - "tls_enabled", t.tlsManager != nil) + "tls_enabled", t.tlsManager != nil, + "tls_bridge", t.tlsManager != nil) } \ No newline at end of file diff --git a/src/internal/sink/tcp_client.go b/src/internal/sink/tcp_client.go index c94d11d..fa56113 100644 --- a/src/internal/sink/tcp_client.go +++ b/src/internal/sink/tcp_client.go @@ -3,6 +3,7 @@ package sink import ( "context" + "crypto/tls" "errors" "fmt" "net" @@ -10,8 +11,10 @@ import ( "sync/atomic" "time" + "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" + tlspkg "logwisp/src/internal/tls" "github.com/lixenwraith/log" ) @@ -28,6 +31,10 @@ type TCPClientSink struct { logger *log.Logger formatter format.Formatter + // TLS support + tlsManager *tlspkg.Manager + tlsConfig *tls.Config + // Reconnection state reconnecting atomic.Bool lastConnectErr error @@ -53,6 +60,9 @@ type TCPClientConfig struct { ReconnectDelay time.Duration MaxReconnectDelay time.Duration ReconnectBackoff float64 + + // TLS config + SSL *config.SSLConfig } // NewTCPClientSink creates a new TCP client sink @@ -103,6 +113,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form cfg.ReconnectBackoff = backoff } + // Extract SSL config + if ssl, ok := options["ssl"].(map[string]any); ok { + cfg.SSL = &config.SSLConfig{} + cfg.SSL.Enabled, _ = ssl["enabled"].(bool) + if certFile, ok := ssl["cert_file"].(string); ok { + cfg.SSL.CertFile = certFile + } + if keyFile, ok := ssl["key_file"].(string); ok { + cfg.SSL.KeyFile = keyFile + } + cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool) + if caFile, ok := ssl["client_ca_file"].(string); ok { + cfg.SSL.ClientCAFile = caFile + } + if insecure, ok := ssl["insecure_skip_verify"].(bool); ok { + cfg.SSL.InsecureSkipVerify = insecure + } + } + t := &TCPClientSink{ input: make(chan core.LogEntry, cfg.BufferSize), config: cfg, @@ -114,6 +143,34 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form t.lastProcessed.Store(time.Time{}) t.connectionUptime.Store(time.Duration(0)) + // Initialize TLS manager if SSL is configured + if cfg.SSL != nil && cfg.SSL.Enabled { + tlsManager, err := tlspkg.NewManager(cfg.SSL, logger) + if err != nil { + return nil, fmt.Errorf("failed to create TLS manager: %w", err) + } + t.tlsManager = tlsManager + + // Get client TLS config + t.tlsConfig = tlsManager.GetTCPConfig() + + // ADDED: Client-specific TLS config adjustments + t.tlsConfig.InsecureSkipVerify = cfg.SSL.InsecureSkipVerify + + // Extract server name from address for SNI + host, _, err := net.SplitHostPort(cfg.Address) + if err != nil { + return nil, fmt.Errorf("failed to parse address for SNI: %w", err) + } + t.tlsConfig.ServerName = host + + logger.Info("msg", "TLS enabled for TCP client", + "component", "tcp_client_sink", + "address", cfg.Address, + "server_name", host, + "insecure", cfg.SSL.InsecureSkipVerify) + } + return t, nil } @@ -280,6 +337,35 @@ func (t *TCPClientSink) connect() (net.Conn, error) { tcpConn.SetKeepAlivePeriod(t.config.KeepAlive) } + // Wrap with TLS if configured + if t.tlsConfig != nil { + t.logger.Debug("msg", "Initiating TLS handshake", + "component", "tcp_client_sink", + "address", t.config.Address) + + tlsConn := tls.Client(conn, t.tlsConfig) + + // Perform handshake with timeout + handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := tlsConn.HandshakeContext(handshakeCtx); err != nil { + conn.Close() + return nil, fmt.Errorf("TLS handshake failed: %w", err) + } + + // Log connection details + state := tlsConn.ConnectionState() + t.logger.Info("msg", "TLS connection established", + "component", "tcp_client_sink", + "address", t.config.Address, + "tls_version", tlsVersionString(state.Version), + "cipher_suite", tls.CipherSuiteName(state.CipherSuite), + "server_name", state.ServerName) + + return tlsConn, nil + } + return conn, nil } @@ -295,7 +381,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) { return case <-ticker.C: // Set read deadline - // TODO: Add t.config.ReadTimeout instead of static value + // TODO: Add t.config.ReadTimeout and after addition use it instead of static value if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { t.logger.Debug("msg", "Failed to set read deadline", "error", err) return @@ -378,4 +464,20 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error { } return nil +} + +// tlsVersionString returns human-readable TLS version +func tlsVersionString(version uint16) string { + switch version { + case tls.VersionTLS10: + return "TLS1.0" + case tls.VersionTLS11: + return "TLS1.1" + case tls.VersionTLS12: + return "TLS1.2" + case tls.VersionTLS13: + return "TLS1.3" + default: + return fmt.Sprintf("0x%04x", version) + } } \ No newline at end of file diff --git a/src/internal/source/http.go b/src/internal/source/http.go index 8f03cfa..1fc20da 100644 --- a/src/internal/source/http.go +++ b/src/internal/source/http.go @@ -4,6 +4,8 @@ package source import ( "encoding/json" "fmt" + "logwisp/src/internal/tls" + "net" "sync" "sync/atomic" "time" @@ -29,6 +31,10 @@ type HTTPSource struct { netLimiter *limit.NetLimiter logger *log.Logger + // CHANGED: Add TLS support + tlsManager *tls.Manager + sslConfig *config.SSLConfig + // Statistics totalEntries atomic.Uint64 droppedEntries atomic.Uint64 @@ -94,6 +100,28 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err } } + // Extract SSL config after existing options + if ssl, ok := options["ssl"].(map[string]any); ok { + h.sslConfig = &config.SSLConfig{} + h.sslConfig.Enabled, _ = ssl["enabled"].(bool) + if certFile, ok := ssl["cert_file"].(string); ok { + h.sslConfig.CertFile = certFile + } + if keyFile, ok := ssl["key_file"].(string); ok { + h.sslConfig.KeyFile = keyFile + } + // TODO: extract other SSL options similar to tcp_client_sink + + // Create TLS manager + if h.sslConfig.Enabled { + tlsManager, err := tls.NewManager(h.sslConfig, logger) + if err != nil { + return nil, fmt.Errorf("failed to create TLS manager: %w", err) + } + h.tlsManager = tlsManager + } + } + return h, nil } @@ -123,9 +151,19 @@ func (h *HTTPSource) Start() error { h.logger.Info("msg", "HTTP source server starting", "component", "http_source", "port", h.port, - "ingest_path", h.ingestPath) + "ingest_path", h.ingestPath, + "tls_enabled", h.tlsManager != nil) - if err := h.server.ListenAndServe(addr); err != nil { + var err error + // Check for TLS manager and start the appropriate server type + if h.tlsManager != nil { + h.server.TLSConfig = h.tlsManager.GetHTTPConfig() + err = h.server.ListenAndServeTLS(addr, h.sslConfig.CertFile, h.sslConfig.KeyFile) + } else { + err = h.server.ListenAndServe(addr) + } + + if err != nil { h.logger.Error("msg", "HTTP source server failed", "component", "http_source", "port", h.port, @@ -134,7 +172,7 @@ func (h *HTTPSource) Start() error { }() // Give server time to start - time.Sleep(100 * time.Millisecond) + time.Sleep(100 * time.Millisecond) // TODO: standardize and better manage timers return nil } @@ -202,8 +240,21 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { return } - // Check net limit + // Extract and validate IP remoteAddr := ctx.RemoteAddr().String() + ipStr, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + if ip := net.ParseIP(ipStr); ip != nil && ip.To4() == nil { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(map[string]string{ + "error": "IPv4-only (IPv6 not supported)", + }) + return + } + } + + // Check net limit if h.netLimiter != nil { if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed { ctx.SetStatusCode(int(statusCode)) @@ -280,7 +331,7 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) { // Try to parse as JSON array var array []core.LogEntry if err := json.Unmarshal(body, &array); err == nil { - // TODO: Placeholder; For array, divide total size by entry count as approximation + // NOTE: Placeholder; For array, divide total size by entry count as approximation approxSizePerEntry := int64(len(body) / len(array)) for i, entry := range array { if entry.Message == "" { @@ -292,7 +343,7 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) { if entry.Source == "" { array[i].Source = "http" } - // TODO: Placeholder + // NOTE: Placeholder array[i].RawSize = approxSizePerEntry } return array, nil diff --git a/src/internal/source/tcp.go b/src/internal/source/tcp.go index 288b6cf..173094d 100644 --- a/src/internal/source/tcp.go +++ b/src/internal/source/tcp.go @@ -5,21 +5,31 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net" "sync" "sync/atomic" "time" + "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/limit" + "logwisp/src/internal/tls" "github.com/lixenwraith/log" "github.com/lixenwraith/log/compat" "github.com/panjf2000/gnet/v2" ) +const ( + maxClientBufferSize = 10 * 1024 * 1024 // 10MB max per client + maxLineLength = 1 * 1024 * 1024 // 1MB max per log line + maxEncryptedDataPerRead = 1 * 1024 * 1024 // 1MB max encrypted data per read + maxCumulativeEncrypted = 20 * 1024 * 1024 // 20MB total encrypted before processing +) + // TCPSource receives log entries via TCP connections type TCPSource struct { port int64 @@ -32,6 +42,8 @@ type TCPSource struct { engineMu sync.Mutex wg sync.WaitGroup netLimiter *limit.NetLimiter + tlsManager *tls.Manager + sslConfig *config.SSLConfig logger *log.Logger // Statistics @@ -91,6 +103,32 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error } } + // Extract SSL config and initialize TLS manager + if ssl, ok := options["ssl"].(map[string]any); ok { + t.sslConfig = &config.SSLConfig{} + t.sslConfig.Enabled, _ = ssl["enabled"].(bool) + if certFile, ok := ssl["cert_file"].(string); ok { + t.sslConfig.CertFile = certFile + } + if keyFile, ok := ssl["key_file"].(string); ok { + t.sslConfig.KeyFile = keyFile + } + t.sslConfig.ClientAuth, _ = ssl["client_auth"].(bool) + if caFile, ok := ssl["client_ca_file"].(string); ok { + t.sslConfig.ClientCAFile = caFile + } + t.sslConfig.VerifyClientCert, _ = ssl["verify_client_cert"].(bool) + + // Create TLS manager if enabled + if t.sslConfig.Enabled { + tlsManager, err := tls.NewManager(t.sslConfig, logger) + if err != nil { + return nil, fmt.Errorf("failed to create TLS manager: %w", err) + } + t.tlsManager = tlsManager + } + } + return t, nil } @@ -121,7 +159,8 @@ func (t *TCPSource) Start() error { defer t.wg.Done() t.logger.Info("msg", "TCP source server starting", "component", "tcp_source", - "port", t.port) + "port", t.port, + "tls_enabled", t.tlsManager != nil) err := gnet.Run(t.server, addr, gnet.WithLogger(gnetLogger), @@ -233,8 +272,14 @@ func (t *TCPSource) publish(entry core.LogEntry) bool { // tcpClient represents a connected TCP client type tcpClient struct { - conn gnet.Conn - buffer bytes.Buffer + conn gnet.Conn + buffer bytes.Buffer + authenticated bool + session *auth.Session + authTimeout time.Time + tlsBridge *tls.GNetTLSConn + maxBufferSeen int + cumulativeEncrypted int64 } // tcpSourceServer handles gnet events @@ -265,8 +310,7 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { // Check net limit if s.source.netLimiter != nil { - remoteStr := c.RemoteAddr().String() - tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr) + tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr) if err != nil { s.source.logger.Warn("msg", "Failed to parse TCP address", "component", "tcp_source", @@ -283,7 +327,21 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { } // Track connection - s.source.netLimiter.AddConnection(remoteStr) + s.source.netLimiter.AddConnection(remoteAddr) + } + + // Create client state + client := &tcpClient{conn: c} + + // Initialize TLS bridge if enabled + if s.source.tlsManager != nil { + tlsConfig := s.source.tlsManager.GetTCPConfig() + client.tlsBridge = tls.NewServerConn(c, tlsConfig) + client.tlsBridge.Handshake() // Start async handshake + + s.source.logger.Debug("msg", "TLS handshake initiated", + "component", "tcp_source", + "remote_addr", remoteAddr) } // Create client state @@ -295,7 +353,8 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { s.source.logger.Debug("msg", "TCP connection opened", "component", "tcp_source", "remote_addr", remoteAddr, - "active_connections", newCount) + "active_connections", newCount, + "tls_enabled", s.source.tlsManager != nil) return nil, gnet.None } @@ -305,9 +364,18 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action { // Remove client state s.mu.Lock() + client := s.clients[c] delete(s.clients, c) s.mu.Unlock() + // Clean up TLS bridge if present + if client != nil && client.tlsBridge != nil { + client.tlsBridge.Close() + s.source.logger.Debug("msg", "TLS connection closed", + "component", "tcp_source", + "remote_addr", remoteAddr) + } + // Remove connection tracking if s.source.netLimiter != nil { s.source.netLimiter.RemoveConnection(remoteAddr) @@ -340,9 +408,113 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { return gnet.Close } + // Check encrypted data size BEFORE processing through TLS + if len(data) > maxEncryptedDataPerRead { + s.source.logger.Warn("msg", "Encrypted data per read limit exceeded", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String(), + "data_size", len(data), + "limit", maxEncryptedDataPerRead) + s.source.invalidEntries.Add(1) + return gnet.Close + } + + // Track cumulative encrypted data to prevent slow accumulation + client.cumulativeEncrypted += int64(len(data)) + if client.cumulativeEncrypted > maxCumulativeEncrypted { + s.source.logger.Warn("msg", "Cumulative encrypted data limit exceeded", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String(), + "total_encrypted", client.cumulativeEncrypted, + "limit", maxCumulativeEncrypted) + s.source.invalidEntries.Add(1) + return gnet.Close + } + + // Process through TLS bridge if present + if client.tlsBridge != nil { + // Feed encrypted data into TLS engine + if err := client.tlsBridge.ProcessIncoming(data); err != nil { + if errors.Is(err, tls.ErrTLSBackpressure) { + s.source.logger.Warn("msg", "TLS backpressure, closing slow client", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String()) + } else { + s.source.logger.Error("msg", "TLS processing error", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String(), + "error", err) + } + return gnet.Close + } + + // Check if handshake is complete + if !client.tlsBridge.IsHandshakeDone() { + // Still handshaking, wait for more data + return gnet.None + } + + // Check handshake result + _, hsErr := client.tlsBridge.HandshakeComplete() + if hsErr != nil { + s.source.logger.Error("msg", "TLS handshake failed", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String(), + "error", hsErr) + return gnet.Close + } + + // Read decrypted plaintext + data = client.tlsBridge.Read() + if data == nil || len(data) == 0 { + // No plaintext available yet + return gnet.None + } + // Reset cumulative counter after successful decryption and processing + client.cumulativeEncrypted = 0 + } + + // Check buffer size before appending + if client.buffer.Len()+len(data) > maxClientBufferSize { + s.source.logger.Warn("msg", "Client buffer limit exceeded", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String(), + "buffer_size", client.buffer.Len(), + "incoming_size", len(data)) + s.source.invalidEntries.Add(1) + return gnet.Close + } + // Append to client buffer client.buffer.Write(data) + // Track high buffer + if client.buffer.Len() > client.maxBufferSeen { + client.maxBufferSeen = client.buffer.Len() + } + + // Check for suspiciously long lines before attempting to read + if client.buffer.Len() > maxLineLength { + // Scan for newline in current buffer + bufBytes := client.buffer.Bytes() + hasNewline := false + for _, b := range bufBytes { + if b == '\n' { + hasNewline = true + break + } + } + + if !hasNewline { + s.source.logger.Warn("msg", "Line too long without newline", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String(), + "buffer_size", client.buffer.Len()) + s.source.invalidEntries.Add(1) + return gnet.Close + } + } + // Process complete lines for { line, err := client.buffer.ReadBytes('\n') diff --git a/src/internal/tls/gnet_bridge.go b/src/internal/tls/gnet_bridge.go new file mode 100644 index 0000000..50f5a7b --- /dev/null +++ b/src/internal/tls/gnet_bridge.go @@ -0,0 +1,341 @@ +// FILE: src/internal/tls/gnet_bridge.go +package tls + +import ( + "crypto/tls" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/panjf2000/gnet/v2" +) + +var ( + ErrTLSBackpressure = errors.New("TLS processing backpressure") + ErrConnectionClosed = errors.New("connection closed") + ErrPlaintextBufferExceeded = errors.New("plaintext buffer size exceeded") +) + +// Maximum plaintext buffer size to prevent memory exhaustion +const maxPlaintextBufferSize = 32 * 1024 * 1024 // 32MB + +// GNetTLSConn bridges gnet.Conn with crypto/tls via io.Pipe +type GNetTLSConn struct { + gnetConn gnet.Conn + tlsConn *tls.Conn + config *tls.Config + + // Buffered channels for non-blocking operation + incomingCipher chan []byte // Network → TLS (encrypted) + outgoingCipher chan []byte // TLS → Network (encrypted) + + // Handshake state + handshakeOnce sync.Once + handshakeDone chan struct{} + handshakeErr error + + // Decrypted data buffer + plainBuf []byte + plainMu sync.Mutex + + // Lifecycle + closed atomic.Bool + closeOnce sync.Once + wg sync.WaitGroup + + // Error tracking + lastErr atomic.Value // error + logger interface{ Warn(args ...any) } // Minimal logger interface +} + +// NewServerConn creates a server-side TLS bridge +func NewServerConn(gnetConn gnet.Conn, config *tls.Config) *GNetTLSConn { + tc := &GNetTLSConn{ + gnetConn: gnetConn, + config: config, + handshakeDone: make(chan struct{}), + // Buffered channels sized for throughput without blocking + incomingCipher: make(chan []byte, 128), // 128 packets buffer + outgoingCipher: make(chan []byte, 128), + plainBuf: make([]byte, 0, 65536), // 64KB initial capacity + } + + // Create TLS conn with channel-based transport + rawConn := &channelConn{ + incoming: tc.incomingCipher, + outgoing: tc.outgoingCipher, + localAddr: gnetConn.LocalAddr(), + remoteAddr: gnetConn.RemoteAddr(), + tc: tc, + } + tc.tlsConn = tls.Server(rawConn, config) + + // Start pump goroutines + tc.wg.Add(2) + go tc.pumpCipherToNetwork() + go tc.pumpPlaintextFromTLS() + + return tc +} + +// NewClientConn creates a client-side TLS bridge (similar changes) +func NewClientConn(gnetConn gnet.Conn, config *tls.Config, serverName string) *GNetTLSConn { + tc := &GNetTLSConn{ + gnetConn: gnetConn, + config: config, + handshakeDone: make(chan struct{}), + incomingCipher: make(chan []byte, 128), + outgoingCipher: make(chan []byte, 128), + plainBuf: make([]byte, 0, 65536), + } + + if config.ServerName == "" { + config = config.Clone() + config.ServerName = serverName + } + + rawConn := &channelConn{ + incoming: tc.incomingCipher, + outgoing: tc.outgoingCipher, + localAddr: gnetConn.LocalAddr(), + remoteAddr: gnetConn.RemoteAddr(), + tc: tc, + } + tc.tlsConn = tls.Client(rawConn, config) + + tc.wg.Add(2) + go tc.pumpCipherToNetwork() + go tc.pumpPlaintextFromTLS() + + return tc +} + +// ProcessIncoming feeds encrypted data from network into TLS engine (non-blocking) +func (tc *GNetTLSConn) ProcessIncoming(encryptedData []byte) error { + if tc.closed.Load() { + return ErrConnectionClosed + } + + // Non-blocking send with backpressure detection + select { + case tc.incomingCipher <- encryptedData: + return nil + default: + // Channel full - TLS processing can't keep up + // Drop connection under backpressure vs blocking event loop + if tc.logger != nil { + tc.logger.Warn("msg", "TLS backpressure, dropping data", + "remote_addr", tc.gnetConn.RemoteAddr()) + } + return ErrTLSBackpressure + } +} + +// pumpCipherToNetwork sends TLS-encrypted data to network +func (tc *GNetTLSConn) pumpCipherToNetwork() { + defer tc.wg.Done() + + for { + select { + case data, ok := <-tc.outgoingCipher: + if !ok { + return + } + // Send to network + if err := tc.gnetConn.AsyncWrite(data, nil); err != nil { + tc.lastErr.Store(err) + tc.Close() + return + } + case <-time.After(30 * time.Second): + // Keepalive/timeout check + if tc.closed.Load() { + return + } + } + } +} + +// pumpPlaintextFromTLS reads decrypted data from TLS +func (tc *GNetTLSConn) pumpPlaintextFromTLS() { + defer tc.wg.Done() + buf := make([]byte, 32768) // 32KB read buffer + + for { + n, err := tc.tlsConn.Read(buf) + if n > 0 { + tc.plainMu.Lock() + // Check buffer size limit before appending to prevent memory exhaustion + if len(tc.plainBuf)+n > maxPlaintextBufferSize { + tc.plainMu.Unlock() + // Log warning about buffer limit + if tc.logger != nil { + tc.logger.Warn("msg", "Plaintext buffer limit exceeded, closing connection", + "remote_addr", tc.gnetConn.RemoteAddr(), + "buffer_size", len(tc.plainBuf), + "incoming_size", n, + "limit", maxPlaintextBufferSize) + } + // Store error and close connection + tc.lastErr.Store(ErrPlaintextBufferExceeded) + tc.Close() + return + } + tc.plainBuf = append(tc.plainBuf, buf[:n]...) + tc.plainMu.Unlock() + } + if err != nil { + if err != io.EOF { + tc.lastErr.Store(err) + } + tc.Close() + return + } + } +} + +// Read returns available decrypted plaintext (non-blocking) +func (tc *GNetTLSConn) Read() []byte { + tc.plainMu.Lock() + defer tc.plainMu.Unlock() + + if len(tc.plainBuf) == 0 { + return nil + } + + // Atomic buffer swap under mutex protection to prevent race condition + data := tc.plainBuf + tc.plainBuf = make([]byte, 0, cap(tc.plainBuf)) + return data +} + +// Write encrypts plaintext and queues for network transmission +func (tc *GNetTLSConn) Write(plaintext []byte) (int, error) { + if tc.closed.Load() { + return 0, ErrConnectionClosed + } + + if !tc.IsHandshakeDone() { + return 0, errors.New("handshake not complete") + } + + return tc.tlsConn.Write(plaintext) +} + +// Handshake initiates TLS handshake asynchronously +func (tc *GNetTLSConn) Handshake() { + tc.handshakeOnce.Do(func() { + go func() { + tc.handshakeErr = tc.tlsConn.Handshake() + close(tc.handshakeDone) + }() + }) +} + +// IsHandshakeDone checks if handshake is complete +func (tc *GNetTLSConn) IsHandshakeDone() bool { + select { + case <-tc.handshakeDone: + return true + default: + return false + } +} + +// HandshakeComplete waits for handshake completion +func (tc *GNetTLSConn) HandshakeComplete() (<-chan struct{}, error) { + <-tc.handshakeDone + return tc.handshakeDone, tc.handshakeErr +} + +// Close shuts down the bridge +func (tc *GNetTLSConn) Close() error { + tc.closeOnce.Do(func() { + tc.closed.Store(true) + + // Close TLS connection + tc.tlsConn.Close() + + // Close channels to stop pumps + close(tc.incomingCipher) + close(tc.outgoingCipher) + }) + + // Wait for pumps to finish + tc.wg.Wait() + return nil +} + +// GetConnectionState returns TLS connection state +func (tc *GNetTLSConn) GetConnectionState() tls.ConnectionState { + return tc.tlsConn.ConnectionState() +} + +// GetError returns last error +func (tc *GNetTLSConn) GetError() error { + if err, ok := tc.lastErr.Load().(error); ok { + return err + } + return nil +} + +// channelConn implements net.Conn over channels +type channelConn struct { + incoming <-chan []byte + outgoing chan<- []byte + localAddr net.Addr + remoteAddr net.Addr + tc *GNetTLSConn + readBuf []byte +} + +func (c *channelConn) Read(b []byte) (int, error) { + // Use buffered read for efficiency + if len(c.readBuf) > 0 { + n := copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + return n, nil + } + + // Wait for new data + select { + case data, ok := <-c.incoming: + if !ok { + return 0, io.EOF + } + n := copy(b, data) + if n < len(data) { + c.readBuf = data[n:] // Buffer remainder + } + return n, nil + case <-time.After(30 * time.Second): + return 0, errors.New("read timeout") + } +} + +func (c *channelConn) Write(b []byte) (int, error) { + if c.tc.closed.Load() { + return 0, ErrConnectionClosed + } + + // Make a copy since TLS may hold reference + data := make([]byte, len(b)) + copy(data, b) + + select { + case c.outgoing <- data: + return len(b), nil + case <-time.After(5 * time.Second): + return 0, errors.New("write timeout") + } +} + +func (c *channelConn) Close() error { return nil } +func (c *channelConn) LocalAddr() net.Addr { return c.localAddr } +func (c *channelConn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *channelConn) SetDeadline(t time.Time) error { return nil } +func (c *channelConn) SetReadDeadline(t time.Time) error { return nil } +func (c *channelConn) SetWriteDeadline(t time.Time) error { return nil } \ No newline at end of file diff --git a/src/internal/tls/manager.go b/src/internal/tls/manager.go index b6067d0..6c6c418 100644 --- a/src/internal/tls/manager.go +++ b/src/internal/tls/manager.go @@ -20,8 +20,8 @@ type Manager struct { logger *log.Logger } -// New creates a TLS configuration from SSL config -func New(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) { +// NewManager creates a TLS configuration from SSL config +func NewManager(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) { if cfg == nil || !cfg.Enabled { return nil, nil }