v0.3.3 pipeline rate limiter added

This commit is contained in:
2025-07-13 03:20:47 -04:00
parent 0accb5f2d3
commit cc27f5cc1c
17 changed files with 742 additions and 588 deletions

View File

@ -0,0 +1,432 @@
// FILE: src/internal/netlimit/limiter.go
package netlimit
import (
"context"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"logwisp/src/internal/config"
"github.com/lixenwraith/log"
)
// Manages net limiting for a transport
type Limiter struct {
config config.NetLimitConfig
logger *log.Logger
// Per-IP limiters
ipLimiters map[string]*ipLimiter
ipMu sync.RWMutex
// Global limiter for the transport
globalLimiter *TokenBucket
// Connection tracking
ipConnections map[string]*atomic.Int32
connMu sync.RWMutex
// Statistics
totalRequests atomic.Uint64
blockedRequests atomic.Uint64
uniqueIPs atomic.Uint64
// Cleanup
lastCleanup time.Time
cleanupMu sync.Mutex
// Lifecycle management
ctx context.Context
cancel context.CancelFunc
cleanupDone chan struct{}
}
type ipLimiter struct {
bucket *TokenBucket
lastSeen time.Time
connections atomic.Int32
}
// Creates a new net limiter
func New(cfg config.NetLimitConfig, logger *log.Logger) *Limiter {
if !cfg.Enabled {
return nil
}
if logger == nil {
panic("netlimit.New: logger cannot be nil")
}
ctx, cancel := context.WithCancel(context.Background())
l := &Limiter{
config: cfg,
ipLimiters: make(map[string]*ipLimiter),
ipConnections: make(map[string]*atomic.Int32),
lastCleanup: time.Now(),
logger: logger,
ctx: ctx,
cancel: cancel,
cleanupDone: make(chan struct{}),
}
// Create global limiter if not using per-IP limiting
if cfg.LimitBy == "global" {
l.globalLimiter = NewTokenBucket(
float64(cfg.BurstSize),
cfg.RequestsPerSecond,
)
}
// Start cleanup goroutine
go l.cleanupLoop()
l.logger.Info("msg", "Net limiter initialized",
"component", "netlimit",
"requests_per_second", cfg.RequestsPerSecond,
"burst_size", cfg.BurstSize,
"limit_by", cfg.LimitBy)
return l
}
func (l *Limiter) Shutdown() {
if l == nil {
return
}
l.logger.Info("msg", "Shutting down net limiter", "component", "netlimit")
// Cancel context to stop cleanup goroutine
l.cancel()
// Wait for cleanup goroutine to finish
select {
case <-l.cleanupDone:
l.logger.Debug("msg", "Cleanup goroutine stopped", "component", "netlimit")
case <-time.After(2 * time.Second):
l.logger.Warn("msg", "Cleanup goroutine shutdown timeout", "component", "netlimit")
}
}
// Checks if an HTTP request should be allowed
func (l *Limiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int, message string) {
if l == nil {
return true, 0, ""
}
l.totalRequests.Add(1)
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
// If we can't parse the IP, allow the request but log
l.logger.Warn("msg", "Failed to parse remote addr",
"component", "netlimit",
"remote_addr", remoteAddr,
"error", err)
return true, 0, ""
}
// Only supporting ipv4
if !isIPv4(ip) {
// Block non-IPv4 addresses to prevent complications
l.blockedRequests.Add(1)
l.logger.Warn("msg", "Non-IPv4 address blocked",
"component", "netlimit",
"ip", ip)
return false, 403, "IPv4 only"
}
// Check connection limit for streaming endpoint
if l.config.MaxConnectionsPerIP > 0 {
l.connMu.RLock()
counter, exists := l.ipConnections[ip]
l.connMu.RUnlock()
if exists && counter.Load() >= int32(l.config.MaxConnectionsPerIP) {
l.blockedRequests.Add(1)
statusCode = l.config.ResponseCode
if statusCode == 0 {
statusCode = 429
}
message = "Connection limit exceeded"
l.logger.Warn("msg", "Connection limit exceeded",
"component", "netlimit",
"ip", ip,
"connections", counter.Load(),
"limit", l.config.MaxConnectionsPerIP)
return false, statusCode, message
}
}
// Check net limit
allowed = l.checkLimit(ip)
if !allowed {
l.blockedRequests.Add(1)
statusCode = l.config.ResponseCode
if statusCode == 0 {
statusCode = 429
}
message = l.config.ResponseMessage
if message == "" {
message = "Net limit exceeded"
}
l.logger.Debug("msg", "Request net limited", "ip", ip)
}
return allowed, statusCode, message
}
// Checks if a TCP connection should be allowed
func (l *Limiter) CheckTCP(remoteAddr net.Addr) bool {
if l == nil {
return true
}
l.totalRequests.Add(1)
// Extract IP from TCP addr
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
if !ok {
return true
}
ip := tcpAddr.IP.String()
// Only supporting ipv4
if !isIPv4(ip) {
l.blockedRequests.Add(1)
l.logger.Warn("msg", "Non-IPv4 TCP connection blocked",
"component", "netlimit",
"ip", ip)
return false
}
allowed := l.checkLimit(ip)
if !allowed {
l.blockedRequests.Add(1)
l.logger.Debug("msg", "TCP connection net limited", "ip", ip)
}
return allowed
}
func isIPv4(ip string) bool {
// Simple check: IPv4 addresses contain dots, IPv6 contain colons
return strings.Contains(ip, ".") && !strings.Contains(ip, ":")
}
// Tracks a new connection for an IP
func (l *Limiter) AddConnection(remoteAddr string) {
if l == nil {
return
}
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return
}
// Only supporting ipv4
if !isIPv4(ip) {
return
}
l.connMu.Lock()
counter, exists := l.ipConnections[ip]
if !exists {
counter = &atomic.Int32{}
l.ipConnections[ip] = counter
}
l.connMu.Unlock()
newCount := counter.Add(1)
l.logger.Debug("msg", "Connection added",
"ip", ip,
"connections", newCount)
}
// Removes a connection for an IP
func (l *Limiter) RemoveConnection(remoteAddr string) {
if l == nil {
return
}
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return
}
// Only supporting ipv4
if !isIPv4(ip) {
return
}
l.connMu.RLock()
counter, exists := l.ipConnections[ip]
l.connMu.RUnlock()
if exists {
newCount := counter.Add(-1)
l.logger.Debug("msg", "Connection removed",
"ip", ip,
"connections", newCount)
if newCount <= 0 {
// Clean up if no more connections
l.connMu.Lock()
if counter.Load() <= 0 {
delete(l.ipConnections, ip)
}
l.connMu.Unlock()
}
}
}
// Returns net limiter statistics
func (l *Limiter) GetStats() map[string]any {
if l == nil {
return map[string]any{
"enabled": false,
}
}
l.ipMu.RLock()
activeIPs := len(l.ipLimiters)
l.ipMu.RUnlock()
l.connMu.RLock()
totalConnections := 0
for _, counter := range l.ipConnections {
totalConnections += int(counter.Load())
}
l.connMu.RUnlock()
return map[string]any{
"enabled": true,
"total_requests": l.totalRequests.Load(),
"blocked_requests": l.blockedRequests.Load(),
"active_ips": activeIPs,
"total_connections": totalConnections,
"config": map[string]any{
"requests_per_second": l.config.RequestsPerSecond,
"burst_size": l.config.BurstSize,
"limit_by": l.config.LimitBy,
},
}
}
// Performs the actual net limit check
func (l *Limiter) checkLimit(ip string) bool {
// Maybe run cleanup
l.maybeCleanup()
switch l.config.LimitBy {
case "global":
return l.globalLimiter.Allow()
case "ip", "":
// Default to per-IP limiting
l.ipMu.Lock()
limiter, exists := l.ipLimiters[ip]
if !exists {
// Create new limiter for this IP
limiter = &ipLimiter{
bucket: NewTokenBucket(
float64(l.config.BurstSize),
l.config.RequestsPerSecond,
),
lastSeen: time.Now(),
}
l.ipLimiters[ip] = limiter
l.uniqueIPs.Add(1)
l.logger.Debug("msg", "Created new IP limiter",
"ip", ip,
"total_ips", l.uniqueIPs.Load())
} else {
limiter.lastSeen = time.Now()
}
l.ipMu.Unlock()
// Check connection limit if configured
if l.config.MaxConnectionsPerIP > 0 {
l.connMu.RLock()
counter, exists := l.ipConnections[ip]
l.connMu.RUnlock()
if exists && counter.Load() >= int32(l.config.MaxConnectionsPerIP) {
return false
}
}
return limiter.bucket.Allow()
default:
// Unknown limit_by value, allow by default
l.logger.Warn("msg", "Unknown limit_by value",
"limit_by", l.config.LimitBy)
return true
}
}
// Runs cleanup if enough time has passed
func (l *Limiter) maybeCleanup() {
l.cleanupMu.Lock()
defer l.cleanupMu.Unlock()
if time.Since(l.lastCleanup) < 30*time.Second {
return
}
l.lastCleanup = time.Now()
go l.cleanup()
}
// Removes stale IP limiters
func (l *Limiter) cleanup() {
staleTimeout := 5 * time.Minute
now := time.Now()
l.ipMu.Lock()
defer l.ipMu.Unlock()
cleaned := 0
for ip, limiter := range l.ipLimiters {
if now.Sub(limiter.lastSeen) > staleTimeout {
delete(l.ipLimiters, ip)
cleaned++
}
}
if cleaned > 0 {
l.logger.Debug("msg", "Cleaned up stale IP limiters",
"cleaned", cleaned,
"remaining", len(l.ipLimiters))
}
}
// Runs periodic cleanup
func (l *Limiter) cleanupLoop() {
defer close(l.cleanupDone)
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-l.ctx.Done():
// Exit when context is cancelled
l.logger.Debug("msg", "Cleanup loop stopping", "component", "netlimit")
return
case <-ticker.C:
l.cleanup()
}
}
}

