v0.4.1 authentication impelemented, not tested and docs not updated

This commit is contained in:
2025-09-23 12:03:42 -04:00
parent 4248d399b3
commit 45b2093569
21 changed files with 1779 additions and 453 deletions

1
go.mod
View File

@ -10,6 +10,7 @@ require (
github.com/valyala/fasthttp v1.65.0 github.com/valyala/fasthttp v1.65.0
golang.org/x/crypto v0.42.0 golang.org/x/crypto v0.42.0
golang.org/x/term v0.35.0 golang.org/x/term v0.35.0
golang.org/x/time v0.13.0
) )
require ( require (

2
go.sum
View File

@ -42,6 +42,8 @@ golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=

View File

@ -3,8 +3,10 @@ package auth
import ( import (
"bufio" "bufio"
"crypto/rand"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -15,8 +17,12 @@ import (
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/time/rate"
) )
// Prevent unbounded map growth
const maxAuthTrackedIPs = 10000
// Authenticator handles all authentication methods for a pipeline // Authenticator handles all authentication methods for a pipeline
type Authenticator struct { type Authenticator struct {
config *config.AuthConfig config *config.AuthConfig
@ -30,6 +36,18 @@ type Authenticator struct {
// Session tracking // Session tracking
sessions map[string]*Session sessions map[string]*Session
sessionMu sync.RWMutex sessionMu sync.RWMutex
// Brute-force protection
ipAuthAttempts map[string]*ipAuthState
authMu sync.RWMutex
}
// ADDED: Per-IP auth attempt tracking
type ipAuthState struct {
limiter *rate.Limiter
failCount int
lastAttempt time.Time
blockedUntil time.Time
} }
// Session represents an authenticated connection // Session represents an authenticated connection
@ -55,6 +73,7 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
basicUsers: make(map[string]string), basicUsers: make(map[string]string),
bearerTokens: make(map[string]bool), bearerTokens: make(map[string]bool),
sessions: make(map[string]*Session), sessions: make(map[string]*Session),
ipAuthAttempts: make(map[string]*ipAuthState),
} }
// Initialize Basic Auth users // Initialize Basic Auth users
@ -82,6 +101,7 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
a.jwtParser = jwt.NewParser( a.jwtParser = jwt.NewParser(
jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}), jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}),
jwt.WithLeeway(5*time.Second), jwt.WithLeeway(5*time.Second),
jwt.WithExpirationRequired(),
) )
// Setup key function // Setup key function
@ -102,6 +122,9 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
// Start session cleanup // Start session cleanup
go a.sessionCleanup() go a.sessionCleanup()
// Start auth attempt cleanup
go a.authAttemptCleanup()
logger.Info("msg", "Authenticator initialized", logger.Info("msg", "Authenticator initialized",
"component", "auth", "component", "auth",
"type", cfg.Type) "type", cfg.Type)
@ -109,6 +132,129 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
return a, nil return a, nil
} }
// Check and enforce rate limits
func (a *Authenticator) checkRateLimit(remoteAddr string) error {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ip = remoteAddr // Fallback for malformed addresses
}
a.authMu.Lock()
defer a.authMu.Unlock()
state, exists := a.ipAuthAttempts[ip]
now := time.Now()
if !exists {
// Check map size limit before creating new entry
if len(a.ipAuthAttempts) >= maxAuthTrackedIPs {
// Evict an old entry using simplified LRU
// Sample 20 random entries and evict the oldest
const sampleSize = 20
var oldestIP string
oldestTime := now
// Build sample
sampled := 0
for sampledIP, sampledState := range a.ipAuthAttempts {
if sampledState.lastAttempt.Before(oldestTime) {
oldestIP = sampledIP
oldestTime = sampledState.lastAttempt
}
sampled++
if sampled >= sampleSize {
break
}
}
// Evict the oldest from our sample
if oldestIP != "" {
delete(a.ipAuthAttempts, oldestIP)
a.logger.Debug("msg", "Evicted old auth attempt state",
"component", "auth",
"evicted_ip", oldestIP,
"last_seen", oldestTime)
}
}
// Create new state for this IP
// 5 attempts per minute, burst of 3
state = &ipAuthState{
limiter: rate.NewLimiter(rate.Every(12*time.Second), 3),
lastAttempt: now,
}
a.ipAuthAttempts[ip] = state
}
// Check if IP is temporarily blocked
if now.Before(state.blockedUntil) {
remaining := state.blockedUntil.Sub(now)
a.logger.Warn("msg", "IP temporarily blocked",
"component", "auth",
"ip", ip,
"remaining", remaining)
// Sleep to slow down even blocked attempts
time.Sleep(2 * time.Second)
return fmt.Errorf("temporarily blocked, try again in %v", remaining.Round(time.Second))
}
// Check rate limit
if !state.limiter.Allow() {
state.failCount++
// Only set new blockedUntil if not already blocked
// This prevents indefinite block extension
if state.blockedUntil.IsZero() || now.After(state.blockedUntil) {
// Progressive blocking: 2^failCount minutes
blockMinutes := 1 << min(state.failCount, 6) // Cap at 64 minutes
state.blockedUntil = now.Add(time.Duration(blockMinutes) * time.Minute)
a.logger.Warn("msg", "Rate limit exceeded, blocking IP",
"component", "auth",
"ip", ip,
"fail_count", state.failCount,
"block_duration", time.Duration(blockMinutes)*time.Minute)
}
return fmt.Errorf("rate limit exceeded")
}
state.lastAttempt = now
return nil
}
// Record failed attempt
func (a *Authenticator) recordFailure(remoteAddr string) {
ip, _, _ := net.SplitHostPort(remoteAddr)
if ip == "" {
ip = remoteAddr
}
a.authMu.Lock()
defer a.authMu.Unlock()
if state, exists := a.ipAuthAttempts[ip]; exists {
state.failCount++
state.lastAttempt = time.Now()
}
}
// Reset failure count on success
func (a *Authenticator) recordSuccess(remoteAddr string) {
ip, _, _ := net.SplitHostPort(remoteAddr)
if ip == "" {
ip = remoteAddr
}
a.authMu.Lock()
defer a.authMu.Unlock()
if state, exists := a.ipAuthAttempts[ip]; exists {
state.failCount = 0
state.blockedUntil = time.Time{}
}
}
// AuthenticateHTTP handles HTTP authentication headers // AuthenticateHTTP handles HTTP authentication headers
func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) { func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) {
if a == nil || a.config.Type == "none" { if a == nil || a.config.Type == "none" {
@ -120,14 +266,31 @@ func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Sessio
}, nil }, nil
} }
// Check rate limit
if err := a.checkRateLimit(remoteAddr); err != nil {
return nil, err
}
var session *Session
var err error
switch a.config.Type { switch a.config.Type {
case "basic": case "basic":
return a.authenticateBasic(authHeader, remoteAddr) session, err = a.authenticateBasic(authHeader, remoteAddr)
case "bearer": case "bearer":
return a.authenticateBearer(authHeader, remoteAddr) session, err = a.authenticateBearer(authHeader, remoteAddr)
default: default:
return nil, fmt.Errorf("unsupported auth type: %s", a.config.Type) err = fmt.Errorf("unsupported auth type: %s", a.config.Type)
} }
if err != nil {
a.recordFailure(remoteAddr)
time.Sleep(500 * time.Millisecond)
return nil, err
}
a.recordSuccess(remoteAddr)
return session, nil
} }
// AuthenticateTCP handles TCP connection authentication // AuthenticateTCP handles TCP connection authentication
@ -141,32 +304,54 @@ func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string)
}, nil }, nil
} }
// Check rate limit first
if err := a.checkRateLimit(remoteAddr); err != nil {
return nil, err
}
var session *Session
var err error
// TCP auth protocol: AUTH <method> <credentials> // TCP auth protocol: AUTH <method> <credentials>
switch strings.ToLower(method) { switch strings.ToLower(method) {
case "token": case "token":
if a.config.Type != "bearer" { if a.config.Type != "bearer" {
return nil, fmt.Errorf("token auth not configured") err = fmt.Errorf("token auth not configured")
} else {
session, err = a.validateToken(credentials, remoteAddr)
} }
return a.validateToken(credentials, remoteAddr)
case "basic": case "basic":
if a.config.Type != "basic" { if a.config.Type != "basic" {
return nil, fmt.Errorf("basic auth not configured") err = fmt.Errorf("basic auth not configured")
} } else {
// Expect base64(username:password) // Expect base64(username:password)
decoded, err := base64.StdEncoding.DecodeString(credentials) decoded, decErr := base64.StdEncoding.DecodeString(credentials)
if err != nil { if decErr != nil {
return nil, fmt.Errorf("invalid credentials encoding") err = fmt.Errorf("invalid credentials encoding")
} } else {
parts := strings.SplitN(string(decoded), ":", 2) parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) != 2 { if len(parts) != 2 {
return nil, fmt.Errorf("invalid credentials format") err = fmt.Errorf("invalid credentials format")
} else {
session, err = a.validateBasicAuth(parts[0], parts[1], remoteAddr)
}
}
} }
return a.validateBasicAuth(parts[0], parts[1], remoteAddr)
default: default:
return nil, fmt.Errorf("unsupported auth method: %s", method) err = fmt.Errorf("unsupported auth method: %s", method)
} }
if err != nil {
a.recordFailure(remoteAddr)
// Add delay on failure
time.Sleep(500 * time.Millisecond)
return nil, err
}
a.recordSuccess(remoteAddr)
return session, nil
} }
func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) { func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) {
@ -255,6 +440,23 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error
return nil, fmt.Errorf("invalid JWT token") return nil, fmt.Errorf("invalid JWT token")
} }
// Explicit expiration check
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return nil, fmt.Errorf("token expired")
}
} else {
// Reject tokens without expiration
return nil, fmt.Errorf("token missing expiration claim")
}
// Check not-before claim
if nbf, ok := claims["nbf"].(float64); ok {
if time.Now().Unix() < int64(nbf) {
return nil, fmt.Errorf("token not yet valid")
}
}
// Check issuer if configured // Check issuer if configured
if a.config.BearerAuth.JWT.Issuer != "" { if a.config.BearerAuth.JWT.Issuer != "" {
if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer { if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer {
@ -264,7 +466,20 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error
// Check audience if configured // Check audience if configured
if a.config.BearerAuth.JWT.Audience != "" { if a.config.BearerAuth.JWT.Audience != "" {
if aud, ok := claims["aud"].(string); !ok || aud != a.config.BearerAuth.JWT.Audience { // Handle both string and []string audience formats
audValid := false
switch aud := claims["aud"].(type) {
case string:
audValid = aud == a.config.BearerAuth.JWT.Audience
case []interface{}:
for _, aa := range aud {
if audStr, ok := aa.(string); ok && audStr == a.config.BearerAuth.JWT.Audience {
audValid = true
break
}
}
}
if !audValid {
return nil, fmt.Errorf("invalid token audience") return nil, fmt.Errorf("invalid token audience")
} }
} }
@ -322,6 +537,27 @@ func (a *Authenticator) sessionCleanup() {
} }
} }
// Cleanup old auth attempts
func (a *Authenticator) authAttemptCleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
a.authMu.Lock()
now := time.Now()
for ip, state := range a.ipAuthAttempts {
// Remove entries older than 1 hour with no recent activity
if now.Sub(state.lastAttempt) > time.Hour {
delete(a.ipAuthAttempts, ip)
a.logger.Debug("msg", "Cleaned up auth attempt state",
"component", "auth",
"ip", ip)
}
}
a.authMu.Unlock()
}
}
func (a *Authenticator) loadUsersFile(path string) error { func (a *Authenticator) loadUsersFile(path string) error {
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
@ -366,7 +602,12 @@ func (a *Authenticator) loadUsersFile(path string) error {
} }
func generateSessionID() string { func generateSessionID() string {
return fmt.Sprintf("%d-%d", time.Now().UnixNano(), time.Now().Unix()) b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
// Fallback to a less secure method if crypto/rand fails
return fmt.Sprintf("fallback-%d", time.Now().UnixNano())
}
return base64.URLEncoding.EncodeToString(b)
} }
// ValidateSession checks if a session is still valid // ValidateSession checks if a session is still valid

View File

