From cc27f5cc1c209661e405f6cc0dcce162936042eeabbf820d12abf30db27b6ada Mon Sep 17 00:00:00 2001 From: Lixen Wraith Date: Sun, 13 Jul 2025 03:20:47 -0400 Subject: [PATCH] v0.3.3 pipeline rate limiter added --- src/cmd/logwisp/status.go | 12 +- src/internal/config/pipeline.go | 17 +- src/internal/config/ratelimit.go | 52 ++ src/internal/config/server.go | 20 +- src/internal/config/validation.go | 5 + src/internal/netlimit/limiter.go | 432 ++++++++++++++++ .../ratelimiter.go => netlimit/netlimiter.go} | 6 +- src/internal/ratelimit/limiter.go | 467 +++--------------- src/internal/service/pipeline.go | 55 ++- src/internal/service/service.go | 19 + src/internal/sink/http.go | 66 +-- src/internal/sink/tcp.go | 56 +-- src/internal/source/directory.go | 12 - src/internal/source/http.go | 43 +- src/internal/source/source.go | 5 - src/internal/source/stdin.go | 12 - src/internal/source/tcp.go | 51 +- 17 files changed, 742 insertions(+), 588 deletions(-) create mode 100644 src/internal/config/ratelimit.go create mode 100644 src/internal/netlimit/limiter.go rename src/internal/{ratelimit/ratelimiter.go => netlimit/netlimiter.go} (91%) diff --git a/src/cmd/logwisp/status.go b/src/cmd/logwisp/status.go index 8134509..56d850c 100644 --- a/src/cmd/logwisp/status.go +++ b/src/cmd/logwisp/status.go @@ -118,10 +118,10 @@ func displayPipelineEndpoints(cfg config.PipelineConfig, routerMode bool) { "sink_index", i, "port", port) - // Display rate limit info if configured - if rl, ok := sinkCfg.Options["rate_limit"].(map[string]any); ok { + // Display net limit info if configured + if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok { if enabled, ok := rl["enabled"].(bool); ok && enabled { - logger.Info("msg", "TCP rate limiting enabled", + logger.Info("msg", "TCP net limiting enabled", "pipeline", cfg.Name, "sink_index", i, "requests_per_second", rl["requests_per_second"], @@ -155,10 +155,10 @@ func displayPipelineEndpoints(cfg config.PipelineConfig, routerMode bool) { "status_url", fmt.Sprintf("http://localhost:%d%s", port, statusPath)) } - // Display rate limit info if configured - if rl, ok := sinkCfg.Options["rate_limit"].(map[string]any); ok { + // Display net limit info if configured + if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok { if enabled, ok := rl["enabled"].(bool); ok && enabled { - logger.Info("msg", "HTTP rate limiting enabled", + logger.Info("msg", "HTTP net limiting enabled", "pipeline", cfg.Name, "sink_index", i, "requests_per_second", rl["requests_per_second"], diff --git a/src/internal/config/pipeline.go b/src/internal/config/pipeline.go index 89e4423..4343a11 100644 --- a/src/internal/config/pipeline.go +++ b/src/internal/config/pipeline.go @@ -17,6 +17,9 @@ type PipelineConfig struct { // Data sources for this pipeline Sources []SourceConfig `toml:"sources"` + // Rate limiting + RateLimit *RateLimitConfig `toml:"rate_limit"` + // Filter configuration Filters []FilterConfig `toml:"filters"` @@ -37,7 +40,7 @@ type SourceConfig struct { // Placeholder for future source-side rate limiting // This will be used for features like aggregation and summarization - RateLimit *RateLimitConfig `toml:"rate_limit"` + NetLimit *NetLimitConfig `toml:"net_limit"` } // SinkConfig represents an output destination @@ -187,9 +190,9 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts } } - // Validate rate limit if present - if rl, ok := cfg.Options["rate_limit"].(map[string]any); ok { - if err := validateRateLimitOptions("HTTP", pipelineName, sinkIndex, rl); err != nil { + // Validate net limit if present + if rl, ok := cfg.Options["net_limit"].(map[string]any); ok { + if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, rl); err != nil { return err } } @@ -231,9 +234,9 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts } } - // Validate rate limit if present - if rl, ok := cfg.Options["rate_limit"].(map[string]any); ok { - if err := validateRateLimitOptions("TCP", pipelineName, sinkIndex, rl); err != nil { + // Validate net limit if present + if rl, ok := cfg.Options["net_limit"].(map[string]any); ok { + if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, rl); err != nil { return err } } diff --git a/src/internal/config/ratelimit.go b/src/internal/config/ratelimit.go new file mode 100644 index 0000000..4a0225d --- /dev/null +++ b/src/internal/config/ratelimit.go @@ -0,0 +1,52 @@ +// FILE: src/internal/config/ratelimit.go +package config + +import ( + "fmt" + "strings" +) + +// RateLimitPolicy defines the action to take when a rate limit is exceeded. +type RateLimitPolicy int + +const ( + // PolicyPass allows all logs through, effectively disabling the limiter. + PolicyPass RateLimitPolicy = iota + // PolicyDrop drops logs that exceed the rate limit. + PolicyDrop +) + +// RateLimitConfig defines the configuration for pipeline-level rate limiting. +type RateLimitConfig struct { + // Rate is the number of log entries allowed per second. Default: 0 (disabled). + Rate float64 `toml:"rate"` + // Burst is the maximum number of log entries that can be sent in a short burst. Defaults to the Rate. + Burst float64 `toml:"burst"` + // Policy defines the action to take when the limit is exceeded. "pass" or "drop". + Policy string `toml:"policy"` +} + +func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { + if cfg == nil { + return nil + } + + if cfg.Rate < 0 { + return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName) + } + + if cfg.Burst < 0 { + return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName) + } + + // Validate policy + switch strings.ToLower(cfg.Policy) { + case "", "pass", "drop": + // Valid policies + default: + return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')", + pipelineName, cfg.Policy) + } + + return nil +} \ No newline at end of file diff --git a/src/internal/config/server.go b/src/internal/config/server.go index a161d1c..82e56e2 100644 --- a/src/internal/config/server.go +++ b/src/internal/config/server.go @@ -11,8 +11,8 @@ type TCPConfig struct { // SSL/TLS Configuration SSL *SSLConfig `toml:"ssl"` - // Rate limiting - RateLimit *RateLimitConfig `toml:"rate_limit"` + // Net limiting + NetLimit *NetLimitConfig `toml:"net_limit"` // Heartbeat Heartbeat HeartbeatConfig `toml:"heartbeat"` @@ -30,8 +30,8 @@ type HTTPConfig struct { // SSL/TLS Configuration SSL *SSLConfig `toml:"ssl"` - // Rate limiting - RateLimit *RateLimitConfig `toml:"rate_limit"` + // Nate limiting + NetLimit *NetLimitConfig `toml:"net_limit"` // Heartbeat Heartbeat HeartbeatConfig `toml:"heartbeat"` @@ -45,8 +45,8 @@ type HeartbeatConfig struct { Format string `toml:"format"` // "comment" or "json" } -type RateLimitConfig struct { - // Enable rate limiting +type NetLimitConfig struct { + // Enable net limiting Enabled bool `toml:"enabled"` // Requests per second per client @@ -55,12 +55,12 @@ type RateLimitConfig struct { // Burst size (token bucket) BurstSize int `toml:"burst_size"` - // Rate limit by: "ip", "user", "token", "global" + // Net limit by: "ip", "user", "token", "global" LimitBy string `toml:"limit_by"` - // Response when rate limited + // Response when net limited ResponseCode int `toml:"response_code"` // Default: 429 - ResponseMessage string `toml:"response_message"` // Default: "Rate limit exceeded" + ResponseMessage string `toml:"response_message"` // Default: "Net limit exceeded" // Connection limits MaxConnectionsPerIP int `toml:"max_connections_per_ip"` @@ -85,7 +85,7 @@ func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb return nil } -func validateRateLimitOptions(serverType, pipelineName string, sinkIndex int, rl map[string]any) error { +func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl map[string]any) error { if enabled, ok := rl["enabled"].(bool); !ok || !enabled { return nil } diff --git a/src/internal/config/validation.go b/src/internal/config/validation.go index 855b767..783f411 100644 --- a/src/internal/config/validation.go +++ b/src/internal/config/validation.go @@ -48,6 +48,11 @@ func (c *Config) validate() error { } } + // Validate rate limit if present + if err := validateRateLimit(pipeline.Name, pipeline.RateLimit); err != nil { + return err + } + // Validate filters for j, filterCfg := range pipeline.Filters { if err := validateFilter(pipeline.Name, j, &filterCfg); err != nil { diff --git a/src/internal/netlimit/limiter.go b/src/internal/netlimit/limiter.go new file mode 100644 index 0000000..7d52a9e --- /dev/null +++ b/src/internal/netlimit/limiter.go @@ -0,0 +1,432 @@ +// FILE: src/internal/netlimit/limiter.go +package netlimit + +import ( + "context" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "logwisp/src/internal/config" + + "github.com/lixenwraith/log" +) + +// Manages net limiting for a transport +type Limiter struct { + config config.NetLimitConfig + logger *log.Logger + + // Per-IP limiters + ipLimiters map[string]*ipLimiter + ipMu sync.RWMutex + + // Global limiter for the transport + globalLimiter *TokenBucket + + // Connection tracking + ipConnections map[string]*atomic.Int32 + connMu sync.RWMutex + + // Statistics + totalRequests atomic.Uint64 + blockedRequests atomic.Uint64 + uniqueIPs atomic.Uint64 + + // Cleanup + lastCleanup time.Time + cleanupMu sync.Mutex + + // Lifecycle management + ctx context.Context + cancel context.CancelFunc + cleanupDone chan struct{} +} + +type ipLimiter struct { + bucket *TokenBucket + lastSeen time.Time + connections atomic.Int32 +} + +// Creates a new net limiter +func New(cfg config.NetLimitConfig, logger *log.Logger) *Limiter { + if !cfg.Enabled { + return nil + } + + if logger == nil { + panic("netlimit.New: logger cannot be nil") + } + + ctx, cancel := context.WithCancel(context.Background()) + + l := &Limiter{ + config: cfg, + ipLimiters: make(map[string]*ipLimiter), + ipConnections: make(map[string]*atomic.Int32), + lastCleanup: time.Now(), + logger: logger, + ctx: ctx, + cancel: cancel, + cleanupDone: make(chan struct{}), + } + + // Create global limiter if not using per-IP limiting + if cfg.LimitBy == "global" { + l.globalLimiter = NewTokenBucket( + float64(cfg.BurstSize), + cfg.RequestsPerSecond, + ) + } + + // Start cleanup goroutine + go l.cleanupLoop() + + l.logger.Info("msg", "Net limiter initialized", + "component", "netlimit", + "requests_per_second", cfg.RequestsPerSecond, + "burst_size", cfg.BurstSize, + "limit_by", cfg.LimitBy) + + return l +} + +func (l *Limiter) Shutdown() { + if l == nil { + return + } + + l.logger.Info("msg", "Shutting down net limiter", "component", "netlimit") + + // Cancel context to stop cleanup goroutine + l.cancel() + + // Wait for cleanup goroutine to finish + select { + case <-l.cleanupDone: + l.logger.Debug("msg", "Cleanup goroutine stopped", "component", "netlimit") + case <-time.After(2 * time.Second): + l.logger.Warn("msg", "Cleanup goroutine shutdown timeout", "component", "netlimit") + } +} + +// Checks if an HTTP request should be allowed +func (l *Limiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int, message string) { + if l == nil { + return true, 0, "" + } + + l.totalRequests.Add(1) + + ip, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + // If we can't parse the IP, allow the request but log + l.logger.Warn("msg", "Failed to parse remote addr", + "component", "netlimit", + "remote_addr", remoteAddr, + "error", err) + return true, 0, "" + } + + // Only supporting ipv4 + if !isIPv4(ip) { + // Block non-IPv4 addresses to prevent complications + l.blockedRequests.Add(1) + l.logger.Warn("msg", "Non-IPv4 address blocked", + "component", "netlimit", + "ip", ip) + return false, 403, "IPv4 only" + } + + // Check connection limit for streaming endpoint + if l.config.MaxConnectionsPerIP > 0 { + l.connMu.RLock() + counter, exists := l.ipConnections[ip] + l.connMu.RUnlock() + + if exists && counter.Load() >= int32(l.config.MaxConnectionsPerIP) { + l.blockedRequests.Add(1) + statusCode = l.config.ResponseCode + if statusCode == 0 { + statusCode = 429 + } + message = "Connection limit exceeded" + + l.logger.Warn("msg", "Connection limit exceeded", + "component", "netlimit", + "ip", ip, + "connections", counter.Load(), + "limit", l.config.MaxConnectionsPerIP) + + return false, statusCode, message + } + } + + // Check net limit + allowed = l.checkLimit(ip) + if !allowed { + l.blockedRequests.Add(1) + statusCode = l.config.ResponseCode + if statusCode == 0 { + statusCode = 429 + } + message = l.config.ResponseMessage + if message == "" { + message = "Net limit exceeded" + } + l.logger.Debug("msg", "Request net limited", "ip", ip) + } + + return allowed, statusCode, message +} + +// Checks if a TCP connection should be allowed +func (l *Limiter) CheckTCP(remoteAddr net.Addr) bool { + if l == nil { + return true + } + + l.totalRequests.Add(1) + + // Extract IP from TCP addr + tcpAddr, ok := remoteAddr.(*net.TCPAddr) + if !ok { + return true + } + + ip := tcpAddr.IP.String() + + // Only supporting ipv4 + if !isIPv4(ip) { + l.blockedRequests.Add(1) + l.logger.Warn("msg", "Non-IPv4 TCP connection blocked", + "component", "netlimit", + "ip", ip) + return false + } + + allowed := l.checkLimit(ip) + if !allowed { + l.blockedRequests.Add(1) + l.logger.Debug("msg", "TCP connection net limited", "ip", ip) + } + + return allowed +} + +func isIPv4(ip string) bool { + // Simple check: IPv4 addresses contain dots, IPv6 contain colons + return strings.Contains(ip, ".") && !strings.Contains(ip, ":") +} + +// Tracks a new connection for an IP +func (l *Limiter) AddConnection(remoteAddr string) { + if l == nil { + return + } + + ip, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return + } + + // Only supporting ipv4 + if !isIPv4(ip) { + return + } + + l.connMu.Lock() + counter, exists := l.ipConnections[ip] + if !exists { + counter = &atomic.Int32{} + l.ipConnections[ip] = counter + } + l.connMu.Unlock() + + newCount := counter.Add(1) + l.logger.Debug("msg", "Connection added", + "ip", ip, + "connections", newCount) +} + +// Removes a connection for an IP +func (l *Limiter) RemoveConnection(remoteAddr string) { + if l == nil { + return + } + + ip, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return + } + + // Only supporting ipv4 + if !isIPv4(ip) { + return + } + + l.connMu.RLock() + counter, exists := l.ipConnections[ip] + l.connMu.RUnlock() + + if exists { + newCount := counter.Add(-1) + l.logger.Debug("msg", "Connection removed", + "ip", ip, + "connections", newCount) + + if newCount <= 0 { + // Clean up if no more connections + l.connMu.Lock() + if counter.Load() <= 0 { + delete(l.ipConnections, ip) + } + l.connMu.Unlock() + } + } +} + +// Returns net limiter statistics +func (l *Limiter) GetStats() map[string]any { + if l == nil { + return map[string]any{ + "enabled": false, + } + } + + l.ipMu.RLock() + activeIPs := len(l.ipLimiters) + l.ipMu.RUnlock() + + l.connMu.RLock() + totalConnections := 0 + for _, counter := range l.ipConnections { + totalConnections += int(counter.Load()) + } + l.connMu.RUnlock() + + return map[string]any{ + "enabled": true, + "total_requests": l.totalRequests.Load(), + "blocked_requests": l.blockedRequests.Load(), + "active_ips": activeIPs, + "total_connections": totalConnections, + "config": map[string]any{ + "requests_per_second": l.config.RequestsPerSecond, + "burst_size": l.config.BurstSize, + "limit_by": l.config.LimitBy, + }, + } +} + +// Performs the actual net limit check +func (l *Limiter) checkLimit(ip string) bool { + // Maybe run cleanup + l.maybeCleanup() + + switch l.config.LimitBy { + case "global": + return l.globalLimiter.Allow() + + case "ip", "": + // Default to per-IP limiting + l.ipMu.Lock() + limiter, exists := l.ipLimiters[ip] + if !exists { + // Create new limiter for this IP + limiter = &ipLimiter{ + bucket: NewTokenBucket( + float64(l.config.BurstSize), + l.config.RequestsPerSecond, + ), + lastSeen: time.Now(), + } + l.ipLimiters[ip] = limiter + l.uniqueIPs.Add(1) + + l.logger.Debug("msg", "Created new IP limiter", + "ip", ip, + "total_ips", l.uniqueIPs.Load()) + } else { + limiter.lastSeen = time.Now() + } + l.ipMu.Unlock() + + // Check connection limit if configured + if l.config.MaxConnectionsPerIP > 0 { + l.connMu.RLock() + counter, exists := l.ipConnections[ip] + l.connMu.RUnlock() + + if exists && counter.Load() >= int32(l.config.MaxConnectionsPerIP) { + return false + } + } + + return limiter.bucket.Allow() + + default: + // Unknown limit_by value, allow by default + l.logger.Warn("msg", "Unknown limit_by value", + "limit_by", l.config.LimitBy) + return true + } +} + +// Runs cleanup if enough time has passed +func (l *Limiter) maybeCleanup() { + l.cleanupMu.Lock() + defer l.cleanupMu.Unlock() + + if time.Since(l.lastCleanup) < 30*time.Second { + return + } + + l.lastCleanup = time.Now() + go l.cleanup() +} + +// Removes stale IP limiters +func (l *Limiter) cleanup() { + staleTimeout := 5 * time.Minute + now := time.Now() + + l.ipMu.Lock() + defer l.ipMu.Unlock() + + cleaned := 0 + for ip, limiter := range l.ipLimiters { + if now.Sub(limiter.lastSeen) > staleTimeout { + delete(l.ipLimiters, ip) + cleaned++ + } + } + + if cleaned > 0 { + l.logger.Debug("msg", "Cleaned up stale IP limiters", + "cleaned", cleaned, + "remaining", len(l.ipLimiters)) + } +} + +// Runs periodic cleanup +func (l *Limiter) cleanupLoop() { + defer close(l.cleanupDone) + + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-l.ctx.Done(): + // Exit when context is cancelled + l.logger.Debug("msg", "Cleanup loop stopping", "component", "netlimit") + return + case <-ticker.C: + l.cleanup() + } + } +} \ No newline at end of file diff --git a/src/internal/ratelimit/ratelimiter.go b/src/internal/netlimit/netlimiter.go similarity index 91% rename from src/internal/ratelimit/ratelimiter.go rename to src/internal/netlimit/netlimiter.go index 6777e14..29d2990 100644 --- a/src/internal/ratelimit/ratelimiter.go +++ b/src/internal/netlimit/netlimiter.go @@ -1,12 +1,12 @@ -// FILE: src/internal/ratelimit/ratelimiter.go -package ratelimit +// FILE: src/internal/netlimit/netlimiter.go +package netlimit import ( "sync" "time" ) -// TokenBucket implements a token bucket rate limiter +// TokenBucket implements a token bucket net limiter type TokenBucket struct { capacity float64 tokens float64 diff --git a/src/internal/ratelimit/limiter.go b/src/internal/ratelimit/limiter.go index 0579cf3..8971acb 100644 --- a/src/internal/ratelimit/limiter.go +++ b/src/internal/ratelimit/limiter.go @@ -2,431 +2,114 @@ package ratelimit import ( - "context" - "net" "strings" "sync" "sync/atomic" "time" - "logwisp/src/internal/config" - "github.com/lixenwraith/log" + "logwisp/src/internal/config" + "logwisp/src/internal/source" ) -// Manages rate limiting for a transport +// Limiter enforces rate limits on log entries flowing through a pipeline. type Limiter struct { - config config.RateLimitConfig - logger *log.Logger - - // Per-IP limiters - ipLimiters map[string]*ipLimiter - ipMu sync.RWMutex - - // Global limiter for the transport - globalLimiter *TokenBucket - - // Connection tracking - ipConnections map[string]*atomic.Int32 - connMu sync.RWMutex + mu sync.Mutex + rate float64 + burst float64 + tokens float64 + lastToken time.Time + policy config.RateLimitPolicy + logger *log.Logger // Statistics - totalRequests atomic.Uint64 - blockedRequests atomic.Uint64 - uniqueIPs atomic.Uint64 - - // Cleanup - lastCleanup time.Time - cleanupMu sync.Mutex - - // Lifecycle management - ctx context.Context - cancel context.CancelFunc - cleanupDone chan struct{} + droppedCount atomic.Uint64 } -type ipLimiter struct { - bucket *TokenBucket - lastSeen time.Time - connections atomic.Int32 -} - -// Creates a new rate limiter -func New(cfg config.RateLimitConfig, logger *log.Logger) *Limiter { - if !cfg.Enabled { - return nil +// New creates a new rate limiter. If cfg.Rate is 0, it returns nil. +func New(cfg config.RateLimitConfig, logger *log.Logger) (*Limiter, error) { + if cfg.Rate <= 0 { + return nil, nil // No rate limit } - if logger == nil { - panic("ratelimit.New: logger cannot be nil") + burst := cfg.Burst + if burst <= 0 { + burst = cfg.Rate // Default burst to rate } - ctx, cancel := context.WithCancel(context.Background()) + var policy config.RateLimitPolicy + switch strings.ToLower(cfg.Policy) { + case "drop": + policy = config.PolicyDrop + default: + policy = config.PolicyPass + } l := &Limiter{ - config: cfg, - ipLimiters: make(map[string]*ipLimiter), - ipConnections: make(map[string]*atomic.Int32), - lastCleanup: time.Now(), - logger: logger, - ctx: ctx, - cancel: cancel, - cleanupDone: make(chan struct{}), + rate: cfg.Rate, + burst: burst, + tokens: burst, + lastToken: time.Now(), + policy: policy, + logger: logger, } - // Create global limiter if not using per-IP limiting - if cfg.LimitBy == "global" { - l.globalLimiter = NewTokenBucket( - float64(cfg.BurstSize), - cfg.RequestsPerSecond, - ) - } - - // Start cleanup goroutine - go l.cleanupLoop() - - l.logger.Info("msg", "Rate limiter initialized", - "component", "ratelimit", - "requests_per_second", cfg.RequestsPerSecond, - "burst_size", cfg.BurstSize, - "limit_by", cfg.LimitBy) - - return l + return l, nil } -func (l *Limiter) Shutdown() { - if l == nil { - return - } - - l.logger.Info("msg", "Shutting down rate limiter", "component", "ratelimit") - - // Cancel context to stop cleanup goroutine - l.cancel() - - // Wait for cleanup goroutine to finish - select { - case <-l.cleanupDone: - l.logger.Debug("msg", "Cleanup goroutine stopped", "component", "ratelimit") - case <-time.After(2 * time.Second): - l.logger.Warn("msg", "Cleanup goroutine shutdown timeout", "component", "ratelimit") - } -} - -// Checks if an HTTP request should be allowed -func (l *Limiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int, message string) { - if l == nil { - return true, 0, "" - } - - l.totalRequests.Add(1) - - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - // If we can't parse the IP, allow the request but log - l.logger.Warn("msg", "Failed to parse remote addr", - "component", "ratelimit", - "remote_addr", remoteAddr, - "error", err) - return true, 0, "" - } - - // Only supporting ipv4 - if !isIPv4(ip) { - // Block non-IPv4 addresses to prevent complications - l.blockedRequests.Add(1) - l.logger.Warn("msg", "Non-IPv4 address blocked", - "component", "ratelimit", - "ip", ip) - return false, 403, "IPv4 only" - } - - // Check connection limit for streaming endpoint - if l.config.MaxConnectionsPerIP > 0 { - l.connMu.RLock() - counter, exists := l.ipConnections[ip] - l.connMu.RUnlock() - - if exists && counter.Load() >= int32(l.config.MaxConnectionsPerIP) { - l.blockedRequests.Add(1) - statusCode = l.config.ResponseCode - if statusCode == 0 { - statusCode = 429 - } - message = "Connection limit exceeded" - - l.logger.Warn("msg", "Connection limit exceeded", - "component", "ratelimit", - "ip", ip, - "connections", counter.Load(), - "limit", l.config.MaxConnectionsPerIP) - - return false, statusCode, message - } - } - - // Check rate limit - allowed = l.checkLimit(ip) - if !allowed { - l.blockedRequests.Add(1) - statusCode = l.config.ResponseCode - if statusCode == 0 { - statusCode = 429 - } - message = l.config.ResponseMessage - if message == "" { - message = "Rate limit exceeded" - } - l.logger.Debug("msg", "Request rate limited", "ip", ip) - } - - return allowed, statusCode, message -} - -// Checks if a TCP connection should be allowed -func (l *Limiter) CheckTCP(remoteAddr net.Addr) bool { - if l == nil { +// Allow checks if a log entry is allowed to pass based on the rate limit. +// It returns true if the entry should pass, false if it should be dropped. +func (l *Limiter) Allow(entry source.LogEntry) bool { + if l.policy == config.PolicyPass { return true } - l.totalRequests.Add(1) + l.mu.Lock() + defer l.mu.Unlock() - // Extract IP from TCP addr - tcpAddr, ok := remoteAddr.(*net.TCPAddr) - if !ok { - return true - } - - ip := tcpAddr.IP.String() - - // Only supporting ipv4 - if !isIPv4(ip) { - l.blockedRequests.Add(1) - l.logger.Warn("msg", "Non-IPv4 TCP connection blocked", - "component", "ratelimit", - "ip", ip) - return false - } - - allowed := l.checkLimit(ip) - if !allowed { - l.blockedRequests.Add(1) - l.logger.Debug("msg", "TCP connection rate limited", "ip", ip) - } - - return allowed -} - -func isIPv4(ip string) bool { - // Simple check: IPv4 addresses contain dots, IPv6 contain colons - return strings.Contains(ip, ".") && !strings.Contains(ip, ":") -} - -// Tracks a new connection for an IP -func (l *Limiter) AddConnection(remoteAddr string) { - if l == nil { - return - } - - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return - } - - // Only supporting ipv4 - if !isIPv4(ip) { - return - } - - l.connMu.Lock() - counter, exists := l.ipConnections[ip] - if !exists { - counter = &atomic.Int32{} - l.ipConnections[ip] = counter - } - l.connMu.Unlock() - - newCount := counter.Add(1) - l.logger.Debug("msg", "Connection added", - "ip", ip, - "connections", newCount) -} - -// Removes a connection for an IP -func (l *Limiter) RemoveConnection(remoteAddr string) { - if l == nil { - return - } - - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return - } - - // Only supporting ipv4 - if !isIPv4(ip) { - return - } - - l.connMu.RLock() - counter, exists := l.ipConnections[ip] - l.connMu.RUnlock() - - if exists { - newCount := counter.Add(-1) - l.logger.Debug("msg", "Connection removed", - "ip", ip, - "connections", newCount) - - if newCount <= 0 { - // Clean up if no more connections - l.connMu.Lock() - if counter.Load() <= 0 { - delete(l.ipConnections, ip) - } - l.connMu.Unlock() - } - } -} - -// Returns rate limiter statistics -func (l *Limiter) GetStats() map[string]any { - if l == nil { - return map[string]any{ - "enabled": false, - } - } - - l.ipMu.RLock() - activeIPs := len(l.ipLimiters) - l.ipMu.RUnlock() - - l.connMu.RLock() - totalConnections := 0 - for _, counter := range l.ipConnections { - totalConnections += int(counter.Load()) - } - l.connMu.RUnlock() - - return map[string]any{ - "enabled": true, - "total_requests": l.totalRequests.Load(), - "blocked_requests": l.blockedRequests.Load(), - "active_ips": activeIPs, - "total_connections": totalConnections, - "config": map[string]any{ - "requests_per_second": l.config.RequestsPerSecond, - "burst_size": l.config.BurstSize, - "limit_by": l.config.LimitBy, - }, - } -} - -// Performs the actual rate limit check -func (l *Limiter) checkLimit(ip string) bool { - // Maybe run cleanup - l.maybeCleanup() - - switch l.config.LimitBy { - case "global": - return l.globalLimiter.Allow() - - case "ip", "": - // Default to per-IP limiting - l.ipMu.Lock() - limiter, exists := l.ipLimiters[ip] - if !exists { - // Create new limiter for this IP - limiter = &ipLimiter{ - bucket: NewTokenBucket( - float64(l.config.BurstSize), - l.config.RequestsPerSecond, - ), - lastSeen: time.Now(), - } - l.ipLimiters[ip] = limiter - l.uniqueIPs.Add(1) - - l.logger.Debug("msg", "Created new IP limiter", - "ip", ip, - "total_ips", l.uniqueIPs.Load()) - } else { - limiter.lastSeen = time.Now() - } - l.ipMu.Unlock() - - // Check connection limit if configured - if l.config.MaxConnectionsPerIP > 0 { - l.connMu.RLock() - counter, exists := l.ipConnections[ip] - l.connMu.RUnlock() - - if exists && counter.Load() >= int32(l.config.MaxConnectionsPerIP) { - return false - } - } - - return limiter.bucket.Allow() - - default: - // Unknown limit_by value, allow by default - l.logger.Warn("msg", "Unknown limit_by value", - "limit_by", l.config.LimitBy) - return true - } -} - -// Runs cleanup if enough time has passed -func (l *Limiter) maybeCleanup() { - l.cleanupMu.Lock() - defer l.cleanupMu.Unlock() - - if time.Since(l.lastCleanup) < 30*time.Second { - return - } - - l.lastCleanup = time.Now() - go l.cleanup() -} - -// Removes stale IP limiters -func (l *Limiter) cleanup() { - staleTimeout := 5 * time.Minute now := time.Now() + elapsed := now.Sub(l.lastToken).Seconds() - l.ipMu.Lock() - defer l.ipMu.Unlock() - - cleaned := 0 - for ip, limiter := range l.ipLimiters { - if now.Sub(limiter.lastSeen) > staleTimeout { - delete(l.ipLimiters, ip) - cleaned++ - } + if elapsed < 0 { + // Clock went backwards, don't add tokens + l.lastToken = now + elapsed = 0 } - if cleaned > 0 { - l.logger.Debug("msg", "Cleaned up stale IP limiters", - "cleaned", cleaned, - "remaining", len(l.ipLimiters)) + l.tokens += elapsed * l.rate + if l.tokens > l.burst { + l.tokens = l.burst + } + l.lastToken = now + + if l.tokens >= 1 { + l.tokens-- + return true + } + + // Not enough tokens, drop the entry + l.droppedCount.Add(1) + return false +} + +// GetStats returns the statistics for the limiter. +func (l *Limiter) GetStats() map[string]any { + return map[string]any{ + "dropped_total": l.droppedCount.Load(), + "policy": policyString(l.policy), + "rate": l.rate, + "burst": l.burst, } } -// Runs periodic cleanup -func (l *Limiter) cleanupLoop() { - defer close(l.cleanupDone) - - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-l.ctx.Done(): - // Exit when context is cancelled - l.logger.Debug("msg", "Cleanup loop stopping", "component", "ratelimit") - return - case <-ticker.C: - l.cleanup() - } +// policyString returns the string representation of the policy. +func policyString(p config.RateLimitPolicy) string { + switch p { + case config.PolicyDrop: + return "drop" + case config.PolicyPass: + return "pass" + default: + return "unknown" } } \ No newline at end of file diff --git a/src/internal/service/pipeline.go b/src/internal/service/pipeline.go index 3ac45b3..a77a51e 100644 --- a/src/internal/service/pipeline.go +++ b/src/internal/service/pipeline.go @@ -3,6 +3,7 @@ package service import ( "context" + "logwisp/src/internal/ratelimit" "sync" "sync/atomic" "time" @@ -20,6 +21,7 @@ type Pipeline struct { Name string Config config.PipelineConfig Sources []source.Source + RateLimiter *ratelimit.Limiter FilterChain *filter.Chain Sinks []sink.Sink Stats *PipelineStats @@ -36,12 +38,13 @@ type Pipeline struct { // PipelineStats contains statistics for a pipeline type PipelineStats struct { - StartTime time.Time - TotalEntriesProcessed atomic.Uint64 - TotalEntriesFiltered atomic.Uint64 - SourceStats []source.SourceStats - SinkStats []sink.SinkStats - FilterStats map[string]any + StartTime time.Time + TotalEntriesProcessed atomic.Uint64 + TotalEntriesDroppedByRateLimit atomic.Uint64 + TotalEntriesFiltered atomic.Uint64 + SourceStats []source.SourceStats + SinkStats []sink.SinkStats + FilterStats map[string]any } // Shutdown gracefully stops the pipeline @@ -112,6 +115,18 @@ func (p *Pipeline) GetStats() map[string]any { }) } + // Collect rate limit stats + var rateLimitStats map[string]any + if p.RateLimiter != nil { + rateLimitStats = p.RateLimiter.GetStats() + } + + // Collect filter stats + var filterStats map[string]any + if p.FilterChain != nil { + filterStats = p.FilterChain.GetStats() + } + // Collect sink stats sinkStats := make([]map[string]any, 0, len(p.Sinks)) for _, s := range p.Sinks { @@ -130,23 +145,19 @@ func (p *Pipeline) GetStats() map[string]any { }) } - // Collect filter stats - var filterStats map[string]any - if p.FilterChain != nil { - filterStats = p.FilterChain.GetStats() - } - return map[string]any{ - "name": p.Name, - "uptime_seconds": int(time.Since(p.Stats.StartTime).Seconds()), - "total_processed": p.Stats.TotalEntriesProcessed.Load(), - "total_filtered": p.Stats.TotalEntriesFiltered.Load(), - "sources": sourceStats, - "sinks": sinkStats, - "filters": filterStats, - "source_count": len(p.Sources), - "sink_count": len(p.Sinks), - "filter_count": len(p.Config.Filters), + "name": p.Name, + "uptime_seconds": int(time.Since(p.Stats.StartTime).Seconds()), + "total_processed": p.Stats.TotalEntriesProcessed.Load(), + "total_dropped_rate_limit": p.Stats.TotalEntriesDroppedByRateLimit.Load(), + "total_filtered": p.Stats.TotalEntriesFiltered.Load(), + "sources": sourceStats, + "rate_limiter": rateLimitStats, + "sinks": sinkStats, + "filters": filterStats, + "source_count": len(p.Sources), + "sink_count": len(p.Sinks), + "filter_count": len(p.Config.Filters), } } diff --git a/src/internal/service/service.go b/src/internal/service/service.go index d9f4b3a..24f55e9 100644 --- a/src/internal/service/service.go +++ b/src/internal/service/service.go @@ -4,6 +4,7 @@ package service import ( "context" "fmt" + "logwisp/src/internal/ratelimit" "sync" "time" @@ -77,6 +78,16 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error { pipeline.Sources = append(pipeline.Sources, src) } + // Create pipeline rate limiter + if cfg.RateLimit != nil { + limiter, err := ratelimit.New(*cfg.RateLimit, s.logger) + if err != nil { + pipelineCancel() + return fmt.Errorf("failed to create pipeline rate limiter: %w", err) + } + pipeline.RateLimiter = limiter + } + // Create filter chain if len(cfg.Filters) > 0 { chain, err := filter.NewChain(cfg.Filters, s.logger) @@ -175,6 +186,14 @@ func (s *Service) wirePipeline(p *Pipeline) { p.Stats.TotalEntriesProcessed.Add(1) + // Apply pipeline rate limiter + if p.RateLimiter != nil { + if !p.RateLimiter.Allow(entry) { + p.Stats.TotalEntriesDroppedByRateLimit.Add(1) + continue // Drop the entry + } + } + // Apply filters if configured if p.FilterChain != nil { if !p.FilterChain.Apply(entry) { diff --git a/src/internal/sink/http.go b/src/internal/sink/http.go index f5885bf..0695ef7 100644 --- a/src/internal/sink/http.go +++ b/src/internal/sink/http.go @@ -12,7 +12,7 @@ import ( "time" "logwisp/src/internal/config" - "logwisp/src/internal/ratelimit" + "logwisp/src/internal/netlimit" "logwisp/src/internal/source" "logwisp/src/internal/version" @@ -40,8 +40,8 @@ type HTTPSink struct { // For router integration standalone bool - // Rate limiting - rateLimiter *ratelimit.Limiter + // Net limiting + netLimiter *netlimit.Limiter // Statistics totalProcessed atomic.Uint64 @@ -56,7 +56,7 @@ type HTTPConfig struct { StatusPath string Heartbeat config.HeartbeatConfig SSL *config.SSLConfig - RateLimit *config.RateLimitConfig + NetLimit *config.NetLimitConfig } // NewHTTPSink creates a new HTTP streaming sink @@ -95,30 +95,30 @@ func NewHTTPSink(options map[string]any, logger *log.Logger) (*HTTPSink, error) } } - // Extract rate limit config - if rl, ok := options["rate_limit"].(map[string]any); ok { - cfg.RateLimit = &config.RateLimitConfig{} - cfg.RateLimit.Enabled, _ = rl["enabled"].(bool) + // Extract net limit config + if rl, ok := options["net_limit"].(map[string]any); ok { + cfg.NetLimit = &config.NetLimitConfig{} + cfg.NetLimit.Enabled, _ = rl["enabled"].(bool) if rps, ok := toFloat(rl["requests_per_second"]); ok { - cfg.RateLimit.RequestsPerSecond = rps + cfg.NetLimit.RequestsPerSecond = rps } if burst, ok := toInt(rl["burst_size"]); ok { - cfg.RateLimit.BurstSize = burst + cfg.NetLimit.BurstSize = burst } if limitBy, ok := rl["limit_by"].(string); ok { - cfg.RateLimit.LimitBy = limitBy + cfg.NetLimit.LimitBy = limitBy } if respCode, ok := toInt(rl["response_code"]); ok { - cfg.RateLimit.ResponseCode = respCode + cfg.NetLimit.ResponseCode = respCode } if msg, ok := rl["response_message"].(string); ok { - cfg.RateLimit.ResponseMessage = msg + cfg.NetLimit.ResponseMessage = msg } if maxPerIP, ok := toInt(rl["max_connections_per_ip"]); ok { - cfg.RateLimit.MaxConnectionsPerIP = maxPerIP + cfg.NetLimit.MaxConnectionsPerIP = maxPerIP } if maxTotal, ok := toInt(rl["max_total_connections"]); ok { - cfg.RateLimit.MaxTotalConnections = maxTotal + cfg.NetLimit.MaxTotalConnections = maxTotal } } @@ -134,9 +134,9 @@ func NewHTTPSink(options map[string]any, logger *log.Logger) (*HTTPSink, error) } h.lastProcessed.Store(time.Time{}) - // Initialize rate limiter if configured - if cfg.RateLimit != nil && cfg.RateLimit.Enabled { - h.rateLimiter = ratelimit.New(*cfg.RateLimit, logger) + // Initialize net limiter if configured + if cfg.NetLimit != nil && cfg.NetLimit.Enabled { + h.netLimiter = netlimit.New(*cfg.NetLimit, logger) } return h, nil @@ -212,9 +212,9 @@ func (h *HTTPSink) Stop() { func (h *HTTPSink) GetStats() SinkStats { lastProc, _ := h.lastProcessed.Load().(time.Time) - var rateLimitStats map[string]any - if h.rateLimiter != nil { - rateLimitStats = h.rateLimiter.GetStats() + var netLimitStats map[string]any + if h.netLimiter != nil { + netLimitStats = h.netLimiter.GetStats() } return SinkStats{ @@ -230,7 +230,7 @@ func (h *HTTPSink) GetStats() SinkStats { "stream": h.streamPath, "status": h.statusPath, }, - "rate_limit": rateLimitStats, + "net_limit": netLimitStats, }, } } @@ -248,9 +248,9 @@ func (h *HTTPSink) RouteRequest(ctx *fasthttp.RequestCtx) { } func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { - // Check rate limit first + // Check net limit first remoteAddr := ctx.RemoteAddr().String() - if allowed, statusCode, message := h.rateLimiter.CheckHTTP(remoteAddr); !allowed { + if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed { ctx.SetStatusCode(statusCode) ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]any{ @@ -279,11 +279,11 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { } func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) { - // Track connection for rate limiting + // Track connection for net limiting remoteAddr := ctx.RemoteAddr().String() - if h.rateLimiter != nil { - h.rateLimiter.AddConnection(remoteAddr) - defer h.rateLimiter.RemoveConnection(remoteAddr) + if h.netLimiter != nil { + h.netLimiter.AddConnection(remoteAddr) + defer h.netLimiter.RemoveConnection(remoteAddr) } // Set SSE headers @@ -450,11 +450,11 @@ func (h *HTTPSink) formatHeartbeat() string { func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { ctx.SetContentType("application/json") - var rateLimitStats any - if h.rateLimiter != nil { - rateLimitStats = h.rateLimiter.GetStats() + var netLimitStats any + if h.netLimiter != nil { + netLimitStats = h.netLimiter.GetStats() } else { - rateLimitStats = map[string]any{ + netLimitStats = map[string]any{ "enabled": false, } } @@ -483,7 +483,7 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { "ssl": map[string]bool{ "enabled": h.config.SSL != nil && h.config.SSL.Enabled, }, - "rate_limit": rateLimitStats, + "net_limit": netLimitStats, }, } diff --git a/src/internal/sink/tcp.go b/src/internal/sink/tcp.go index 8edfa08..50fa0e2 100644 --- a/src/internal/sink/tcp.go +++ b/src/internal/sink/tcp.go @@ -11,7 +11,7 @@ import ( "time" "logwisp/src/internal/config" - "logwisp/src/internal/ratelimit" + "logwisp/src/internal/netlimit" "logwisp/src/internal/source" "github.com/lixenwraith/log" @@ -29,7 +29,7 @@ type TCPSink struct { engine *gnet.Engine engineMu sync.Mutex wg sync.WaitGroup - rateLimiter *ratelimit.Limiter + netLimiter *netlimit.Limiter logger *log.Logger // Statistics @@ -43,7 +43,7 @@ type TCPConfig struct { BufferSize int Heartbeat config.HeartbeatConfig SSL *config.SSLConfig - RateLimit *config.RateLimitConfig + NetLimit *config.NetLimitConfig } // NewTCPSink creates a new TCP streaming sink @@ -74,30 +74,30 @@ func NewTCPSink(options map[string]any, logger *log.Logger) (*TCPSink, error) { } } - // Extract rate limit config - if rl, ok := options["rate_limit"].(map[string]any); ok { - cfg.RateLimit = &config.RateLimitConfig{} - cfg.RateLimit.Enabled, _ = rl["enabled"].(bool) + // Extract net limit config + if rl, ok := options["net_limit"].(map[string]any); ok { + cfg.NetLimit = &config.NetLimitConfig{} + cfg.NetLimit.Enabled, _ = rl["enabled"].(bool) if rps, ok := toFloat(rl["requests_per_second"]); ok { - cfg.RateLimit.RequestsPerSecond = rps + cfg.NetLimit.RequestsPerSecond = rps } if burst, ok := toInt(rl["burst_size"]); ok { - cfg.RateLimit.BurstSize = burst + cfg.NetLimit.BurstSize = burst } if limitBy, ok := rl["limit_by"].(string); ok { - cfg.RateLimit.LimitBy = limitBy + cfg.NetLimit.LimitBy = limitBy } if respCode, ok := toInt(rl["response_code"]); ok { - cfg.RateLimit.ResponseCode = respCode + cfg.NetLimit.ResponseCode = respCode } if msg, ok := rl["response_message"].(string); ok { - cfg.RateLimit.ResponseMessage = msg + cfg.NetLimit.ResponseMessage = msg } if maxPerIP, ok := toInt(rl["max_connections_per_ip"]); ok { - cfg.RateLimit.MaxConnectionsPerIP = maxPerIP + cfg.NetLimit.MaxConnectionsPerIP = maxPerIP } if maxTotal, ok := toInt(rl["max_total_connections"]); ok { - cfg.RateLimit.MaxTotalConnections = maxTotal + cfg.NetLimit.MaxTotalConnections = maxTotal } } @@ -110,8 +110,8 @@ func NewTCPSink(options map[string]any, logger *log.Logger) (*TCPSink, error) { } t.lastProcessed.Store(time.Time{}) - if cfg.RateLimit != nil && cfg.RateLimit.Enabled { - t.rateLimiter = ratelimit.New(*cfg.RateLimit, logger) + if cfg.NetLimit != nil && cfg.NetLimit.Enabled { + t.netLimiter = netlimit.New(*cfg.NetLimit, logger) } return t, nil @@ -194,9 +194,9 @@ func (t *TCPSink) Stop() { func (t *TCPSink) GetStats() SinkStats { lastProc, _ := t.lastProcessed.Load().(time.Time) - var rateLimitStats map[string]any - if t.rateLimiter != nil { - rateLimitStats = t.rateLimiter.GetStats() + var netLimitStats map[string]any + if t.netLimiter != nil { + netLimitStats = t.netLimiter.GetStats() } return SinkStats{ @@ -208,7 +208,7 @@ func (t *TCPSink) GetStats() SinkStats { Details: map[string]any{ "port": t.config.Port, "buffer_size": t.config.BufferSize, - "rate_limit": rateLimitStats, + "net_limit": netLimitStats, }, } } @@ -313,8 +313,8 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { remoteAddr := c.RemoteAddr().String() s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr) - // Check rate limit - if s.sink.rateLimiter != nil { + // Check net limit + if s.sink.netLimiter != nil { // Parse the remote address to get proper net.Addr remoteStr := c.RemoteAddr().String() tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr) @@ -325,15 +325,15 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { return nil, gnet.Close } - if !s.sink.rateLimiter.CheckTCP(tcpAddr) { - s.sink.logger.Warn("msg", "TCP connection rate limited", + if !s.sink.netLimiter.CheckTCP(tcpAddr) { + s.sink.logger.Warn("msg", "TCP connection net limited", "remote_addr", remoteAddr) - // Silently close connection when rate limited + // Silently close connection when net limited return nil, gnet.Close } // Track connection - s.sink.rateLimiter.AddConnection(remoteStr) + s.sink.netLimiter.AddConnection(remoteStr) } s.connections.Store(c, struct{}{}) @@ -352,8 +352,8 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action { remoteAddr := c.RemoteAddr().String() // Remove connection tracking - if s.sink.rateLimiter != nil { - s.sink.rateLimiter.RemoveConnection(c.RemoteAddr().String()) + if s.sink.netLimiter != nil { + s.sink.netLimiter.RemoveConnection(c.RemoteAddr().String()) } newCount := s.sink.activeConns.Add(-1) diff --git a/src/internal/source/directory.go b/src/internal/source/directory.go index 80d1a90..6707a52 100644 --- a/src/internal/source/directory.go +++ b/src/internal/source/directory.go @@ -144,19 +144,7 @@ func (ds *DirectorySource) GetStats() SourceStats { } } -func (ds *DirectorySource) ApplyRateLimit(entry LogEntry) (LogEntry, bool) { - // TODO: Implement source-side rate limiting for aggregation/summarization - // For now, just pass through unchanged - return entry, true -} - func (ds *DirectorySource) publish(entry LogEntry) { - // Apply rate limiting (placeholder for now) - entry, allowed := ds.ApplyRateLimit(entry) - if !allowed { - return - } - ds.mu.RLock() defer ds.mu.RUnlock() diff --git a/src/internal/source/http.go b/src/internal/source/http.go index 55f6f06..854619f 100644 --- a/src/internal/source/http.go +++ b/src/internal/source/http.go @@ -9,7 +9,7 @@ import ( "time" "logwisp/src/internal/config" - "logwisp/src/internal/ratelimit" + "logwisp/src/internal/netlimit" "github.com/lixenwraith/log" "github.com/valyala/fasthttp" @@ -25,7 +25,7 @@ type HTTPSource struct { mu sync.RWMutex done chan struct{} wg sync.WaitGroup - rateLimiter *ratelimit.Limiter + netLimiter *netlimit.Limiter logger *log.Logger // Statistics @@ -63,10 +63,10 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err } h.lastEntryTime.Store(time.Time{}) - // Initialize rate limiter if configured - if rl, ok := options["rate_limit"].(map[string]any); ok { + // Initialize net limiter if configured + if rl, ok := options["net_limit"].(map[string]any); ok { if enabled, _ := rl["enabled"].(bool); enabled { - cfg := config.RateLimitConfig{ + cfg := config.NetLimitConfig{ Enabled: true, } @@ -89,7 +89,7 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err cfg.MaxConnectionsPerIP = maxPerIP } - h.rateLimiter = ratelimit.New(cfg, logger) + h.netLimiter = netlimit.New(cfg, logger) } } @@ -149,9 +149,9 @@ func (h *HTTPSource) Stop() { } } - // Shutdown rate limiter - if h.rateLimiter != nil { - h.rateLimiter.Shutdown() + // Shutdown net limiter + if h.netLimiter != nil { + h.netLimiter.Shutdown() } h.wg.Wait() @@ -169,9 +169,9 @@ func (h *HTTPSource) Stop() { func (h *HTTPSource) GetStats() SourceStats { lastEntry, _ := h.lastEntryTime.Load().(time.Time) - var rateLimitStats map[string]any - if h.rateLimiter != nil { - rateLimitStats = h.rateLimiter.GetStats() + var netLimitStats map[string]any + if h.netLimiter != nil { + netLimitStats = h.netLimiter.GetStats() } return SourceStats{ @@ -184,16 +184,11 @@ func (h *HTTPSource) GetStats() SourceStats { "port": h.port, "ingest_path": h.ingestPath, "invalid_entries": h.invalidEntries.Load(), - "rate_limit": rateLimitStats, + "net_limit": netLimitStats, }, } } -func (h *HTTPSource) ApplyRateLimit(entry LogEntry) (LogEntry, bool) { - // TODO: Implement source-side rate limiting for aggregation/summarization - return entry, true -} - func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { // Only handle POST to the configured ingest path if string(ctx.Method()) != "POST" || string(ctx.Path()) != h.ingestPath { @@ -206,10 +201,10 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { return } - // Check rate limit + // Check net limit remoteAddr := ctx.RemoteAddr().String() - if h.rateLimiter != nil { - if allowed, statusCode, message := h.rateLimiter.CheckHTTP(remoteAddr); !allowed { + if h.netLimiter != nil { + if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed { ctx.SetStatusCode(statusCode) ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]any{ @@ -330,12 +325,6 @@ func (h *HTTPSource) parseEntries(body []byte) ([]LogEntry, error) { } func (h *HTTPSource) publish(entry LogEntry) bool { - // Apply rate limiting - entry, allowed := h.ApplyRateLimit(entry) - if !allowed { - return false - } - h.mu.RLock() defer h.mu.RUnlock() diff --git a/src/internal/source/source.go b/src/internal/source/source.go index 2a57953..657e3ff 100644 --- a/src/internal/source/source.go +++ b/src/internal/source/source.go @@ -28,11 +28,6 @@ type Source interface { // GetStats returns source statistics GetStats() SourceStats - - // ApplyRateLimit applies source-side rate limiting - // TODO: This is a placeholder for future features like aggregation and summarization - // Currently just returns the entry unchanged - ApplyRateLimit(entry LogEntry) (LogEntry, bool) } // SourceStats contains statistics about a source diff --git a/src/internal/source/stdin.go b/src/internal/source/stdin.go index ed47625..8ae9ef0 100644 --- a/src/internal/source/stdin.go +++ b/src/internal/source/stdin.go @@ -65,12 +65,6 @@ func (s *StdinSource) GetStats() SourceStats { } } -func (s *StdinSource) ApplyRateLimit(entry LogEntry) (LogEntry, bool) { - // TODO: Implement source-side rate limiting for aggregation/summarization - // For now, just pass through unchanged - return entry, true -} - func (s *StdinSource) readLoop() { scanner := bufio.NewScanner(os.Stdin) for scanner.Scan() { @@ -90,12 +84,6 @@ func (s *StdinSource) readLoop() { Level: extractLogLevel(line), } - // Apply rate limiting - entry, allowed := s.ApplyRateLimit(entry) - if !allowed { - continue - } - s.publish(entry) } } diff --git a/src/internal/source/tcp.go b/src/internal/source/tcp.go index c8809eb..ed72857 100644 --- a/src/internal/source/tcp.go +++ b/src/internal/source/tcp.go @@ -12,7 +12,7 @@ import ( "time" "logwisp/src/internal/config" - "logwisp/src/internal/ratelimit" + "logwisp/src/internal/netlimit" "github.com/lixenwraith/log" "github.com/panjf2000/gnet/v2" @@ -29,7 +29,7 @@ type TCPSource struct { engine *gnet.Engine engineMu sync.Mutex wg sync.WaitGroup - rateLimiter *ratelimit.Limiter + netLimiter *netlimit.Limiter logger *log.Logger // Statistics @@ -62,10 +62,10 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error } t.lastEntryTime.Store(time.Time{}) - // Initialize rate limiter if configured - if rl, ok := options["rate_limit"].(map[string]any); ok { + // Initialize net limiter if configured + if rl, ok := options["net_limit"].(map[string]any); ok { if enabled, _ := rl["enabled"].(bool); enabled { - cfg := config.RateLimitConfig{ + cfg := config.NetLimitConfig{ Enabled: true, } @@ -85,7 +85,7 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error cfg.MaxTotalConnections = maxTotal } - t.rateLimiter = ratelimit.New(cfg, logger) + t.netLimiter = netlimit.New(cfg, logger) } } @@ -150,9 +150,9 @@ func (t *TCPSource) Stop() { (*engine).Stop(ctx) } - // Shutdown rate limiter - if t.rateLimiter != nil { - t.rateLimiter.Shutdown() + // Shutdown net limiter + if t.netLimiter != nil { + t.netLimiter.Shutdown() } t.wg.Wait() @@ -170,9 +170,9 @@ func (t *TCPSource) Stop() { func (t *TCPSource) GetStats() SourceStats { lastEntry, _ := t.lastEntryTime.Load().(time.Time) - var rateLimitStats map[string]any - if t.rateLimiter != nil { - rateLimitStats = t.rateLimiter.GetStats() + var netLimitStats map[string]any + if t.netLimiter != nil { + netLimitStats = t.netLimiter.GetStats() } return SourceStats{ @@ -185,23 +185,12 @@ func (t *TCPSource) GetStats() SourceStats { "port": t.port, "active_connections": t.activeConns.Load(), "invalid_entries": t.invalidEntries.Load(), - "rate_limit": rateLimitStats, + "net_limit": netLimitStats, }, } } -func (t *TCPSource) ApplyRateLimit(entry LogEntry) (LogEntry, bool) { - // TODO: Implement source-side rate limiting for aggregation/summarization - return entry, true -} - func (t *TCPSource) publish(entry LogEntry) bool { - // Apply rate limiting - entry, allowed := t.ApplyRateLimit(entry) - if !allowed { - return false - } - t.mu.RLock() defer t.mu.RUnlock() @@ -258,8 +247,8 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { "component", "tcp_source", "remote_addr", remoteAddr) - // Check rate limit - if s.source.rateLimiter != nil { + // Check net limit + if s.source.netLimiter != nil { remoteStr := c.RemoteAddr().String() tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr) if err != nil { @@ -270,15 +259,15 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { return nil, gnet.Close } - if !s.source.rateLimiter.CheckTCP(tcpAddr) { - s.source.logger.Warn("msg", "TCP connection rate limited", + if !s.source.netLimiter.CheckTCP(tcpAddr) { + s.source.logger.Warn("msg", "TCP connection net limited", "component", "tcp_source", "remote_addr", remoteAddr) return nil, gnet.Close } // Track connection - s.source.rateLimiter.AddConnection(remoteStr) + s.source.netLimiter.AddConnection(remoteStr) } // Create client state @@ -304,8 +293,8 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action { s.mu.Unlock() // Remove connection tracking - if s.source.rateLimiter != nil { - s.source.rateLimiter.RemoveConnection(remoteAddr) + if s.source.netLimiter != nil { + s.source.netLimiter.RemoveConnection(remoteAddr) } newCount := s.source.activeConns.Add(-1)