View File

@ -0,0 +1,62 @@
// FILE: src/internal/netlimit/netlimiter.go
package netlimit
import (
"sync"
"time"
)
// TokenBucket implements a token bucket net limiter
type TokenBucket struct {
capacity float64
tokens float64
refillRate float64
lastRefill time.Time
mu sync.Mutex
}
// NewTokenBucket creates a new token bucket with given capacity and refill rate
func NewTokenBucket(capacity float64, refillRate float64) *TokenBucket {
return &TokenBucket{
capacity: capacity,
tokens: capacity,
refillRate: refillRate,
lastRefill: time.Now(),
}
}
// Allow attempts to consume one token, returns true if allowed
func (tb *TokenBucket) Allow() bool {
return tb.AllowN(1)
}
// AllowN attempts to consume n tokens, returns true if allowed
func (tb *TokenBucket) AllowN(n float64) bool {
tb.mu.Lock()
defer tb.mu.Unlock()
// Refill tokens based on time elapsed
now := time.Now()
elapsed := now.Sub(tb.lastRefill).Seconds()
// Handle time sync issues causing negative elapsed time
if elapsed < 0 {
// Clock went backwards, reset to current time but don't add tokens
tb.lastRefill = now
// Don't log here as this is a hot path
elapsed = 0
}
tb.tokens += elapsed * tb.refillRate
if tb.tokens > tb.capacity {
tb.tokens = tb.capacity
}
tb.lastRefill = now
// Check if we have enough tokens
if tb.tokens >= n {
tb.tokens -= n
return true
}
return false
}