@ -3,8 +3,6 @@ package config
import ( import (
"fmt" "fmt"
"net"
"strings"
) )
type AuthConfig struct { type AuthConfig struct {

View File

@ -3,7 +3,6 @@ package config
import ( import (
"fmt" "fmt"
"net"
"strings" "strings"
) )
@ -29,37 +28,6 @@ type RateLimitConfig struct {
MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"` MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"`
} }
func validateNetAccess(pipelineName string, cfg *NetAccessConfig) error {
if cfg == nil {
return nil
}
// Validate CIDR notation
for _, cidr := range cfg.IPWhitelist {
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
if _, _, err := net.ParseCIDR(cidr); err != nil {
if net.ParseIP(cidr) == nil {
return fmt.Errorf("pipeline '%s': invalid IP whitelist entry: %s", pipelineName, cidr)
}
}
}
for _, cidr := range cfg.IPBlacklist {
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
if _, _, err := net.ParseCIDR(cidr); err != nil {
if net.ParseIP(cidr) == nil {
return fmt.Errorf("pipeline '%s': invalid IP blacklist entry: %s", pipelineName, cidr)
}
}
}
return nil
}
func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error {
if cfg == nil { if cfg == nil {
return nil return nil

View File

@ -20,9 +20,6 @@ type PipelineConfig struct {
// Rate limiting // Rate limiting
RateLimit *RateLimitConfig `toml:"rate_limit"` RateLimit *RateLimitConfig `toml:"rate_limit"`
// Network access control (IP filtering)
NetAccess *NetAccessConfig `toml:"net_access"`
// Filter configuration // Filter configuration
Filters []FilterConfig `toml:"filters"` Filters []FilterConfig `toml:"filters"`
@ -37,12 +34,6 @@ type PipelineConfig struct {
Auth *AuthConfig `toml:"auth"` Auth *AuthConfig `toml:"auth"`
} }
// NetAccessConfig defines IP-based access control lists
type NetAccessConfig struct {
IPWhitelist []string `toml:"ip_whitelist"`
IPBlacklist []string `toml:"ip_blacklist"`
}
// SourceConfig represents an input data source // SourceConfig represents an input data source
type SourceConfig struct { type SourceConfig struct {
// Source type: "directory", "file", "stdin", etc. // Source type: "directory", "file", "stdin", etc.
@ -132,6 +123,13 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
} }
} }
// CHANGED: Validate SSL if present
if ssl, ok := cfg.Options["ssl"].(map[string]any); ok {
if err := validateSSLOptions("HTTP source", pipelineName, sourceIndex, ssl); err != nil {
return err
}
}
case "tcp": case "tcp":
// Validate TCP source options // Validate TCP source options
port, ok := cfg.Options["port"].(int64) port, ok := cfg.Options["port"].(int64)
@ -147,6 +145,13 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
} }
} }
// CHANGED: Validate SSL if present
if ssl, ok := cfg.Options["ssl"].(map[string]any); ok {
if err := validateSSLOptions("TCP source", pipelineName, sourceIndex, ssl); err != nil {
return err
}
}
default: default:
return fmt.Errorf("pipeline '%s' source[%d]: unknown source type '%s'", return fmt.Errorf("pipeline '%s' source[%d]: unknown source type '%s'",
pipelineName, sourceIndex, cfg.Type) pipelineName, sourceIndex, cfg.Type)

View File

@ -1,7 +1,11 @@
// FILE: logwisp/src/internal/config/server.go // FILE: logwisp/src/internal/config/server.go
package config package config
import "fmt" import (
"fmt"
"net"
"strings"
)
type TCPConfig struct { type TCPConfig struct {
Enabled bool `toml:"enabled"` Enabled bool `toml:"enabled"`
@ -49,6 +53,10 @@ type NetLimitConfig struct {
// Enable net limiting // Enable net limiting
Enabled bool `toml:"enabled"` Enabled bool `toml:"enabled"`
// IP Access Control Lists
IPWhitelist []string `toml:"ip_whitelist"`
IPBlacklist []string `toml:"ip_blacklist"`
// Requests per second per client // Requests per second per client
RequestsPerSecond float64 `toml:"requests_per_second"` RequestsPerSecond float64 `toml:"requests_per_second"`
@ -90,6 +98,33 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
return nil return nil
} }
// Validate IP lists if present
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
for i, entry := range ipWhitelist {
entryStr, ok := entry.(string)
if !ok {
continue
}
if err := validateIPv4Entry(entryStr); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: whitelist[%d] %v",
pipelineName, sinkIndex, serverType, i, err)
}
}
}
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
for i, entry := range ipBlacklist {
entryStr, ok := entry.(string)
if !ok {
continue
}
if err := validateIPv4Entry(entryStr); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: blacklist[%d] %v",
pipelineName, sinkIndex, serverType, i, err)
}
}
}
// Validate requests per second // Validate requests per second
rps, ok := rl["requests_per_second"].(float64) rps, ok := rl["requests_per_second"].(float64)
if !ok || rps <= 0 { if !ok || rps <= 0 {
@ -134,3 +169,37 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
return nil return nil
} }
// validateIPv4Entry ensures an IP or CIDR is IPv4
func validateIPv4Entry(entry string) error {
// Handle single IP
if !strings.Contains(entry, "/") {
ip := net.ParseIP(entry)
if ip == nil {
return fmt.Errorf("invalid IP address: %s", entry)
}
if ip.To4() == nil {
return fmt.Errorf("IPv6 not supported (IPv4-only): %s", entry)
}
return nil
}
// Handle CIDR
ipAddr, ipNet, err := net.ParseCIDR(entry)
if err != nil {
return fmt.Errorf("invalid CIDR: %s", entry)
}
// Check if the IP is IPv4
if ipAddr.To4() == nil {
return fmt.Errorf("IPv6 CIDR not supported (IPv4-only): %s", entry)
}
// Verify the network mask is appropriate for IPv4
_, bits := ipNet.Mask.Size()
if bits != 32 {
return fmt.Errorf("invalid IPv4 CIDR mask (got %d bits, expected 32): %s", bits, entry)
}
return nil
}

View File

@ -1,7 +1,10 @@
// FILE: logwisp/src/internal/config/ssl.go // FILE: logwisp/src/internal/config/ssl.go
package config package config
import "fmt" import (
"fmt"
"os"
)
type SSLConfig struct { type SSLConfig struct {
Enabled bool `toml:"enabled"` Enabled bool `toml:"enabled"`
@ -13,6 +16,9 @@ type SSLConfig struct {
ClientCAFile string `toml:"client_ca_file"` ClientCAFile string `toml:"client_ca_file"`
VerifyClientCert bool `toml:"verify_client_cert"` VerifyClientCert bool `toml:"verify_client_cert"`
// Option to skip verification for clients
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
// TLS version constraints // TLS version constraints
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3" MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
MaxVersion string `toml:"max_version"` MaxVersion string `toml:"max_version"`
@ -31,11 +37,27 @@ func validateSSLOptions(serverType, pipelineName string, sinkIndex int, ssl map[
pipelineName, sinkIndex, serverType) pipelineName, sinkIndex, serverType)
} }
// Validate that certificate files exist and are readable
if _, err := os.Stat(certFile); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: cert_file is not accessible: %w",
pipelineName, sinkIndex, serverType, err)
}
if _, err := os.Stat(keyFile); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: key_file is not accessible: %w",
pipelineName, sinkIndex, serverType, err)
}
if clientAuth, ok := ssl["client_auth"].(bool); ok && clientAuth { if clientAuth, ok := ssl["client_auth"].(bool); ok && clientAuth {
if caFile, ok := ssl["client_ca_file"].(string); !ok || caFile == "" { caFile, caOk := ssl["client_ca_file"].(string)
if !caOk || caFile == "" {
return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified", return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified",
pipelineName, sinkIndex, serverType) pipelineName, sinkIndex, serverType)
} }
// Validate that the client CA file exists and is readable
if _, err := os.Stat(caFile); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: client_ca_file is not accessible: %w",
pipelineName, sinkIndex, serverType, err)
}
} }
// Validate TLS versions // Validate TLS versions

View File

@ -72,11 +72,6 @@ func (c *Config) validate() error {
} }
} }
// Validate net access if present
if err := validateNetAccess(pipeline.Name, pipeline.NetAccess); err != nil {
return err
}
// Validate auth if present // Validate auth if present
if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil { if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil {
return err return err

View File

@ -1,173 +0,0 @@
// FILE: src/internal/limit/ip.go
package limit
import (
"net"
"strings"
"logwisp/src/internal/config"
"github.com/lixenwraith/log"
)
// IPChecker handles IP-based access control lists
type IPChecker struct {
ipWhitelist []*net.IPNet
ipBlacklist []*net.IPNet
logger *log.Logger
}
// NewIPChecker creates a new IPChecker. Returns nil if no rules are defined.
func NewIPChecker(cfg *config.NetAccessConfig, logger *log.Logger) *IPChecker {
if cfg == nil || (len(cfg.IPWhitelist) == 0 && len(cfg.IPBlacklist) == 0) {
return nil
}
c := &IPChecker{
ipWhitelist: make([]*net.IPNet, 0),
ipBlacklist: make([]*net.IPNet, 0),
logger: logger,
}
// Parse whitelist entries
for _, cidr := range cfg.IPWhitelist {
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
// Try parsing as plain IP
if ip := net.ParseIP(cidr); ip != nil {
if ip.To4() != nil {
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)}
} else {
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
}
} else {
logger.Warn("msg", "Skipping invalid IP whitelist entry",
"component", "ip_checker",
"entry", cidr,
"error", err)
continue
}
}
c.ipWhitelist = append(c.ipWhitelist, ipNet)
}
// Parse blacklist entries
for _, cidr := range cfg.IPBlacklist {
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
// Try parsing as plain IP
if ip := net.ParseIP(cidr); ip != nil {
if ip.To4() != nil {
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)}
} else {
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
}
} else {
logger.Warn("msg", "Skipping invalid IP blacklist entry",
"component", "ip_checker",
"entry", cidr,
"error", err)
continue
}
}
c.ipBlacklist = append(c.ipBlacklist, ipNet)
}
logger.Info("msg", "IP checker initialized",
"component", "ip_checker",
"whitelist_rules", len(c.ipWhitelist),
"blacklist_rules", len(c.ipBlacklist))
return c
}
// IsAllowed validates if a remote address is permitted
func (c *IPChecker) IsAllowed(remoteAddr net.Addr) bool {
if c == nil {
return true // No checker = allow all
}
// No rules = allow all
if len(c.ipWhitelist) == 0 && len(c.ipBlacklist) == 0 {
return true
}
// Extract IP from address
var ipStr string
switch addr := remoteAddr.(type) {
case *net.TCPAddr:
ipStr = addr.IP.String()
case *net.UDPAddr:
ipStr = addr.IP.String()
default:
// Try string parsing
addrStr := remoteAddr.String()
host, _, err := net.SplitHostPort(addrStr)
if err != nil {
ipStr = addrStr
} else {
ipStr = host
}
}
ip := net.ParseIP(ipStr)
if ip == nil {
c.logger.Warn("msg", "Could not parse remote address to IP",
"component", "ip_checker",
"remote_addr", remoteAddr.String())
return false // Deny unparseable addresses
}
// Check blacklist first (deny takes precedence)
for _, ipNet := range c.ipBlacklist {
if ipNet.Contains(ip) {
c.logger.Warn("msg", "Blacklisted IP denied",
"component", "ip_checker",
"ip", ipStr,
"rule", ipNet.String())
return false
}
}
// If whitelist is configured, IP must be in it
if len(c.ipWhitelist) > 0 {
for _, ipNet := range c.ipWhitelist {
if ipNet.Contains(ip) {
c.logger.Debug("msg", "IP allowed by whitelist",
"component", "ip_checker",
"ip", ipStr,
"rule", ipNet.String())
return true
}
}
// No whitelist match = deny
c.logger.Warn("msg", "IP not in whitelist",
"component", "ip_checker",
"ip", ipStr)
return false
}
// No blacklist match + no whitelist configured = allow
return true
}
// GetStats returns IP checker statistics
func (c *IPChecker) GetStats() map[string]any {
if c == nil {
return map[string]any{"enabled": false}
}
return map[string]any{
"enabled": true,
"whitelist_rules": len(c.ipWhitelist),
"blacklist_rules": len(c.ipBlacklist),
}
}

View File

@ -14,11 +14,32 @@ import (
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
) )
// DenialReason indicates why a request was denied
type DenialReason string
const (
// IPv4Only is the enforcement message for IPv6 rejection
IPv4Only = "IPv4-only (IPv6 not supported)"
)
const (
ReasonAllowed DenialReason = ""
ReasonBlacklisted DenialReason = "IP denied by blacklist"
ReasonNotWhitelisted DenialReason = "IP not in whitelist"
ReasonRateLimited DenialReason = "Rate limit exceeded"
ReasonConnectionLimited DenialReason = "Connection limit exceeded"
ReasonInvalidIP DenialReason = "Invalid IP address"
)
// NetLimiter manages net limiting for a transport // NetLimiter manages net limiting for a transport
type NetLimiter struct { type NetLimiter struct {
config config.NetLimitConfig config config.NetLimitConfig
logger *log.Logger logger *log.Logger
// IP Access Control Lists
ipWhitelist []*net.IPNet
ipBlacklist []*net.IPNet
// Per-IP limiters // Per-IP limiters
ipLimiters map[string]*ipLimiter ipLimiters map[string]*ipLimiter
ipMu sync.RWMutex ipMu sync.RWMutex
@ -27,17 +48,22 @@ type NetLimiter struct {
globalLimiter *TokenBucket globalLimiter *TokenBucket
// Connection tracking // Connection tracking
ipConnections map[string]*atomic.Int64 ipConnections map[string]*connTracker
connMu sync.RWMutex connMu sync.RWMutex
// Statistics // Statistics
totalRequests atomic.Uint64 totalRequests atomic.Uint64
blockedRequests atomic.Uint64 blockedByBlacklist atomic.Uint64
blockedByWhitelist atomic.Uint64
blockedByRateLimit atomic.Uint64
blockedByConnLimit atomic.Uint64
blockedByInvalidIP atomic.Uint64
uniqueIPs atomic.Uint64 uniqueIPs atomic.Uint64
// Cleanup // Cleanup
lastCleanup time.Time lastCleanup time.Time
cleanupMu sync.Mutex cleanupMu sync.Mutex
cleanupActive atomic.Bool
// Lifecycle management // Lifecycle management
ctx context.Context ctx context.Context
@ -51,9 +77,20 @@ type ipLimiter struct {
connections atomic.Int64 connections atomic.Int64
} }
// Connection tracking with activity timestamp
type connTracker struct {
connections atomic.Int64
lastSeen time.Time
mu sync.Mutex
}
// Creates a new net limiter // Creates a new net limiter
func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter { func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
if !cfg.Enabled { // Return nil only if nothing is configured
hasACL := len(cfg.IPWhitelist) > 0 || len(cfg.IPBlacklist) > 0
hasRateLimit := cfg.Enabled
if !hasACL && !hasRateLimit {
return nil return nil
} }
@ -65,28 +102,39 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
l := &NetLimiter{ l := &NetLimiter{
config: cfg, config: cfg,
ipLimiters: make(map[string]*ipLimiter),
ipConnections: make(map[string]*atomic.Int64),
lastCleanup: time.Now(),
logger: logger, logger: logger,
ipWhitelist: make([]*net.IPNet, 0),
ipBlacklist: make([]*net.IPNet, 0),
ipLimiters: make(map[string]*ipLimiter),
ipConnections: make(map[string]*connTracker),
lastCleanup: time.Now(),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
cleanupDone: make(chan struct{}), cleanupDone: make(chan struct{}),
} }
// Create global limiter if not using per-IP limiting // Parse IP lists
if cfg.LimitBy == "global" { l.parseIPLists(cfg)
// Create global limiter if configured
if cfg.Enabled && cfg.LimitBy == "global" {
l.globalLimiter = NewTokenBucket( l.globalLimiter = NewTokenBucket(
float64(cfg.BurstSize), float64(cfg.BurstSize),
cfg.RequestsPerSecond, cfg.RequestsPerSecond,
) )
} }
// Start cleanup goroutine // Start cleanup goroutine only if rate limiting is enabled
if cfg.Enabled {
go l.cleanupLoop() go l.cleanupLoop()
}
l.logger.Info("msg", "Net limiter initialized", logger.Info("msg", "Net limiter initialized",
"component", "netlimit", "component", "netlimit",
"acl_enabled", hasACL,
"rate_limiting", cfg.Enabled,
"whitelist_rules", len(l.ipWhitelist),
"blacklist_rules", len(l.ipBlacklist),
"requests_per_second", cfg.RequestsPerSecond, "requests_per_second", cfg.RequestsPerSecond,
"burst_size", cfg.BurstSize, "burst_size", cfg.BurstSize,
"limit_by", cfg.LimitBy) "limit_by", cfg.LimitBy)
@ -94,6 +142,120 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
return l return l
} }
// parseIPLists parses and validates IP whitelist/blacklist
func (l *NetLimiter) parseIPLists(cfg config.NetLimitConfig) {
// Parse whitelist
for _, entry := range cfg.IPWhitelist {
if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil {
l.ipWhitelist = append(l.ipWhitelist, ipNet)
}
}
// Parse blacklist
for _, entry := range cfg.IPBlacklist {
if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil {
l.ipBlacklist = append(l.ipBlacklist, ipNet)
}
}
}
// parseIPEntry parses a single IP or CIDR entry
func (l *NetLimiter) parseIPEntry(entry, listType string) *net.IPNet {
// Handle single IP
if !strings.Contains(entry, "/") {
ip := net.ParseIP(entry)
if ip == nil {
l.logger.Warn("msg", "Invalid IP entry",
"component", "netlimit",
"list", listType,
"entry", entry)
return nil
}
// Reject IPv6
if ip.To4() == nil {
l.logger.Warn("msg", "IPv6 address rejected",
"component", "netlimit",
"list", listType,
"entry", entry,
"reason", IPv4Only)
return nil
}
return &net.IPNet{IP: ip.To4(), Mask: net.CIDRMask(32, 32)}
}
// Parse CIDR
ipAddr, ipNet, err := net.ParseCIDR(entry)
if err != nil {
l.logger.Warn("msg", "Invalid CIDR entry",
"component", "netlimit",
"list", listType,
"entry", entry,
"error", err)
return nil
}
// Reject IPv6 CIDR
if ipAddr.To4() == nil {
l.logger.Warn("msg", "IPv6 CIDR rejected",
"component", "netlimit",
"list", listType,
"entry", entry,
"reason", IPv4Only)
return nil
}
// Ensure mask is IPv4
_, bits := ipNet.Mask.Size()
if bits != 32 {
l.logger.Warn("msg", "Non-IPv4 CIDR mask rejected",
"component", "netlimit",
"list", listType,
"entry", entry,
"mask_bits", bits,
"reason", IPv4Only)
return nil
}
return &net.IPNet{IP: ipAddr.To4(), Mask: ipNet.Mask}
}
// checkIPAccess checks if an IP is allowed by ACLs
func (l *NetLimiter) checkIPAccess(ip net.IP) DenialReason {
// 1. Check blacklist first (deny takes precedence)
for _, ipNet := range l.ipBlacklist {
if ipNet.Contains(ip) {
l.blockedByBlacklist.Add(1)
l.logger.Debug("msg", "IP denied by blacklist",
"component", "netlimit",
"ip", ip.String(),
"rule", ipNet.String())
return ReasonBlacklisted
}
}
// 2. If whitelist is configured, IP must be in it
if len(l.ipWhitelist) > 0 {
for _, ipNet := range l.ipWhitelist {
if ipNet.Contains(ip) {
l.logger.Debug("msg", "IP allowed by whitelist",
"component", "netlimit",
"ip", ip.String(),
"rule", ipNet.String())
return ReasonAllowed
}
}
l.blockedByWhitelist.Add(1)
l.logger.Debug("msg", "IP not in whitelist",
"component", "netlimit",
"ip", ip.String())
return ReasonNotWhitelisted
}
return ReasonAllowed
}
func (l *NetLimiter) Shutdown() { func (l *NetLimiter) Shutdown() {
if l == nil { if l == nil {
return return
@ -121,9 +283,9 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6
l.totalRequests.Add(1) l.totalRequests.Add(1)
ip, _, err := net.SplitHostPort(remoteAddr) // Parse IP address
ipStr, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
// If we can't parse the IP, allow the request but log
l.logger.Warn("msg", "Failed to parse remote addr", l.logger.Warn("msg", "Failed to parse remote addr",
"component", "netlimit", "component", "netlimit",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
@ -131,56 +293,82 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6
return true, 0, "" return true, 0, ""
} }
// Only supporting ipv4 ip := net.ParseIP(ipStr)
if !isIPv4(ip) { if ip == nil {
// Block non-IPv4 addresses to prevent complications l.blockedByInvalidIP.Add(1)
l.blockedRequests.Add(1) l.logger.Warn("msg", "Failed to parse IP",
l.logger.Warn("msg", "Non-IPv4 address blocked",
"component", "netlimit", "component", "netlimit",
"ip", ip) "ip", ipStr)
return false, 403, "IPv4 only" return false, 403, string(ReasonInvalidIP)
} }
// Check connection limit for streaming endpoint // Reject IPv6 connections
if !isIPv4(ip) {
l.blockedByInvalidIP.Add(1)
l.logger.Warn("msg", "IPv6 connection rejected",
"component", "netlimit",
"ip", ipStr,
"reason", IPv4Only)
return false, 403, IPv4Only
}
// Normalize to IPv4 representation
ip = ip.To4()
// Check IP access control
if reason := l.checkIPAccess(ip); reason != ReasonAllowed {
return false, 403, string(reason)
}
// If rate limiting is not enabled, allow
if !l.config.Enabled {
return true, 0, ""
}
// Check connection limits
if l.config.MaxConnectionsPerIP > 0 { if l.config.MaxConnectionsPerIP > 0 {
l.connMu.RLock() l.connMu.RLock()
counter, exists := l.ipConnections[ip] tracker, exists := l.ipConnections[ipStr]
l.connMu.RUnlock() l.connMu.RUnlock()
if exists && counter.Load() >= l.config.MaxConnectionsPerIP { if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
l.blockedRequests.Add(1) l.blockedByConnLimit.Add(1)
statusCode = l.config.ResponseCode statusCode = l.config.ResponseCode
if statusCode == 0 { if statusCode == 0 {
statusCode = 429 statusCode = 429
} }
message = "Connection limit exceeded" return false, statusCode, string(ReasonConnectionLimited)
l.logger.Warn("msg", "Connection limit exceeded",
"component", "netlimit",
"ip", ip,
"connections", counter.Load(),
"limit", l.config.MaxConnectionsPerIP)
return false, statusCode, message
} }
} }
// Check net limit // Check rate limit
allowed = l.checkLimit(ip) if !l.checkLimit(ipStr) {
if !allowed { l.blockedByRateLimit.Add(1)
l.blockedRequests.Add(1)
statusCode = l.config.ResponseCode statusCode = l.config.ResponseCode
if statusCode == 0 { if statusCode == 0 {
statusCode = 429 statusCode = 429
} }
message = l.config.ResponseMessage message = l.config.ResponseMessage
if message == "" { if message == "" {
message = "Net limit exceeded" message = string(ReasonRateLimited)
} }
l.logger.Debug("msg", "Request net limited", "ip", ip) return false, statusCode, message
} }
return allowed, statusCode, message return true, 0, ""
}
// Update connection activity
func (l *NetLimiter) updateConnectionActivity(ip string) {
l.connMu.RLock()
tracker, exists := l.ipConnections[ip]
l.connMu.RUnlock()
if exists {
tracker.mu.Lock()
tracker.lastSeen = time.Now()
tracker.mu.Unlock()
}
} }
// Checks if a TCP connection should be allowed // Checks if a TCP connection should be allowed
@ -194,32 +382,45 @@ func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
// Extract IP from TCP addr // Extract IP from TCP addr
tcpAddr, ok := remoteAddr.(*net.TCPAddr) tcpAddr, ok := remoteAddr.(*net.TCPAddr)
if !ok { if !ok {
return true l.blockedByInvalidIP.Add(1)
}
ip := tcpAddr.IP.String()
// Only supporting ipv4
if !isIPv4(ip) {
l.blockedRequests.Add(1)
l.logger.Warn("msg", "Non-IPv4 TCP connection blocked",
"component", "netlimit",
"ip", ip)
return false return false
} }
allowed := l.checkLimit(ip) // Reject IPv6 connections
if !allowed { if !isIPv4(tcpAddr.IP) {
l.blockedRequests.Add(1) l.blockedByInvalidIP.Add(1)
l.logger.Debug("msg", "TCP connection net limited", "ip", ip) l.logger.Warn("msg", "IPv6 TCP connection rejected",
"component", "netlimit",
"ip", tcpAddr.IP.String(),
"reason", IPv4Only)
return false
} }
return allowed // Normalize to IPv4 representation
ip := tcpAddr.IP.To4()
// Check IP access control
if reason := l.checkIPAccess(ip); reason != ReasonAllowed {
return false
} }
func isIPv4(ip string) bool { // If rate limiting is not enabled, allow
// Simple check: IPv4 addresses contain dots, IPv6 contain colons if !l.config.Enabled {
return strings.Contains(ip, ".") && !strings.Contains(ip, ":") return true
}
// Check rate limit
ipStr := tcpAddr.IP.String()
if !l.checkLimit(ipStr) {
l.blockedByRateLimit.Add(1)
return false
}
return true
}
func isIPv4(ip net.IP) bool {
return ip.To4() != nil
} }
// Tracks a new connection for an IP // Tracks a new connection for an IP
@ -230,23 +431,44 @@ func (l *NetLimiter) AddConnection(remoteAddr string) {
ip, _, err := net.SplitHostPort(remoteAddr) ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
l.logger.Warn("msg", "Failed to parse remote address in AddConnection",
"component", "netlimit",
"remote_addr", remoteAddr,
"error", err)
return
}
// IP validation
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
l.logger.Warn("msg", "Failed to parse IP in AddConnection",
"component", "netlimit",
"ip", ip)
return return
} }
// Only supporting ipv4 // Only supporting ipv4
if !isIPv4(ip) { if !isIPv4(parsedIP) {
return return
} }
l.connMu.Lock() l.connMu.Lock()
counter, exists := l.ipConnections[ip] tracker, exists := l.ipConnections[ip]
if !exists { if !exists {
counter = &atomic.Int64{} // Create new tracker with timestamp
l.ipConnections[ip] = counter tracker = &connTracker{
lastSeen: time.Now(),
}
l.ipConnections[ip] = tracker
} }
l.connMu.Unlock() l.connMu.Unlock()
newCount := counter.Add(1) newCount := tracker.connections.Add(1)
// Update activity timestamp
tracker.mu.Lock()
tracker.lastSeen = time.Now()
tracker.mu.Unlock()
l.logger.Debug("msg", "Connection added", l.logger.Debug("msg", "Connection added",
"ip", ip, "ip", ip,
"connections", newCount) "connections", newCount)
@ -260,20 +482,33 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) {
ip, _, err := net.SplitHostPort(remoteAddr) ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
l.logger.Warn("msg", "Failed to parse remote address in RemoveConnection",
"component", "netlimit",
"remote_addr", remoteAddr,
"error", err)
return
}
// IP validation
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
l.logger.Warn("msg", "Failed to parse IP in RemoveConnection",
"component", "netlimit",
"ip", ip)
return return
} }
// Only supporting ipv4 // Only supporting ipv4
if !isIPv4(ip) { if !isIPv4(parsedIP) {
return return
} }
l.connMu.RLock() l.connMu.RLock()
counter, exists := l.ipConnections[ip] tracker, exists := l.ipConnections[ip]
l.connMu.RUnlock() l.connMu.RUnlock()
if exists { if exists {
newCount := counter.Add(-1) newCount := tracker.connections.Add(-1)
l.logger.Debug("msg", "Connection removed", l.logger.Debug("msg", "Connection removed",
"ip", ip, "ip", ip,
"connections", newCount) "connections", newCount)
@ -281,7 +516,7 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) {
if newCount <= 0 { if newCount <= 0 {
// Clean up if no more connections // Clean up if no more connections
l.connMu.Lock() l.connMu.Lock()
if counter.Load() <= 0 { if tracker.connections.Load() <= 0 {
delete(l.ipConnections, ip) delete(l.ipConnections, ip)
} }
l.connMu.Unlock() l.connMu.Unlock()
@ -292,9 +527,7 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) {
// Returns net limiter statistics // Returns net limiter statistics
func (l *NetLimiter) GetStats() map[string]any { func (l *NetLimiter) GetStats() map[string]any {
if l == nil { if l == nil {
return map[string]any{ return map[string]any{"enabled": false}
"enabled": false,
}
} }
l.ipMu.RLock() l.ipMu.RLock()
@ -303,18 +536,36 @@ func (l *NetLimiter) GetStats() map[string]any {
l.connMu.RLock() l.connMu.RLock()
totalConnections := 0 totalConnections := 0
for _, counter := range l.ipConnections { for _, tracker := range l.ipConnections {
totalConnections += int(counter.Load()) totalConnections += int(tracker.connections.Load())
} }
l.connMu.RUnlock() l.connMu.RUnlock()
totalBlocked := l.blockedByBlacklist.Load() +
l.blockedByWhitelist.Load() +
l.blockedByRateLimit.Load() +
l.blockedByConnLimit.Load() +
l.blockedByInvalidIP.Load()
return map[string]any{ return map[string]any{
"enabled": true, "enabled": true,
"total_requests": l.totalRequests.Load(), "total_requests": l.totalRequests.Load(),
"blocked_requests": l.blockedRequests.Load(), "total_blocked": totalBlocked,
"blocked_breakdown": map[string]uint64{
"blacklist": l.blockedByBlacklist.Load(),
"whitelist": l.blockedByWhitelist.Load(),
"rate_limit": l.blockedByRateLimit.Load(),
"conn_limit": l.blockedByConnLimit.Load(),
"invalid_ip": l.blockedByInvalidIP.Load(),
},
"active_ips": activeIPs, "active_ips": activeIPs,
"total_connections": totalConnections, "total_connections": totalConnections,
"config": map[string]any{ "acl": map[string]int{
"whitelist_rules": len(l.ipWhitelist),
"blacklist_rules": len(l.ipBlacklist),
},
"rate_limit": map[string]any{
"enabled": l.config.Enabled,
"requests_per_second": l.config.RequestsPerSecond, "requests_per_second": l.config.RequestsPerSecond,
"burst_size": l.config.BurstSize, "burst_size": l.config.BurstSize,
"limit_by": l.config.LimitBy, "limit_by": l.config.LimitBy,
@ -324,6 +575,15 @@ func (l *NetLimiter) GetStats() map[string]any {
// Performs the actual net limit check // Performs the actual net limit check
func (l *NetLimiter) checkLimit(ip string) bool { func (l *NetLimiter) checkLimit(ip string) bool {
// Validate IP format
parsedIP := net.ParseIP(ip)
if parsedIP == nil || !isIPv4(parsedIP) {
l.logger.Warn("msg", "Invalid or non-IPv4 address in rate limiter",
"component", "netlimit",
"ip", ip)
return false
}
// Maybe run cleanup // Maybe run cleanup
l.maybeCleanup() l.maybeCleanup()
@ -358,10 +618,10 @@ func (l *NetLimiter) checkLimit(ip string) bool {
// Check connection limit if configured // Check connection limit if configured
if l.config.MaxConnectionsPerIP > 0 { if l.config.MaxConnectionsPerIP > 0 {
l.connMu.RLock() l.connMu.RLock()
counter, exists := l.ipConnections[ip] tracker, exists := l.ipConnections[ip]
l.connMu.RUnlock() l.connMu.RUnlock()
if exists && counter.Load() >= l.config.MaxConnectionsPerIP { if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
return false return false
} }
} }
@ -379,14 +639,27 @@ func (l *NetLimiter) checkLimit(ip string) bool {
// Runs cleanup if enough time has passed // Runs cleanup if enough time has passed
func (l *NetLimiter) maybeCleanup() { func (l *NetLimiter) maybeCleanup() {
l.cleanupMu.Lock() l.cleanupMu.Lock()
defer l.cleanupMu.Unlock()
// Check if enough time has passed
if time.Since(l.lastCleanup) < 30*time.Second { if time.Since(l.lastCleanup) < 30*time.Second {
l.cleanupMu.Unlock()
return
}
// Check if cleanup already running
if !l.cleanupActive.CompareAndSwap(false, true) {
l.cleanupMu.Unlock()
return return
} }
l.lastCleanup = time.Now() l.lastCleanup = time.Now()
go l.cleanup() l.cleanupMu.Unlock()
// Run cleanup async
go func() {
defer l.cleanupActive.Store(false)
l.cleanup()
}()
} }
// Removes stale IP limiters // Removes stale IP limiters
@ -397,6 +670,8 @@ func (l *NetLimiter) cleanup() {
l.ipMu.Lock() l.ipMu.Lock()
defer l.ipMu.Unlock() defer l.ipMu.Unlock()
// Clean up rate limiters
l.ipMu.Lock()
cleaned := 0 cleaned := 0
for ip, lim := range l.ipLimiters { for ip, lim := range l.ipLimiters {
if now.Sub(lim.lastSeen) > staleTimeout { if now.Sub(lim.lastSeen) > staleTimeout {
@ -404,12 +679,37 @@ func (l *NetLimiter) cleanup() {
cleaned++ cleaned++
} }
} }
l.ipMu.Unlock()
if cleaned > 0 { if cleaned > 0 {
l.logger.Debug("msg", "Cleaned up stale IP limiters", l.logger.Debug("msg", "Cleaned up stale IP limiters",
"component", "netlimit",
"cleaned", cleaned, "cleaned", cleaned,
"remaining", len(l.ipLimiters)) "remaining", len(l.ipLimiters))
} }
// Clean up stale connection trackers
l.connMu.Lock()
connCleaned := 0
for ip, tracker := range l.ipConnections {
tracker.mu.Lock()
lastSeen := tracker.lastSeen
tracker.mu.Unlock()
// Remove if no activity for 5 minutes AND no active connections
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
delete(l.ipConnections, ip)
connCleaned++
}
}
l.connMu.Unlock()
if connCleaned > 0 {
l.logger.Debug("msg", "Cleaned up stale connection trackers",
"component", "netlimit",
"cleaned", connCleaned,
"remaining", len(l.ipConnections))
}
} }
// Runs periodic cleanup // Runs periodic cleanup

View File

@ -139,9 +139,6 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
// Configure authentication for sinks that support it // Configure authentication for sinks that support it
for _, sinkInst := range pipeline.Sinks { for _, sinkInst := range pipeline.Sinks {
if setter, ok := sinkInst.(sink.NetAccessSetter); ok {
setter.SetNetAccessConfig(cfg.NetAccess)
}
if setter, ok := sinkInst.(sink.AuthSetter); ok { if setter, ok := sinkInst.(sink.AuthSetter); ok {
setter.SetAuthConfig(cfg.Auth) setter.SetAuthConfig(cfg.Auth)
} }

View File

@ -48,7 +48,6 @@ type HTTPSink struct {
// Net limiting // Net limiting
netLimiter *limit.NetLimiter netLimiter *limit.NetLimiter
ipChecker *limit.IPChecker
// Statistics // Statistics
totalProcessed atomic.Uint64 totalProcessed atomic.Uint64
@ -156,6 +155,22 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
if maxTotal, ok := rl["max_total_connections"].(int64); ok { if maxTotal, ok := rl["max_total_connections"].(int64); ok {
cfg.NetLimit.MaxTotalConnections = maxTotal cfg.NetLimit.MaxTotalConnections = maxTotal
} }
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
for _, entry := range ipWhitelist {
if str, ok := entry.(string); ok {
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
}
}
}
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
for _, entry := range ipBlacklist {
if str, ok := entry.(string); ok {
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
}
}
}
} }
h := &HTTPSink{ h := &HTTPSink{
@ -195,8 +210,10 @@ func (h *HTTPSink) Start(ctx context.Context) error {
// Configure TLS if enabled // Configure TLS if enabled
if h.tlsManager != nil { if h.tlsManager != nil {
tlsConfig := h.tlsManager.GetHTTPConfig() h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
h.server.TLSConfig = tlsConfig h.logger.Info("msg", "TLS enabled for HTTP sink",
"component", "http_sink",
"port", h.config.Port)
} }
addr := fmt.Sprintf(":%d", h.config.Port) addr := fmt.Sprintf(":%d", h.config.Port)
@ -208,12 +225,13 @@ func (h *HTTPSink) Start(ctx context.Context) error {
"component", "http_sink", "component", "http_sink",
"port", h.config.Port, "port", h.config.Port,
"stream_path", h.streamPath, "stream_path", h.streamPath,
"status_path", h.statusPath) "status_path", h.statusPath,
"tls_enabled", h.tlsManager != nil)
var err error var err error
if h.tlsManager != nil { if h.tlsManager != nil {
// HTTPS server // HTTPS server
err = h.server.ListenAndServeTLS(addr, "", "") err = h.server.ListenAndServeTLS(addr, h.config.SSL.CertFile, h.config.SSL.KeyFile)
} else { } else {
// HTTP server // HTTP server
err = h.server.ListenAndServe(addr) err = h.server.ListenAndServe(addr)
@ -306,24 +324,18 @@ func (h *HTTPSink) GetStats() SinkStats {
func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
remoteAddr := ctx.RemoteAddr().String() remoteAddr := ctx.RemoteAddr().String()
// Check IP access control
if h.ipChecker != nil {
if !h.ipChecker.IsAllowed(ctx.RemoteAddr()) {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("text/plain")
ctx.SetBodyString("Forbidden")
return
}
}
// Check net limit // Check net limit
if h.netLimiter != nil { if h.netLimiter != nil {
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed { if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
ctx.SetStatusCode(int(statusCode)) ctx.SetStatusCode(int(statusCode))
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
h.logger.Warn("msg", "Net limited",
"component", "http_sink",
"remote_addr", remoteAddr,
"status_code", statusCode,
"error", message)
json.NewEncoder(ctx).Encode(map[string]any{ json.NewEncoder(ctx).Encode(map[string]any{
"error": message, "error": "Too many requests",
"retry_after": "60", // seconds
}) })
return return
} }
@ -355,7 +367,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil { if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
realm := h.authConfig.BasicAuth.Realm realm := h.authConfig.BasicAuth.Realm
if realm == "" { if realm == "" {
realm = "LogWisp" realm = "Restricted"
} }
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm)) ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
} else if h.authConfig.Type == "bearer" { } else if h.authConfig.Type == "bearer" {
@ -364,7 +376,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{ json.NewEncoder(ctx).Encode(map[string]string{
"error": "Authentication required", "error": "Unauthorized",
}) })
return return
} }
@ -379,8 +391,6 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]any{ json.NewEncoder(ctx).Encode(map[string]any{
"error": "Not Found", "error": "Not Found",
"message": fmt.Sprintf("Available endpoints: %s (SSE transport), %s (status)",
h.streamPath, h.statusPath),
}) })
} }
} }
@ -656,14 +666,6 @@ func (h *HTTPSink) GetStatusPath() string {
return h.statusPath return h.statusPath
} }
func (h *HTTPSink) SetNetAccessConfig(cfg *config.NetAccessConfig) {
h.ipChecker = limit.NewIPChecker(cfg, h.logger)
if h.ipChecker != nil {
h.logger.Info("msg", "IP access control configured for HTTP sink",
"component", "http_sink")
}
}
// SetAuthConfig configures http sink authentication // SetAuthConfig configures http sink authentication
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) { func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" { if authCfg == nil || authCfg.Type == "none" {

View File

@ -4,8 +4,12 @@ package sink
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"net/url" "net/url"
"os"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -55,6 +59,7 @@ type HTTPClientConfig struct {
// TLS configuration // TLS configuration
InsecureSkipVerify bool InsecureSkipVerify bool
CAFile string
} }
// NewHTTPClientSink creates a new HTTP client sink // NewHTTPClientSink creates a new HTTP client sink
@ -126,6 +131,11 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
cfg.Headers["Content-Type"] = "application/json" cfg.Headers["Content-Type"] = "application/json"
} }
// Extract TLS options
if caFile, ok := options["ca_file"].(string); ok && caFile != "" {
cfg.CAFile = caFile
}
h := &HTTPClientSink{ h := &HTTPClientSink{
input: make(chan core.LogEntry, cfg.BufferSize), input: make(chan core.LogEntry, cfg.BufferSize),
config: cfg, config: cfg,
@ -147,17 +157,28 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
DisableHeaderNamesNormalizing: true, DisableHeaderNamesNormalizing: true,
} }
// TODO: Implement custom TLS configuration, including InsecureSkipVerify, // Configure TLS if using HTTPS
// by setting a custom dialer on the fasthttp.Client. if strings.HasPrefix(cfg.URL, "https://") {
// For example: tlsConfig := &tls.Config{
// if cfg.InsecureSkipVerify { InsecureSkipVerify: cfg.InsecureSkipVerify,
// h.client.Dial = func(addr string) (net.Conn, error) { }
// return fasthttp.DialDualStackTimeout(addr, cfg.Timeout, &tls.Config{
// InsecureSkipVerify: true, // Load custom CA if provided
// }) if cfg.CAFile != "" {
// } caCert, err := os.ReadFile(cfg.CAFile)
// } if err != nil {
// FIXED: Removed incorrect TLS configuration that referenced non-existent field return nil, fmt.Errorf("failed to read CA file: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
tlsConfig.RootCAs = caCertPool
}
// Set TLS config directly on the client
h.client.TLSConfig = tlsConfig
}
return h, nil return h, nil
} }
@ -341,15 +362,22 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
if attempt > 0 { if attempt > 0 {
// Wait before retry // Wait before retry
time.Sleep(retryDelay) time.Sleep(retryDelay)
retryDelay = time.Duration(float64(retryDelay) * h.config.RetryBackoff)
// Calculate new delay with overflow protection
newDelay := time.Duration(float64(retryDelay) * h.config.RetryBackoff)
// Cap at maximum to prevent integer overflow
if newDelay > h.config.Timeout || newDelay < retryDelay {
// Either exceeded max or overflowed (negative/wrapped)
retryDelay = h.config.Timeout
} else {
retryDelay = newDelay
}
} }
// TODO: defer placement issue // Acquire resources inside loop, release immediately after use
// Create request
req := fasthttp.AcquireRequest() req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
resp := fasthttp.AcquireResponse() resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(h.config.URL) req.SetRequestURI(h.config.URL)
req.Header.SetMethod("POST") req.Header.SetMethod("POST")
@ -362,35 +390,50 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
// Send request // Send request
err := h.client.DoTimeout(req, resp, h.config.Timeout) err := h.client.DoTimeout(req, resp, h.config.Timeout)
// Capture response before releasing
statusCode := resp.StatusCode()
var responseBody []byte
if len(resp.Body()) > 0 {
responseBody = make([]byte, len(resp.Body()))
copy(responseBody, resp.Body())
}
// Release immediately, not deferred
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
// Handle errors
if err != nil { if err != nil {
lastErr = fmt.Errorf("request failed: %w", err) lastErr = fmt.Errorf("request failed: %w", err)
h.logger.Warn("msg", "HTTP request failed", h.logger.Warn("msg", "HTTP request failed",
"component", "http_client_sink", "component", "http_client_sink",
"attempt", attempt+1, "attempt", attempt+1,
"max_retries", h.config.MaxRetries,
"error", err) "error", err)
continue continue
} }
// Check response status // Check response status
statusCode := resp.StatusCode()
if statusCode >= 200 && statusCode < 300 { if statusCode >= 200 && statusCode < 300 {
// Success // Success
h.logger.Debug("msg", "Batch sent successfully", h.logger.Debug("msg", "Batch sent successfully",
"component", "http_client_sink", "component", "http_client_sink",
"batch_size", len(batch), "batch_size", len(batch),
"status_code", statusCode) "status_code", statusCode,
"attempt", attempt+1)
return return
} }
// Non-2xx status // Non-2xx status
lastErr = fmt.Errorf("server returned status %d: %s", statusCode, resp.Body()) lastErr = fmt.Errorf("server returned status %d: %s", statusCode, responseBody)
// Don't retry on 4xx errors (client errors) // Don't retry on 4xx errors (client errors)
if statusCode >= 400 && statusCode < 500 { if statusCode >= 400 && statusCode < 500 {
h.logger.Error("msg", "Batch rejected by server", h.logger.Error("msg", "Batch rejected by server",
"component", "http_client_sink", "component", "http_client_sink",
"status_code", statusCode, "status_code", statusCode,
"response", string(resp.Body()), "response", string(responseBody),
"batch_size", len(batch)) "batch_size", len(batch))
h.failedBatches.Add(1) h.failedBatches.Add(1)
return return
@ -400,13 +443,14 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
"component", "http_client_sink", "component", "http_client_sink",
"attempt", attempt+1, "attempt", attempt+1,
"status_code", statusCode, "status_code", statusCode,
"response", string(resp.Body())) "response", string(responseBody))
} }
// All retries failed // All retries exhausted
h.logger.Error("msg", "Failed to send batch after retries", h.logger.Error("msg", "Failed to send batch after all retries",
"component", "http_client_sink", "component", "http_client_sink",
"batch_size", len(batch), "batch_size", len(batch),
"retries", h.config.MaxRetries,
"last_error", lastErr) "last_error", lastErr)
h.failedBatches.Add(1) h.failedBatches.Add(1)
} }

View File

@ -34,11 +34,6 @@ type SinkStats struct {
Details map[string]any Details map[string]any
} }
// NetAccessSetter is an interface for sinks that can accept network access configuration
type NetAccessSetter interface {
SetNetAccessConfig(cfg *config.NetAccessConfig)
}
// AuthSetter is an interface for sinks that can accept an AuthConfig. // AuthSetter is an interface for sinks that can accept an AuthConfig.
type AuthSetter interface { type AuthSetter interface {
SetAuthConfig(auth *config.AuthConfig) SetAuthConfig(auth *config.AuthConfig)

View File

@ -36,7 +36,6 @@ type TCPSink struct {
engineMu sync.Mutex engineMu sync.Mutex
wg sync.WaitGroup wg sync.WaitGroup
netLimiter *limit.NetLimiter netLimiter *limit.NetLimiter
ipChecker *limit.IPChecker
logger *log.Logger logger *log.Logger
formatter format.Formatter formatter format.Formatter
@ -50,6 +49,11 @@ type TCPSink struct {
lastProcessed atomic.Value // time.Time lastProcessed atomic.Value // time.Time
authFailures atomic.Uint64 authFailures atomic.Uint64
authSuccesses atomic.Uint64 authSuccesses atomic.Uint64
// Write error tracking
writeErrors atomic.Uint64
consecutiveWriteErrors map[gnet.Conn]int
errorMu sync.Mutex
} }
// TCPConfig holds TCP sink configuration // TCPConfig holds TCP sink configuration
@ -141,6 +145,22 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
if maxTotal, ok := rl["max_total_connections"].(int64); ok { if maxTotal, ok := rl["max_total_connections"].(int64); ok {
cfg.NetLimit.MaxTotalConnections = maxTotal cfg.NetLimit.MaxTotalConnections = maxTotal
} }
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
for _, entry := range ipWhitelist {
if str, ok := entry.(string); ok {
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
}
}
}
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
for _, entry := range ipBlacklist {
if str, ok := entry.(string); ok {
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
}
}
}
} }
t := &TCPSink{ t := &TCPSink{
@ -191,17 +211,6 @@ func (t *TCPSink) Start(ctx context.Context) error {
gnet.WithReusePort(true), gnet.WithReusePort(true),
) )
// Add TLS if configured
if t.tlsManager != nil {
// tlsConfig := t.tlsManager.GetTCPConfig()
// TODO: tlsConfig is not used, wrapper to be implemented, non-TLS stream to be available without wrapper
// ☢ SECURITY: gnet doesn't support TLS natively - would need wrapper
// This is a limitation that requires implementing TLS at application layer
t.logger.Warn("msg", "TLS configured but gnet doesn't support native TLS",
"component", "tcp_sink",
"workaround", "Use stunnel or nginx TCP proxy for TLS termination")
}
// Start gnet server // Start gnet server
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
@ -338,7 +347,29 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
t.server.mu.RLock() t.server.mu.RLock()
for conn, client := range t.server.clients { for conn, client := range t.server.clients {
if client.authenticated { if client.authenticated {
conn.AsyncWrite(data, nil) // Send through TLS bridge if present
if client.tlsBridge != nil {
if _, err := client.tlsBridge.Write(data); err != nil {
// TLS write failed, connection likely dead
t.logger.Debug("msg", "TLS write failed",
"component", "tcp_sink",
"error", err)
conn.Close()
}
} else {
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
if err != nil {
t.writeErrors.Add(1)
t.handleWriteError(c, err)
} else {
// Reset consecutive error count on success
t.errorMu.Lock()
delete(t.consecutiveWriteErrors, c)
t.errorMu.Unlock()
}
return nil
})
}
} }
} }
t.server.mu.RUnlock() t.server.mu.RUnlock()
@ -364,7 +395,22 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
continue continue
} }
} }
conn.AsyncWrite(data, nil) if client.tlsBridge != nil {
if _, err := client.tlsBridge.Write(data); err != nil {
t.logger.Debug("msg", "TLS heartbeat write failed",
"component", "tcp_sink",
"error", err)
conn.Close()
}
} else {
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
if err != nil {
t.writeErrors.Add(1)
t.handleWriteError(c, err)
}
return nil
})
}
} }
} }
t.server.mu.RUnlock() t.server.mu.RUnlock()
@ -375,6 +421,36 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
} }
} }
// Handle write errors with threshold-based connection termination
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
t.errorMu.Lock()
defer t.errorMu.Unlock()
// Track consecutive errors per connection
if t.consecutiveWriteErrors == nil {
t.consecutiveWriteErrors = make(map[gnet.Conn]int)
}
t.consecutiveWriteErrors[c]++
errorCount := t.consecutiveWriteErrors[c]
t.logger.Debug("msg", "AsyncWrite error",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr(),
"error", err,
"consecutive_errors", errorCount)
// Close connection after 3 consecutive write errors
if errorCount >= 3 {
t.logger.Warn("msg", "Closing connection due to repeated write errors",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr(),
"error_count", errorCount)
delete(t.consecutiveWriteErrors, c)
c.Close()
}
}
// Create heartbeat as a proper LogEntry // Create heartbeat as a proper LogEntry
func (t *TCPSink) createHeartbeatEntry() core.LogEntry { func (t *TCPSink) createHeartbeatEntry() core.LogEntry {
message := "heartbeat" message := "heartbeat"
@ -411,6 +487,8 @@ type tcpClient struct {
authenticated bool authenticated bool
session *auth.Session session *auth.Session
authTimeout time.Time authTimeout time.Time
tlsBridge *tls.GNetTLSConn
authTimeoutSet bool
} }
// tcpServer handles gnet events with authentication // tcpServer handles gnet events with authentication
@ -434,15 +512,13 @@ func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
} }
func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
remoteAddr := c.RemoteAddr().String() remoteAddr := c.RemoteAddr()
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr) s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
// Check IP access control first // Reject IPv6 connections immediately
if s.sink.ipChecker != nil { if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok {
if !s.sink.ipChecker.IsAllowed(c.RemoteAddr()) { if tcpAddr.IP.To4() == nil {
s.sink.logger.Warn("msg", "TCP connection denied by IP filter", return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close
"remote_addr", remoteAddr)
return nil, gnet.Close
} }
} }
@ -467,11 +543,26 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
s.sink.netLimiter.AddConnection(remoteStr) s.sink.netLimiter.AddConnection(remoteStr)
} }
// Create client state // Create client state without auth timeout initially
client := &tcpClient{ client := &tcpClient{
conn: c, conn: c,
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
authTimeout: time.Now().Add(30 * time.Second), // 30s to authenticate authTimeoutSet: false, // Auth timeout not started yet
}
// Initialize TLS bridge if enabled
if s.sink.tlsManager != nil {
tlsConfig := s.sink.tlsManager.GetTCPConfig()
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
client.tlsBridge.Handshake() // Start async handshake
s.sink.logger.Debug("msg", "TLS handshake initiated",
"component", "tcp_sink",
"remote_addr", remoteAddr)
} else if s.sink.authenticator != nil {
// Only set auth timeout if no TLS (plain connection)
client.authTimeout = time.Now().Add(30 * time.Second) // TODO: configurable or non-hardcoded timer
client.authTimeoutSet = true
} }
s.mu.Lock() s.mu.Lock()
@ -485,7 +576,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
"requires_auth", s.sink.authenticator != nil) "requires_auth", s.sink.authenticator != nil)
// Send auth prompt if authentication is required // Send auth prompt if authentication is required
if s.sink.authenticator != nil { if s.sink.authenticator != nil && s.sink.tlsManager == nil {
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n") authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
return authPrompt, gnet.None return authPrompt, gnet.None
} }
@ -498,9 +589,22 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
// Remove client state // Remove client state
s.mu.Lock() s.mu.Lock()
client := s.clients[c]
delete(s.clients, c) delete(s.clients, c)
s.mu.Unlock() s.mu.Unlock()
// Clean up TLS bridge if present
if client != nil && client.tlsBridge != nil {
client.tlsBridge.Close()
s.sink.logger.Debug("msg", "TLS connection closed",
"remote_addr", remoteAddr)
}
// Clean up write error tracking
s.sink.errorMu.Lock()
delete(s.sink.consecutiveWriteErrors, c)
s.sink.errorMu.Unlock()
// Remove connection tracking // Remove connection tracking
if s.sink.netLimiter != nil { if s.sink.netLimiter != nil {
s.sink.netLimiter.RemoveConnection(remoteAddr) s.sink.netLimiter.RemoveConnection(remoteAddr)
@ -523,13 +627,18 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.Close return gnet.Close
} }
// Check auth timeout // // Check auth timeout
if !client.authenticated && time.Now().After(client.authTimeout) { // if !client.authenticated && time.Now().After(client.authTimeout) {
s.sink.logger.Warn("msg", "Authentication timeout", // s.sink.logger.Warn("msg", "Authentication timeout",
"remote_addr", c.RemoteAddr().String()) // "component", "tcp_sink",
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil) // "remote_addr", c.RemoteAddr().String())
return gnet.Close // if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
} // client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
// } else if client.tlsBridge == nil {
// c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
// }
// return gnet.Close
// }
// Read all available data // Read all available data
data, err := c.Next(-1) data, err := c.Next(-1)
@ -540,6 +649,70 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.Close return gnet.Close
} }
// Process through TLS bridge if present
if client.tlsBridge != nil {
// Feed encrypted data into TLS engine
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
s.sink.logger.Error("msg", "TLS processing error",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String(),
"error", err)
return gnet.Close
}
// Check if handshake is complete
if !client.tlsBridge.IsHandshakeDone() {
// Still handshaking, wait for more data
return gnet.None
}
// Check handshake result
_, hsErr := client.tlsBridge.HandshakeComplete()
if hsErr != nil {
s.sink.logger.Error("msg", "TLS handshake failed",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String(),
"error", hsErr)
return gnet.Close
}
// Set auth timeout only after TLS handshake completes
if !client.authTimeoutSet && s.sink.authenticator != nil && !client.authenticated {
client.authTimeout = time.Now().Add(30 * time.Second)
client.authTimeoutSet = true
s.sink.logger.Debug("msg", "Auth timeout started after TLS handshake",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String())
}
// Read decrypted plaintext
data = client.tlsBridge.Read()
if data == nil || len(data) == 0 {
// No plaintext available yet
return gnet.None
}
// First data after TLS handshake - send auth prompt if needed
if s.sink.authenticator != nil && !client.authenticated &&
len(client.buffer.Bytes()) == 0 {
authPrompt := []byte("AUTH REQUIRED\n")
client.tlsBridge.Write(authPrompt)
}
}
// Only check auth timeout if it has been set
if !client.authenticated && client.authTimeoutSet && time.Now().After(client.authTimeout) {
s.sink.logger.Warn("msg", "Authentication timeout",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String())
if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
} else if client.tlsBridge == nil {
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
}
return gnet.Close
}
// If not authenticated, expect auth command // If not authenticated, expect auth command
if !client.authenticated { if !client.authenticated {
client.buffer.Write(data) client.buffer.Write(data)
@ -551,7 +724,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
// Parse AUTH command: AUTH <method> <credentials> // Parse AUTH command: AUTH <method> <credentials>
parts := strings.SplitN(string(line), " ", 3) parts := strings.SplitN(string(line), " ", 3)
if len(parts) != 3 || parts[0] != "AUTH" { if len(parts) != 3 || parts[0] != "AUTH" {
c.AsyncWrite([]byte("ERROR: Invalid auth format\n"), nil) // Send error through TLS if enabled
errMsg := []byte("AUTH FAILED\n")
if client.tlsBridge != nil {
client.tlsBridge.Write(errMsg)
} else {
c.AsyncWrite(errMsg, nil)
}
return gnet.None return gnet.None
} }
@ -563,7 +742,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
"remote_addr", c.RemoteAddr().String(), "remote_addr", c.RemoteAddr().String(),
"method", parts[1], "method", parts[1],
"error", err) "error", err)
c.AsyncWrite([]byte(fmt.Sprintf("AUTH FAILED: %v\n", err)), nil) // Send error through TLS if enabled
errMsg := []byte("AUTH FAILED\n")
if client.tlsBridge != nil {
client.tlsBridge.Write(errMsg)
} else {
c.AsyncWrite(errMsg, nil)
}
return gnet.Close return gnet.Close
} }
@ -575,11 +760,19 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
s.mu.Unlock() s.mu.Unlock()
s.sink.logger.Info("msg", "TCP client authenticated", s.sink.logger.Info("msg", "TCP client authenticated",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String(), "remote_addr", c.RemoteAddr().String(),
"username", session.Username, "username", session.Username,
"method", session.Method) "method", session.Method,
"tls", client.tlsBridge != nil)
c.AsyncWrite([]byte("AUTH OK\n"), nil) // Send success through TLS if enabled
successMsg := []byte("AUTH OK\n")
if client.tlsBridge != nil {
client.tlsBridge.Write(successMsg)
} else {
c.AsyncWrite(successMsg, nil)
}
// Clear buffer after auth // Clear buffer after auth
client.buffer.Reset() client.buffer.Reset()
@ -610,7 +803,7 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
// Initialize TLS manager if SSL is configured // Initialize TLS manager if SSL is configured
if t.config.SSL != nil && t.config.SSL.Enabled { if t.config.SSL != nil && t.config.SSL.Enabled {
tlsManager, err := tls.New(t.config.SSL, t.logger) tlsManager, err := tls.NewManager(t.config.SSL, t.logger)
if err != nil { if err != nil {
t.logger.Error("msg", "Failed to create TLS manager", t.logger.Error("msg", "Failed to create TLS manager",
"component", "tcp_sink", "component", "tcp_sink",
@ -624,5 +817,6 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
t.logger.Info("msg", "Authentication configured for TCP sink", t.logger.Info("msg", "Authentication configured for TCP sink",
"component", "tcp_sink", "component", "tcp_sink",
"auth_type", authCfg.Type, "auth_type", authCfg.Type,
"tls_enabled", t.tlsManager != nil) "tls_enabled", t.tlsManager != nil,
"tls_bridge", t.tlsManager != nil)
} }

View File

@ -3,6 +3,7 @@ package sink
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -10,8 +11,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/format" "logwisp/src/internal/format"
tlspkg "logwisp/src/internal/tls"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
) )
@ -28,6 +31,10 @@ type TCPClientSink struct {
logger *log.Logger logger *log.Logger
formatter format.Formatter formatter format.Formatter
// TLS support
tlsManager *tlspkg.Manager
tlsConfig *tls.Config
// Reconnection state // Reconnection state
reconnecting atomic.Bool reconnecting atomic.Bool
lastConnectErr error lastConnectErr error
@ -53,6 +60,9 @@ type TCPClientConfig struct {
ReconnectDelay time.Duration ReconnectDelay time.Duration
MaxReconnectDelay time.Duration MaxReconnectDelay time.Duration
ReconnectBackoff float64 ReconnectBackoff float64
// TLS config
SSL *config.SSLConfig
} }
// NewTCPClientSink creates a new TCP client sink // NewTCPClientSink creates a new TCP client sink
@ -103,6 +113,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
cfg.ReconnectBackoff = backoff cfg.ReconnectBackoff = backoff
} }
// Extract SSL config
if ssl, ok := options["ssl"].(map[string]any); ok {
cfg.SSL = &config.SSLConfig{}
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
if certFile, ok := ssl["cert_file"].(string); ok {
cfg.SSL.CertFile = certFile
}
if keyFile, ok := ssl["key_file"].(string); ok {
cfg.SSL.KeyFile = keyFile
}
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
if caFile, ok := ssl["client_ca_file"].(string); ok {
cfg.SSL.ClientCAFile = caFile
}
if insecure, ok := ssl["insecure_skip_verify"].(bool); ok {
cfg.SSL.InsecureSkipVerify = insecure
}
}
t := &TCPClientSink{ t := &TCPClientSink{
input: make(chan core.LogEntry, cfg.BufferSize), input: make(chan core.LogEntry, cfg.BufferSize),
config: cfg, config: cfg,
@ -114,6 +143,34 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
t.lastProcessed.Store(time.Time{}) t.lastProcessed.Store(time.Time{})
t.connectionUptime.Store(time.Duration(0)) t.connectionUptime.Store(time.Duration(0))
// Initialize TLS manager if SSL is configured
if cfg.SSL != nil && cfg.SSL.Enabled {
tlsManager, err := tlspkg.NewManager(cfg.SSL, logger)
if err != nil {
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
}
t.tlsManager = tlsManager
// Get client TLS config
t.tlsConfig = tlsManager.GetTCPConfig()
// ADDED: Client-specific TLS config adjustments
t.tlsConfig.InsecureSkipVerify = cfg.SSL.InsecureSkipVerify
// Extract server name from address for SNI
host, _, err := net.SplitHostPort(cfg.Address)
if err != nil {
return nil, fmt.Errorf("failed to parse address for SNI: %w", err)
}
t.tlsConfig.ServerName = host
logger.Info("msg", "TLS enabled for TCP client",
"component", "tcp_client_sink",
"address", cfg.Address,
"server_name", host,
"insecure", cfg.SSL.InsecureSkipVerify)
}
return t, nil return t, nil
} }
@ -280,6 +337,35 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive) tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
} }
// Wrap with TLS if configured
if t.tlsConfig != nil {
t.logger.Debug("msg", "Initiating TLS handshake",
"component", "tcp_client_sink",
"address", t.config.Address)
tlsConn := tls.Client(conn, t.tlsConfig)
// Perform handshake with timeout
handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := tlsConn.HandshakeContext(handshakeCtx); err != nil {
conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
// Log connection details
state := tlsConn.ConnectionState()
t.logger.Info("msg", "TLS connection established",
"component", "tcp_client_sink",
"address", t.config.Address,
"tls_version", tlsVersionString(state.Version),
"cipher_suite", tls.CipherSuiteName(state.CipherSuite),
"server_name", state.ServerName)
return tlsConn, nil
}
return conn, nil return conn, nil
} }
@ -295,7 +381,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) {
return return
case <-ticker.C: case <-ticker.C:
// Set read deadline // Set read deadline
// TODO: Add t.config.ReadTimeout instead of static value // TODO: Add t.config.ReadTimeout and after addition use it instead of static value
if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
t.logger.Debug("msg", "Failed to set read deadline", "error", err) t.logger.Debug("msg", "Failed to set read deadline", "error", err)
return return
@ -379,3 +465,19 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
return nil return nil
} }
// tlsVersionString returns human-readable TLS version
func tlsVersionString(version uint16) string {
switch version {
case tls.VersionTLS10:
return "TLS1.0"
case tls.VersionTLS11:
return "TLS1.1"
case tls.VersionTLS12:
return "TLS1.2"
case tls.VersionTLS13:
return "TLS1.3"
default:
return fmt.Sprintf("0x%04x", version)
}
}

View File

@ -4,6 +4,8 @@ package source
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"logwisp/src/internal/tls"
"net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -29,6 +31,10 @@ type HTTPSource struct {
netLimiter *limit.NetLimiter netLimiter *limit.NetLimiter
logger *log.Logger logger *log.Logger
// CHANGED: Add TLS support
tlsManager *tls.Manager
sslConfig *config.SSLConfig
// Statistics // Statistics
totalEntries atomic.Uint64 totalEntries atomic.Uint64
droppedEntries atomic.Uint64 droppedEntries atomic.Uint64
@ -94,6 +100,28 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err
} }
} }
// Extract SSL config after existing options
if ssl, ok := options["ssl"].(map[string]any); ok {
h.sslConfig = &config.SSLConfig{}
h.sslConfig.Enabled, _ = ssl["enabled"].(bool)
if certFile, ok := ssl["cert_file"].(string); ok {
h.sslConfig.CertFile = certFile
}
if keyFile, ok := ssl["key_file"].(string); ok {
h.sslConfig.KeyFile = keyFile
}
// TODO: extract other SSL options similar to tcp_client_sink
// Create TLS manager
if h.sslConfig.Enabled {
tlsManager, err := tls.NewManager(h.sslConfig, logger)
if err != nil {
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
}
h.tlsManager = tlsManager
}
}
return h, nil return h, nil
} }
@ -123,9 +151,19 @@ func (h *HTTPSource) Start() error {
h.logger.Info("msg", "HTTP source server starting", h.logger.Info("msg", "HTTP source server starting",
"component", "http_source", "component", "http_source",
"port", h.port, "port", h.port,
"ingest_path", h.ingestPath) "ingest_path", h.ingestPath,
"tls_enabled", h.tlsManager != nil)
if err := h.server.ListenAndServe(addr); err != nil { var err error
// Check for TLS manager and start the appropriate server type
if h.tlsManager != nil {
h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
err = h.server.ListenAndServeTLS(addr, h.sslConfig.CertFile, h.sslConfig.KeyFile)
} else {
err = h.server.ListenAndServe(addr)
}
if err != nil {
h.logger.Error("msg", "HTTP source server failed", h.logger.Error("msg", "HTTP source server failed",
"component", "http_source", "component", "http_source",
"port", h.port, "port", h.port,
@ -134,7 +172,7 @@ func (h *HTTPSource) Start() error {
}() }()
// Give server time to start // Give server time to start
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond) // TODO: standardize and better manage timers
return nil return nil
} }
@ -202,8 +240,21 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
return return
} }
// Check net limit // Extract and validate IP
remoteAddr := ctx.RemoteAddr().String() remoteAddr := ctx.RemoteAddr().String()
ipStr, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
if ip := net.ParseIP(ipStr); ip != nil && ip.To4() == nil {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "IPv4-only (IPv6 not supported)",
})
return
}
}
// Check net limit
if h.netLimiter != nil { if h.netLimiter != nil {
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed { if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
ctx.SetStatusCode(int(statusCode)) ctx.SetStatusCode(int(statusCode))
@ -280,7 +331,7 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) {
// Try to parse as JSON array // Try to parse as JSON array
var array []core.LogEntry var array []core.LogEntry
if err := json.Unmarshal(body, &array); err == nil { if err := json.Unmarshal(body, &array); err == nil {
// TODO: Placeholder; For array, divide total size by entry count as approximation // NOTE: Placeholder; For array, divide total size by entry count as approximation
approxSizePerEntry := int64(len(body) / len(array)) approxSizePerEntry := int64(len(body) / len(array))
for i, entry := range array { for i, entry := range array {
if entry.Message == "" { if entry.Message == "" {
@ -292,7 +343,7 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) {
if entry.Source == "" { if entry.Source == "" {
array[i].Source = "http" array[i].Source = "http"
} }
// TODO: Placeholder // NOTE: Placeholder
array[i].RawSize = approxSizePerEntry array[i].RawSize = approxSizePerEntry
} }
return array, nil return array, nil

View File

@ -5,21 +5,31 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"logwisp/src/internal/auth"
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/limit" "logwisp/src/internal/limit"
"logwisp/src/internal/tls"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
"github.com/lixenwraith/log/compat" "github.com/lixenwraith/log/compat"
"github.com/panjf2000/gnet/v2" "github.com/panjf2000/gnet/v2"
) )
const (
maxClientBufferSize = 10 * 1024 * 1024 // 10MB max per client
maxLineLength = 1 * 1024 * 1024 // 1MB max per log line
maxEncryptedDataPerRead = 1 * 1024 * 1024 // 1MB max encrypted data per read
maxCumulativeEncrypted = 20 * 1024 * 1024 // 20MB total encrypted before processing
)
// TCPSource receives log entries via TCP connections // TCPSource receives log entries via TCP connections
type TCPSource struct { type TCPSource struct {
port int64 port int64
@ -32,6 +42,8 @@ type TCPSource struct {
engineMu sync.Mutex engineMu sync.Mutex
wg sync.WaitGroup wg sync.WaitGroup
netLimiter *limit.NetLimiter netLimiter *limit.NetLimiter
tlsManager *tls.Manager
sslConfig *config.SSLConfig
logger *log.Logger logger *log.Logger
// Statistics // Statistics
@ -91,6 +103,32 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error
} }
} }
// Extract SSL config and initialize TLS manager
if ssl, ok := options["ssl"].(map[string]any); ok {
t.sslConfig = &config.SSLConfig{}
t.sslConfig.Enabled, _ = ssl["enabled"].(bool)
if certFile, ok := ssl["cert_file"].(string); ok {
t.sslConfig.CertFile = certFile
}
if keyFile, ok := ssl["key_file"].(string); ok {
t.sslConfig.KeyFile = keyFile
}
t.sslConfig.ClientAuth, _ = ssl["client_auth"].(bool)
if caFile, ok := ssl["client_ca_file"].(string); ok {
t.sslConfig.ClientCAFile = caFile
}
t.sslConfig.VerifyClientCert, _ = ssl["verify_client_cert"].(bool)
// Create TLS manager if enabled
if t.sslConfig.Enabled {
tlsManager, err := tls.NewManager(t.sslConfig, logger)
if err != nil {
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
}
t.tlsManager = tlsManager
}
}
return t, nil return t, nil
} }
@ -121,7 +159,8 @@ func (t *TCPSource) Start() error {
defer t.wg.Done() defer t.wg.Done()
t.logger.Info("msg", "TCP source server starting", t.logger.Info("msg", "TCP source server starting",
"component", "tcp_source", "component", "tcp_source",
"port", t.port) "port", t.port,
"tls_enabled", t.tlsManager != nil)
err := gnet.Run(t.server, addr, err := gnet.Run(t.server, addr,
gnet.WithLogger(gnetLogger), gnet.WithLogger(gnetLogger),
@ -235,6 +274,12 @@ func (t *TCPSource) publish(entry core.LogEntry) bool {
type tcpClient struct { type tcpClient struct {
conn gnet.Conn conn gnet.Conn
buffer bytes.Buffer buffer bytes.Buffer
authenticated bool
session *auth.Session
authTimeout time.Time
tlsBridge *tls.GNetTLSConn
maxBufferSeen int
cumulativeEncrypted int64
} }
// tcpSourceServer handles gnet events // tcpSourceServer handles gnet events
@ -265,8 +310,7 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
// Check net limit // Check net limit
if s.source.netLimiter != nil { if s.source.netLimiter != nil {
remoteStr := c.RemoteAddr().String() tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr)
tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr)
if err != nil { if err != nil {
s.source.logger.Warn("msg", "Failed to parse TCP address", s.source.logger.Warn("msg", "Failed to parse TCP address",
"component", "tcp_source", "component", "tcp_source",
@ -283,7 +327,21 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
} }
// Track connection // Track connection
s.source.netLimiter.AddConnection(remoteStr) s.source.netLimiter.AddConnection(remoteAddr)
}
// Create client state
client := &tcpClient{conn: c}
// Initialize TLS bridge if enabled
if s.source.tlsManager != nil {
tlsConfig := s.source.tlsManager.GetTCPConfig()
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
client.tlsBridge.Handshake() // Start async handshake
s.source.logger.Debug("msg", "TLS handshake initiated",
"component", "tcp_source",
"remote_addr", remoteAddr)
} }
// Create client state // Create client state
@ -295,7 +353,8 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
s.source.logger.Debug("msg", "TCP connection opened", s.source.logger.Debug("msg", "TCP connection opened",
"component", "tcp_source", "component", "tcp_source",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
"active_connections", newCount) "active_connections", newCount,
"tls_enabled", s.source.tlsManager != nil)
return nil, gnet.None return nil, gnet.None
} }
@ -305,9 +364,18 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
// Remove client state // Remove client state
s.mu.Lock() s.mu.Lock()
client := s.clients[c]
delete(s.clients, c) delete(s.clients, c)
s.mu.Unlock() s.mu.Unlock()
// Clean up TLS bridge if present
if client != nil && client.tlsBridge != nil {
client.tlsBridge.Close()
s.source.logger.Debug("msg", "TLS connection closed",
"component", "tcp_source",
"remote_addr", remoteAddr)
}
// Remove connection tracking // Remove connection tracking
if s.source.netLimiter != nil { if s.source.netLimiter != nil {
s.source.netLimiter.RemoveConnection(remoteAddr) s.source.netLimiter.RemoveConnection(remoteAddr)
@ -340,9 +408,113 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.Close return gnet.Close
} }
// Check encrypted data size BEFORE processing through TLS
if len(data) > maxEncryptedDataPerRead {
s.source.logger.Warn("msg", "Encrypted data per read limit exceeded",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"data_size", len(data),
"limit", maxEncryptedDataPerRead)
s.source.invalidEntries.Add(1)
return gnet.Close
}
// Track cumulative encrypted data to prevent slow accumulation
client.cumulativeEncrypted += int64(len(data))
if client.cumulativeEncrypted > maxCumulativeEncrypted {
s.source.logger.Warn("msg", "Cumulative encrypted data limit exceeded",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"total_encrypted", client.cumulativeEncrypted,
"limit", maxCumulativeEncrypted)
s.source.invalidEntries.Add(1)
return gnet.Close
}
// Process through TLS bridge if present
if client.tlsBridge != nil {
// Feed encrypted data into TLS engine
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
if errors.Is(err, tls.ErrTLSBackpressure) {
s.source.logger.Warn("msg", "TLS backpressure, closing slow client",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String())
} else {
s.source.logger.Error("msg", "TLS processing error",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"error", err)
}
return gnet.Close
}
// Check if handshake is complete
if !client.tlsBridge.IsHandshakeDone() {
// Still handshaking, wait for more data
return gnet.None
}
// Check handshake result
_, hsErr := client.tlsBridge.HandshakeComplete()
if hsErr != nil {
s.source.logger.Error("msg", "TLS handshake failed",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"error", hsErr)
return gnet.Close
}
// Read decrypted plaintext
data = client.tlsBridge.Read()
if data == nil || len(data) == 0 {
// No plaintext available yet
return gnet.None
}
// Reset cumulative counter after successful decryption and processing
client.cumulativeEncrypted = 0
}
// Check buffer size before appending
if client.buffer.Len()+len(data) > maxClientBufferSize {
s.source.logger.Warn("msg", "Client buffer limit exceeded",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"buffer_size", client.buffer.Len(),
"incoming_size", len(data))
s.source.invalidEntries.Add(1)
return gnet.Close
}
// Append to client buffer // Append to client buffer
client.buffer.Write(data) client.buffer.Write(data)
// Track high buffer
if client.buffer.Len() > client.maxBufferSeen {
client.maxBufferSeen = client.buffer.Len()
}
// Check for suspiciously long lines before attempting to read
if client.buffer.Len() > maxLineLength {
// Scan for newline in current buffer
bufBytes := client.buffer.Bytes()
hasNewline := false
for _, b := range bufBytes {
if b == '\n' {
hasNewline = true
break
}
}
if !hasNewline {
s.source.logger.Warn("msg", "Line too long without newline",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"buffer_size", client.buffer.Len())
s.source.invalidEntries.Add(1)
return gnet.Close
}
}
// Process complete lines // Process complete lines
for { for {
line, err := client.buffer.ReadBytes('\n') line, err := client.buffer.ReadBytes('\n')

View File

@ -0,0 +1,341 @@
// FILE: src/internal/tls/gnet_bridge.go
package tls
import (
"crypto/tls"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/panjf2000/gnet/v2"
)
var (
ErrTLSBackpressure = errors.New("TLS processing backpressure")
ErrConnectionClosed = errors.New("connection closed")
ErrPlaintextBufferExceeded = errors.New("plaintext buffer size exceeded")
)
// Maximum plaintext buffer size to prevent memory exhaustion
const maxPlaintextBufferSize = 32 * 1024 * 1024 // 32MB
// GNetTLSConn bridges gnet.Conn with crypto/tls via io.Pipe
type GNetTLSConn struct {
gnetConn gnet.Conn
tlsConn *tls.Conn
config *tls.Config
// Buffered channels for non-blocking operation
incomingCipher chan []byte // Network → TLS (encrypted)
outgoingCipher chan []byte // TLS → Network (encrypted)
// Handshake state
handshakeOnce sync.Once
handshakeDone chan struct{}
handshakeErr error
// Decrypted data buffer
plainBuf []byte
plainMu sync.Mutex
// Lifecycle
closed atomic.Bool
closeOnce sync.Once
wg sync.WaitGroup
// Error tracking
lastErr atomic.Value // error
logger interface{ Warn(args ...any) } // Minimal logger interface
}
// NewServerConn creates a server-side TLS bridge
func NewServerConn(gnetConn gnet.Conn, config *tls.Config) *GNetTLSConn {
tc := &GNetTLSConn{
gnetConn: gnetConn,
config: config,
handshakeDone: make(chan struct{}),
// Buffered channels sized for throughput without blocking
incomingCipher: make(chan []byte, 128), // 128 packets buffer
outgoingCipher: make(chan []byte, 128),
plainBuf: make([]byte, 0, 65536), // 64KB initial capacity
}
// Create TLS conn with channel-based transport
rawConn := &channelConn{
incoming: tc.incomingCipher,
outgoing: tc.outgoingCipher,
localAddr: gnetConn.LocalAddr(),
remoteAddr: gnetConn.RemoteAddr(),
tc: tc,
}
tc.tlsConn = tls.Server(rawConn, config)
// Start pump goroutines
tc.wg.Add(2)
go tc.pumpCipherToNetwork()
go tc.pumpPlaintextFromTLS()
return tc
}
// NewClientConn creates a client-side TLS bridge (similar changes)
func NewClientConn(gnetConn gnet.Conn, config *tls.Config, serverName string) *GNetTLSConn {
tc := &GNetTLSConn{
gnetConn: gnetConn,
config: config,
handshakeDone: make(chan struct{}),
incomingCipher: make(chan []byte, 128),
outgoingCipher: make(chan []byte, 128),
plainBuf: make([]byte, 0, 65536),
}
if config.ServerName == "" {
config = config.Clone()
config.ServerName = serverName
}
rawConn := &channelConn{
incoming: tc.incomingCipher,
outgoing: tc.outgoingCipher,
localAddr: gnetConn.LocalAddr(),
remoteAddr: gnetConn.RemoteAddr(),
tc: tc,
}
tc.tlsConn = tls.Client(rawConn, config)
tc.wg.Add(2)
go tc.pumpCipherToNetwork()
go tc.pumpPlaintextFromTLS()
return tc
}
// ProcessIncoming feeds encrypted data from network into TLS engine (non-blocking)
func (tc *GNetTLSConn) ProcessIncoming(encryptedData []byte) error {
if tc.closed.Load() {
return ErrConnectionClosed
}
// Non-blocking send with backpressure detection
select {
case tc.incomingCipher <- encryptedData:
return nil
default:
// Channel full - TLS processing can't keep up
// Drop connection under backpressure vs blocking event loop
if tc.logger != nil {
tc.logger.Warn("msg", "TLS backpressure, dropping data",
"remote_addr", tc.gnetConn.RemoteAddr())
}
return ErrTLSBackpressure
}
}
// pumpCipherToNetwork sends TLS-encrypted data to network
func (tc *GNetTLSConn) pumpCipherToNetwork() {
defer tc.wg.Done()
for {
select {
case data, ok := <-tc.outgoingCipher:
if !ok {
return
}
// Send to network
if err := tc.gnetConn.AsyncWrite(data, nil); err != nil {
tc.lastErr.Store(err)
tc.Close()
return
}
case <-time.After(30 * time.Second):
// Keepalive/timeout check
if tc.closed.Load() {
return
}
}
}
}
// pumpPlaintextFromTLS reads decrypted data from TLS
func (tc *GNetTLSConn) pumpPlaintextFromTLS() {
defer tc.wg.Done()
buf := make([]byte, 32768) // 32KB read buffer
for {
n, err := tc.tlsConn.Read(buf)
if n > 0 {
tc.plainMu.Lock()
// Check buffer size limit before appending to prevent memory exhaustion
if len(tc.plainBuf)+n > maxPlaintextBufferSize {
tc.plainMu.Unlock()
// Log warning about buffer limit
if tc.logger != nil {
tc.logger.Warn("msg", "Plaintext buffer limit exceeded, closing connection",
"remote_addr", tc.gnetConn.RemoteAddr(),
"buffer_size", len(tc.plainBuf),
"incoming_size", n,
"limit", maxPlaintextBufferSize)
}
// Store error and close connection
tc.lastErr.Store(ErrPlaintextBufferExceeded)
tc.Close()
return
}
tc.plainBuf = append(tc.plainBuf, buf[:n]...)
tc.plainMu.Unlock()
}
if err != nil {
if err != io.EOF {
tc.lastErr.Store(err)
}
tc.Close()
return
}
}
}
// Read returns available decrypted plaintext (non-blocking)
func (tc *GNetTLSConn) Read() []byte {
tc.plainMu.Lock()
defer tc.plainMu.Unlock()
if len(tc.plainBuf) == 0 {
return nil
}
// Atomic buffer swap under mutex protection to prevent race condition
data := tc.plainBuf
tc.plainBuf = make([]byte, 0, cap(tc.plainBuf))
return data
}
// Write encrypts plaintext and queues for network transmission
func (tc *GNetTLSConn) Write(plaintext []byte) (int, error) {
if tc.closed.Load() {
return 0, ErrConnectionClosed
}
if !tc.IsHandshakeDone() {
return 0, errors.New("handshake not complete")
}
return tc.tlsConn.Write(plaintext)
}
// Handshake initiates TLS handshake asynchronously
func (tc *GNetTLSConn) Handshake() {
tc.handshakeOnce.Do(func() {
go func() {
tc.handshakeErr = tc.tlsConn.Handshake()
close(tc.handshakeDone)
}()
})
}
// IsHandshakeDone checks if handshake is complete
func (tc *GNetTLSConn) IsHandshakeDone() bool {
select {
case <-tc.handshakeDone:
return true
default:
return false
}
}
// HandshakeComplete waits for handshake completion
func (tc *GNetTLSConn) HandshakeComplete() (<-chan struct{}, error) {
<-tc.handshakeDone
return tc.handshakeDone, tc.handshakeErr
}
// Close shuts down the bridge
func (tc *GNetTLSConn) Close() error {
tc.closeOnce.Do(func() {
tc.closed.Store(true)
// Close TLS connection
tc.tlsConn.Close()
// Close channels to stop pumps
close(tc.incomingCipher)
close(tc.outgoingCipher)
})
// Wait for pumps to finish
tc.wg.Wait()
return nil
}
// GetConnectionState returns TLS connection state
func (tc *GNetTLSConn) GetConnectionState() tls.ConnectionState {
return tc.tlsConn.ConnectionState()
}
// GetError returns last error
func (tc *GNetTLSConn) GetError() error {
if err, ok := tc.lastErr.Load().(error); ok {
return err
}
return nil
}
// channelConn implements net.Conn over channels
type channelConn struct {
incoming <-chan []byte
outgoing chan<- []byte
localAddr net.Addr
remoteAddr net.Addr
tc *GNetTLSConn
readBuf []byte
}
func (c *channelConn) Read(b []byte) (int, error) {
// Use buffered read for efficiency
if len(c.readBuf) > 0 {
n := copy(b, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
// Wait for new data
select {
case data, ok := <-c.incoming:
if !ok {
return 0, io.EOF
}
n := copy(b, data)
if n < len(data) {
c.readBuf = data[n:] // Buffer remainder
}
return n, nil
case <-time.After(30 * time.Second):
return 0, errors.New("read timeout")
}
}
func (c *channelConn) Write(b []byte) (int, error) {
if c.tc.closed.Load() {
return 0, ErrConnectionClosed
}
// Make a copy since TLS may hold reference
data := make([]byte, len(b))
copy(data, b)
select {
case c.outgoing <- data:
return len(b), nil
case <-time.After(5 * time.Second):
return 0, errors.New("write timeout")
}
}
func (c *channelConn) Close() error { return nil }
func (c *channelConn) LocalAddr() net.Addr { return c.localAddr }
func (c *channelConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *channelConn) SetDeadline(t time.Time) error { return nil }
func (c *channelConn) SetReadDeadline(t time.Time) error { return nil }
func (c *channelConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@ -20,8 +20,8 @@ type Manager struct {
logger *log.Logger logger *log.Logger
} }
// New creates a TLS configuration from SSL config // NewManager creates a TLS configuration from SSL config
func New(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) { func NewManager(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) {
if cfg == nil || !cfg.Enabled { if cfg == nil || !cfg.Enabled {
return nil, nil return nil, nil
} }