v0.1.8 rate limiter added, improved http router, config templates added, docs updated
This commit is contained in:
@ -53,10 +53,14 @@ type RateLimitConfig struct {
|
||||
// Burst size (token bucket)
|
||||
BurstSize int `toml:"burst_size"`
|
||||
|
||||
// Rate limit by: "ip", "user", "token"
|
||||
// Rate limit by: "ip", "user", "token", "global"
|
||||
LimitBy string `toml:"limit_by"`
|
||||
|
||||
// Response when rate limited
|
||||
ResponseCode int `toml:"response_code"` // Default: 429
|
||||
ResponseMessage string `toml:"response_message"` // Default: "Rate limit exceeded"
|
||||
|
||||
// Connection limits
|
||||
MaxConnectionsPerIP int `toml:"max_connections_per_ip"`
|
||||
MaxTotalConnections int `toml:"max_total_connections"`
|
||||
}
|
||||
@ -68,6 +68,10 @@ func (c *Config) validate() error {
|
||||
if err := validateSSL("TCP", stream.Name, stream.TCPServer.SSL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateRateLimit("TCP", stream.Name, stream.TCPServer.RateLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate HTTP server
|
||||
@ -109,6 +113,10 @@ func (c *Config) validate() error {
|
||||
if err := validateSSL("HTTP", stream.Name, stream.HTTPServer.SSL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateRateLimit("HTTP", stream.Name, stream.HTTPServer.RateLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// At least one server must be enabled
|
||||
@ -185,5 +193,34 @@ func validateAuth(streamName string, auth *AuthConfig) error {
|
||||
return fmt.Errorf("stream '%s': bearer auth type specified but config missing", streamName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRateLimit(serverType, streamName string, rl *RateLimitConfig) error {
|
||||
if rl == nil || !rl.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rl.RequestsPerSecond <= 0 {
|
||||
return fmt.Errorf("stream '%s' %s: requests_per_second must be positive: %f",
|
||||
streamName, serverType, rl.RequestsPerSecond)
|
||||
}
|
||||
|
||||
if rl.BurstSize < 1 {
|
||||
return fmt.Errorf("stream '%s' %s: burst_size must be at least 1: %d",
|
||||
streamName, serverType, rl.BurstSize)
|
||||
}
|
||||
|
||||
validLimitBy := map[string]bool{"ip": true, "global": true, "": true}
|
||||
if !validLimitBy[rl.LimitBy] {
|
||||
return fmt.Errorf("stream '%s' %s: invalid limit_by value: %s (must be 'ip' or 'global')",
|
||||
streamName, serverType, rl.LimitBy)
|
||||
}
|
||||
|
||||
if rl.ResponseCode > 0 && (rl.ResponseCode < 400 || rl.ResponseCode >= 600) {
|
||||
return fmt.Errorf("stream '%s' %s: response_code must be 4xx or 5xx: %d",
|
||||
streamName, serverType, rl.ResponseCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -13,12 +15,19 @@ type HTTPRouter struct {
|
||||
service *Service
|
||||
servers map[int]*routerServer // port -> server
|
||||
mu sync.RWMutex
|
||||
|
||||
// Statistics
|
||||
startTime time.Time
|
||||
totalRequests atomic.Uint64
|
||||
routedRequests atomic.Uint64
|
||||
failedRequests atomic.Uint64
|
||||
}
|
||||
|
||||
func NewHTTPRouter(service *Service) *HTTPRouter {
|
||||
return &HTTPRouter{
|
||||
service: service,
|
||||
servers: make(map[int]*routerServer),
|
||||
service: service,
|
||||
servers: make(map[int]*routerServer),
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -34,24 +43,31 @@ func (r *HTTPRouter) RegisterStream(stream *LogStream) error {
|
||||
if !exists {
|
||||
// Create new server for this port
|
||||
rs = &routerServer{
|
||||
port: port,
|
||||
routes: make(map[string]*LogStream),
|
||||
port: port,
|
||||
routes: make(map[string]*LogStream),
|
||||
router: r,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
rs.server = &fasthttp.Server{
|
||||
Handler: rs.requestHandler,
|
||||
DisableKeepalive: false,
|
||||
StreamRequestBody: true,
|
||||
CloseOnShutdown: true, // Ensure connections close on shutdown
|
||||
}
|
||||
r.servers[port] = rs
|
||||
|
||||
// Start server in background
|
||||
go func() {
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
fmt.Printf("[ROUTER] Starting server on port %d\n", port)
|
||||
if err := rs.server.ListenAndServe(addr); err != nil {
|
||||
// Log error but don't crash
|
||||
fmt.Printf("Router server on port %d failed: %v\n", port, err)
|
||||
fmt.Printf("[ROUTER] Server on port %d failed: %v\n", port, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait briefly to ensure server starts
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
@ -71,6 +87,7 @@ func (r *HTTPRouter) RegisterStream(stream *LogStream) error {
|
||||
}
|
||||
|
||||
rs.routes[pathPrefix] = stream
|
||||
fmt.Printf("[ROUTER] Registered stream '%s' at path '%s' on port %d\n", stream.Name, pathPrefix, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -78,18 +95,27 @@ func (r *HTTPRouter) UnregisterStream(streamName string) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, rs := range r.servers {
|
||||
for port, rs := range r.servers {
|
||||
rs.routeMu.Lock()
|
||||
for path, stream := range rs.routes {
|
||||
if stream.Name == streamName {
|
||||
delete(rs.routes, path)
|
||||
fmt.Printf("[ROUTER] Unregistered stream '%s' from path '%s' on port %d\n",
|
||||
streamName, path, port)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if server has no more routes
|
||||
if len(rs.routes) == 0 {
|
||||
fmt.Printf("[ROUTER] No routes left on port %d, considering shutdown\n", port)
|
||||
}
|
||||
rs.routeMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HTTPRouter) Shutdown() {
|
||||
fmt.Println("[ROUTER] Starting router shutdown...")
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
@ -98,10 +124,46 @@ func (r *HTTPRouter) Shutdown() {
|
||||
wg.Add(1)
|
||||
go func(p int, s *routerServer) {
|
||||
defer wg.Done()
|
||||
fmt.Printf("[ROUTER] Shutting down server on port %d\n", p)
|
||||
if err := s.server.Shutdown(); err != nil {
|
||||
fmt.Printf("Error shutting down router server on port %d: %v\n", p, err)
|
||||
fmt.Printf("[ROUTER] Error shutting down server on port %d: %v\n", p, err)
|
||||
}
|
||||
}(port, rs)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
fmt.Println("[ROUTER] Router shutdown complete")
|
||||
}
|
||||
|
||||
func (r *HTTPRouter) GetStats() map[string]interface{} {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
serverStats := make(map[int]interface{})
|
||||
totalRoutes := 0
|
||||
|
||||
for port, rs := range r.servers {
|
||||
rs.routeMu.RLock()
|
||||
routes := make([]string, 0, len(rs.routes))
|
||||
for path := range rs.routes {
|
||||
routes = append(routes, path)
|
||||
totalRoutes++
|
||||
}
|
||||
rs.routeMu.RUnlock()
|
||||
|
||||
serverStats[port] = map[string]interface{}{
|
||||
"routes": routes,
|
||||
"requests": rs.requests.Load(),
|
||||
"uptime": int(time.Since(rs.startTime).Seconds()),
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"uptime_seconds": int(time.Since(r.startTime).Seconds()),
|
||||
"total_requests": r.totalRequests.Load(),
|
||||
"routed_requests": r.routedRequests.Load(),
|
||||
"failed_requests": r.failedRequests.Load(),
|
||||
"servers": serverStats,
|
||||
"total_routes": totalRoutes,
|
||||
}
|
||||
}
|
||||
@ -6,21 +6,32 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"logwisp/src/internal/version"
|
||||
)
|
||||
|
||||
type routerServer struct {
|
||||
port int
|
||||
server *fasthttp.Server
|
||||
routes map[string]*LogStream // path prefix -> stream
|
||||
routeMu sync.RWMutex
|
||||
port int
|
||||
server *fasthttp.Server
|
||||
routes map[string]*LogStream // path prefix -> stream
|
||||
routeMu sync.RWMutex
|
||||
router *HTTPRouter
|
||||
startTime time.Time
|
||||
requests atomic.Uint64
|
||||
}
|
||||
|
||||
func (rs *routerServer) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
rs.requests.Add(1)
|
||||
rs.router.totalRequests.Add(1)
|
||||
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Log request for debugging
|
||||
fmt.Printf("[ROUTER] Request: %s %s from %s\n", ctx.Method(), path, ctx.RemoteAddr())
|
||||
|
||||
// Special case: global status at /status
|
||||
if path == "/status" {
|
||||
rs.handleGlobalStatus(ctx)
|
||||
@ -40,26 +51,48 @@ func (rs *routerServer) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
matchedPrefix = prefix
|
||||
matchedStream = stream
|
||||
remainingPath = strings.TrimPrefix(path, prefix)
|
||||
// Ensure remaining path starts with / or is empty
|
||||
if remainingPath != "" && !strings.HasPrefix(remainingPath, "/") {
|
||||
remainingPath = "/" + remainingPath
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rs.routeMu.RUnlock()
|
||||
|
||||
if matchedStream == nil {
|
||||
rs.router.failedRequests.Add(1)
|
||||
rs.handleNotFound(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
rs.router.routedRequests.Add(1)
|
||||
|
||||
// Route to stream's handler
|
||||
if matchedStream.HTTPServer != nil {
|
||||
// Save original path
|
||||
originalPath := string(ctx.URI().Path())
|
||||
|
||||
// Rewrite path to remove stream prefix
|
||||
if remainingPath == "" {
|
||||
// Default to stream path if no remaining path
|
||||
remainingPath = matchedStream.Config.HTTPServer.StreamPath
|
||||
}
|
||||
|
||||
fmt.Printf("[ROUTER] Routing to stream '%s': %s -> %s\n",
|
||||
matchedStream.Name, originalPath, remainingPath)
|
||||
|
||||
ctx.URI().SetPath(remainingPath)
|
||||
matchedStream.HTTPServer.RouteRequest(ctx)
|
||||
|
||||
// Restore original path
|
||||
ctx.URI().SetPath(originalPath)
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusServiceUnavailable)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Stream HTTP server not available",
|
||||
"error": "Stream HTTP server not available",
|
||||
"stream": matchedStream.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -70,26 +103,37 @@ func (rs *routerServer) handleGlobalStatus(ctx *fasthttp.RequestCtx) {
|
||||
rs.routeMu.RLock()
|
||||
streams := make(map[string]interface{})
|
||||
for prefix, stream := range rs.routes {
|
||||
streams[stream.Name] = map[string]interface{}{
|
||||
streamStats := stream.GetStats()
|
||||
|
||||
// Add routing information
|
||||
streamStats["routing"] = map[string]interface{}{
|
||||
"path_prefix": prefix,
|
||||
"config": map[string]interface{}{
|
||||
"stream_path": stream.Config.HTTPServer.StreamPath,
|
||||
"status_path": stream.Config.HTTPServer.StatusPath,
|
||||
"endpoints": map[string]string{
|
||||
"stream": prefix + stream.Config.HTTPServer.StreamPath,
|
||||
"status": prefix + stream.Config.HTTPServer.StatusPath,
|
||||
},
|
||||
"stats": stream.GetStats(),
|
||||
}
|
||||
|
||||
streams[stream.Name] = streamStats
|
||||
}
|
||||
rs.routeMu.RUnlock()
|
||||
|
||||
// Get router stats
|
||||
routerStats := rs.router.GetStats()
|
||||
|
||||
status := map[string]interface{}{
|
||||
"service": "LogWisp Router",
|
||||
"version": version.Short(),
|
||||
"version": version.String(),
|
||||
"port": rs.port,
|
||||
"streams": streams,
|
||||
"total_streams": len(streams),
|
||||
"router": routerStats,
|
||||
"endpoints": map[string]string{
|
||||
"global_status": "/status",
|
||||
},
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(status)
|
||||
data, _ := json.MarshalIndent(status, "", " ")
|
||||
ctx.SetBody(data)
|
||||
}
|
||||
|
||||
@ -113,9 +157,11 @@ func (rs *routerServer) handleNotFound(ctx *fasthttp.RequestCtx) {
|
||||
|
||||
response := map[string]interface{}{
|
||||
"error": "Not Found",
|
||||
"requested_path": string(ctx.Path()),
|
||||
"available_routes": availableRoutes,
|
||||
"hint": "Use /status for global router status",
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(response)
|
||||
data, _ := json.MarshalIndent(response, "", " ")
|
||||
ctx.SetBody(data)
|
||||
}
|
||||
311
src/internal/ratelimit/limiter.go
Normal file
311
src/internal/ratelimit/limiter.go
Normal file
@ -0,0 +1,311 @@
|
||||
// FILE: src/internal/ratelimit/limiter.go
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
)
|
||||
|
||||
// Manages rate limiting for a stream
|
||||
type Limiter struct {
|
||||
config config.RateLimitConfig
|
||||
|
||||
// Per-IP limiters
|
||||
ipLimiters map[string]*ipLimiter
|
||||
ipMu sync.RWMutex
|
||||
|
||||
// Global limiter for the stream
|
||||
globalLimiter *TokenBucket
|
||||
|
||||
// Connection tracking
|
||||
ipConnections map[string]*atomic.Int32
|
||||
connMu sync.RWMutex
|
||||
|
||||
// Statistics
|
||||
totalRequests atomic.Uint64
|
||||
blockedRequests atomic.Uint64
|
||||
uniqueIPs atomic.Uint64
|
||||
|
||||
// Cleanup
|
||||
lastCleanup time.Time
|
||||
cleanupMu sync.Mutex
|
||||
}
|
||||
|
||||
type ipLimiter struct {
|
||||
bucket *TokenBucket
|
||||
lastSeen time.Time
|
||||
connections atomic.Int32
|
||||
}
|
||||
|
||||
// Creates a new rate limiter
|
||||
func New(cfg config.RateLimitConfig) *Limiter {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
l := &Limiter{
|
||||
config: cfg,
|
||||
ipLimiters: make(map[string]*ipLimiter),
|
||||
ipConnections: make(map[string]*atomic.Int32),
|
||||
lastCleanup: time.Now(),
|
||||
}
|
||||
|
||||
// Create global limiter if not using per-IP limiting
|
||||
if cfg.LimitBy == "global" {
|
||||
l.globalLimiter = NewTokenBucket(
|
||||
float64(cfg.BurstSize),
|
||||
cfg.RequestsPerSecond,
|
||||
)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go l.cleanupLoop()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// Checks if an HTTP request should be allowed
|
||||
func (l *Limiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int, message string) {
|
||||
if l == nil {
|
||||
return true, 0, ""
|
||||
}
|
||||
|
||||
l.totalRequests.Add(1)
|
||||
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
// If we can't parse the IP, allow the request but log
|
||||
fmt.Printf("[RATELIMIT] Failed to parse remote addr %s: %v\n", remoteAddr, err)
|
||||
return true, 0, ""
|
||||
}
|
||||
|
||||
// Check connection limit for streaming endpoint
|
||||
if l.config.MaxConnectionsPerIP > 0 {
|
||||
l.connMu.RLock()
|
||||
counter, exists := l.ipConnections[ip]
|
||||
l.connMu.RUnlock()
|
||||
|
||||
if exists && counter.Load() >= int32(l.config.MaxConnectionsPerIP) {
|
||||
l.blockedRequests.Add(1)
|
||||
statusCode = l.config.ResponseCode
|
||||
if statusCode == 0 {
|
||||
statusCode = 429
|
||||
}
|
||||
message = "Connection limit exceeded"
|
||||
return false, statusCode, message
|
||||
}
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
allowed = l.checkLimit(ip)
|
||||
if !allowed {
|
||||
l.blockedRequests.Add(1)
|
||||
statusCode = l.config.ResponseCode
|
||||
if statusCode == 0 {
|
||||
statusCode = 429
|
||||
}
|
||||
message = l.config.ResponseMessage
|
||||
if message == "" {
|
||||
message = "Rate limit exceeded"
|
||||
}
|
||||
}
|
||||
|
||||
return allowed, statusCode, message
|
||||
}
|
||||
|
||||
// Checks if a TCP connection should be allowed
|
||||
func (l *Limiter) CheckTCP(remoteAddr net.Addr) bool {
|
||||
if l == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
l.totalRequests.Add(1)
|
||||
|
||||
// Extract IP from TCP addr
|
||||
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := tcpAddr.IP.String()
|
||||
allowed := l.checkLimit(ip)
|
||||
if !allowed {
|
||||
l.blockedRequests.Add(1)
|
||||
}
|
||||
|
||||
return allowed
|
||||
}
|
||||
|
||||
// Tracks a new connection for an IP
|
||||
func (l *Limiter) AddConnection(remoteAddr string) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.connMu.Lock()
|
||||
counter, exists := l.ipConnections[ip]
|
||||
if !exists {
|
||||
counter = &atomic.Int32{}
|
||||
l.ipConnections[ip] = counter
|
||||
}
|
||||
l.connMu.Unlock()
|
||||
|
||||
counter.Add(1)
|
||||
}
|
||||
|
||||
// Removes a connection for an IP
|
||||
func (l *Limiter) RemoveConnection(remoteAddr string) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.connMu.RLock()
|
||||
counter, exists := l.ipConnections[ip]
|
||||
l.connMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
newCount := counter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
// Clean up if no more connections
|
||||
l.connMu.Lock()
|
||||
if counter.Load() <= 0 {
|
||||
delete(l.ipConnections, ip)
|
||||
}
|
||||
l.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns rate limiter statistics
|
||||
func (l *Limiter) GetStats() map[string]interface{} {
|
||||
if l == nil {
|
||||
return map[string]interface{}{
|
||||
"enabled": false,
|
||||
}
|
||||
}
|
||||
|
||||
l.ipMu.RLock()
|
||||
activeIPs := len(l.ipLimiters)
|
||||
l.ipMu.RUnlock()
|
||||
|
||||
l.connMu.RLock()
|
||||
totalConnections := 0
|
||||
for _, counter := range l.ipConnections {
|
||||
totalConnections += int(counter.Load())
|
||||
}
|
||||
l.connMu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"enabled": true,
|
||||
"total_requests": l.totalRequests.Load(),
|
||||
"blocked_requests": l.blockedRequests.Load(),
|
||||
"active_ips": activeIPs,
|
||||
"total_connections": totalConnections,
|
||||
"config": map[string]interface{}{
|
||||
"requests_per_second": l.config.RequestsPerSecond,
|
||||
"burst_size": l.config.BurstSize,
|
||||
"limit_by": l.config.LimitBy,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Performs the actual rate limit check
|
||||
func (l *Limiter) checkLimit(ip string) bool {
|
||||
// Maybe run cleanup
|
||||
l.maybeCleanup()
|
||||
|
||||
switch l.config.LimitBy {
|
||||
case "global":
|
||||
return l.globalLimiter.Allow()
|
||||
|
||||
case "ip", "":
|
||||
// Default to per-IP limiting
|
||||
l.ipMu.Lock()
|
||||
limiter, exists := l.ipLimiters[ip]
|
||||
if !exists {
|
||||
// Create new limiter for this IP
|
||||
limiter = &ipLimiter{
|
||||
bucket: NewTokenBucket(
|
||||
float64(l.config.BurstSize),
|
||||
l.config.RequestsPerSecond,
|
||||
),
|
||||
lastSeen: time.Now(),
|
||||
}
|
||||
l.ipLimiters[ip] = limiter
|
||||
l.uniqueIPs.Add(1)
|
||||
} else {
|
||||
limiter.lastSeen = time.Now()
|
||||
}
|
||||
l.ipMu.Unlock()
|
||||
|
||||
// Check connection limit if configured
|
||||
if l.config.MaxConnectionsPerIP > 0 {
|
||||
l.connMu.RLock()
|
||||
counter, exists := l.ipConnections[ip]
|
||||
l.connMu.RUnlock()
|
||||
|
||||
if exists && counter.Load() >= int32(l.config.MaxConnectionsPerIP) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return limiter.bucket.Allow()
|
||||
|
||||
default:
|
||||
// Unknown limit_by value, allow by default
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Runs cleanup if enough time has passed
|
||||
func (l *Limiter) maybeCleanup() {
|
||||
l.cleanupMu.Lock()
|
||||
defer l.cleanupMu.Unlock()
|
||||
|
||||
if time.Since(l.lastCleanup) < 30*time.Second {
|
||||
return
|
||||
}
|
||||
|
||||
l.lastCleanup = time.Now()
|
||||
go l.cleanup()
|
||||
}
|
||||
|
||||
// Removes stale IP limiters
|
||||
func (l *Limiter) cleanup() {
|
||||
staleTimeout := 5 * time.Minute
|
||||
now := time.Now()
|
||||
|
||||
l.ipMu.Lock()
|
||||
defer l.ipMu.Unlock()
|
||||
|
||||
for ip, limiter := range l.ipLimiters {
|
||||
if now.Sub(limiter.lastSeen) > staleTimeout {
|
||||
delete(l.ipLimiters, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Runs periodic cleanup
|
||||
func (l *Limiter) cleanupLoop() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
l.cleanup()
|
||||
}
|
||||
}
|
||||
53
src/internal/ratelimit/ratelimit.go
Normal file
53
src/internal/ratelimit/ratelimit.go
Normal file
@ -0,0 +1,53 @@
|
||||
// FILE: src/internal/ratelimit/ratelimit.go
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenBucket implements a token bucket rate limiter
|
||||
type TokenBucket struct {
|
||||
capacity float64
|
||||
tokens float64
|
||||
refillRate float64
|
||||
lastRefill time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTokenBucket creates a new token bucket with given capacity and refill rate
|
||||
func NewTokenBucket(capacity float64, refillRate float64) *TokenBucket {
|
||||
return &TokenBucket{
|
||||
capacity: capacity,
|
||||
tokens: capacity,
|
||||
refillRate: refillRate,
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow attempts to consume one token, returns true if allowed
|
||||
func (tb *TokenBucket) Allow() bool {
|
||||
return tb.AllowN(1)
|
||||
}
|
||||
|
||||
// AllowN attempts to consume n tokens, returns true if allowed
|
||||
func (tb *TokenBucket) AllowN(n float64) bool {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
|
||||
// Refill tokens based on time elapsed
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(tb.lastRefill).Seconds()
|
||||
tb.tokens += elapsed * tb.refillRate
|
||||
if tb.tokens > tb.capacity {
|
||||
tb.tokens = tb.capacity
|
||||
}
|
||||
tb.lastRefill = now
|
||||
|
||||
// Check if we have enough tokens
|
||||
if tb.tokens >= n {
|
||||
tb.tokens -= n
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -14,6 +14,7 @@ import (
|
||||
"github.com/valyala/fasthttp"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/monitor"
|
||||
"logwisp/src/internal/ratelimit"
|
||||
"logwisp/src/internal/version"
|
||||
)
|
||||
|
||||
@ -33,6 +34,9 @@ type HTTPStreamer struct {
|
||||
|
||||
// For router integration
|
||||
standalone bool
|
||||
|
||||
// Rate limiting
|
||||
rateLimiter *ratelimit.Limiter
|
||||
}
|
||||
|
||||
func NewHTTPStreamer(logChan chan monitor.LogEntry, cfg config.HTTPConfig) *HTTPStreamer {
|
||||
@ -46,7 +50,7 @@ func NewHTTPStreamer(logChan chan monitor.LogEntry, cfg config.HTTPConfig) *HTTP
|
||||
statusPath = "/status"
|
||||
}
|
||||
|
||||
return &HTTPStreamer{
|
||||
h := &HTTPStreamer{
|
||||
logChan: logChan,
|
||||
config: cfg,
|
||||
startTime: time.Now(),
|
||||
@ -55,9 +59,16 @@ func NewHTTPStreamer(logChan chan monitor.LogEntry, cfg config.HTTPConfig) *HTTP
|
||||
statusPath: statusPath,
|
||||
standalone: true, // Default to standalone mode
|
||||
}
|
||||
|
||||
// Initialize rate limiter if configured
|
||||
if cfg.RateLimit != nil && cfg.RateLimit.Enabled {
|
||||
h.rateLimiter = ratelimit.New(*cfg.RateLimit)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// SetRouterMode configures the streamer for use with a router
|
||||
// Configures the streamer for use with a router
|
||||
func (h *HTTPStreamer) SetRouterMode() {
|
||||
h.standalone = false
|
||||
}
|
||||
@ -116,6 +127,18 @@ func (h *HTTPStreamer) RouteRequest(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
|
||||
func (h *HTTPStreamer) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
// Check rate limit first
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
if allowed, statusCode, message := h.rateLimiter.CheckHTTP(remoteAddr); !allowed {
|
||||
ctx.SetStatusCode(statusCode)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]interface{}{
|
||||
"error": message,
|
||||
"retry_after": "60", // seconds
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
path := string(ctx.Path())
|
||||
|
||||
switch path {
|
||||
@ -135,6 +158,13 @@ func (h *HTTPStreamer) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
|
||||
func (h *HTTPStreamer) handleStream(ctx *fasthttp.RequestCtx) {
|
||||
// Track connection for rate limiting
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
if h.rateLimiter != nil {
|
||||
h.rateLimiter.AddConnection(remoteAddr)
|
||||
defer h.rateLimiter.RemoveConnection(remoteAddr)
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
ctx.Response.Header.Set("Content-Type", "text/event-stream")
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
@ -285,6 +315,15 @@ func (h *HTTPStreamer) formatHeartbeat() string {
|
||||
func (h *HTTPStreamer) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
var rateLimitStats interface{}
|
||||
if h.rateLimiter != nil {
|
||||
rateLimitStats = h.rateLimiter.GetStats()
|
||||
} else {
|
||||
rateLimitStats = map[string]interface{}{
|
||||
"enabled": false,
|
||||
}
|
||||
}
|
||||
|
||||
status := map[string]interface{}{
|
||||
"service": "LogWisp",
|
||||
"version": version.Short(),
|
||||
@ -309,9 +348,7 @@ func (h *HTTPStreamer) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
"ssl": map[string]bool{
|
||||
"enabled": h.config.SSL != nil && h.config.SSL.Enabled,
|
||||
},
|
||||
"rate_limit": map[string]bool{
|
||||
"enabled": h.config.RateLimit != nil && h.config.RateLimit.Enabled,
|
||||
},
|
||||
"rate_limit": rateLimitStats,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
@ -17,10 +18,34 @@ type tcpServer struct {
|
||||
func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
// Store engine reference for shutdown
|
||||
s.streamer.engine = &eng
|
||||
fmt.Printf("[TCP DEBUG] Server booted on port %d\n", s.streamer.config.Port)
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
// Debug: Log all connection attempts
|
||||
fmt.Printf("[TCP DEBUG] Connection attempt from %s\n", c.RemoteAddr())
|
||||
|
||||
// Check rate limit
|
||||
if s.streamer.rateLimiter != nil {
|
||||
// Parse the remote address to get proper net.Addr
|
||||
remoteStr := c.RemoteAddr().String()
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr)
|
||||
if err != nil {
|
||||
fmt.Printf("[TCP DEBUG] Failed to parse address %s: %v\n", remoteStr, err)
|
||||
return nil, gnet.Close
|
||||
}
|
||||
|
||||
if !s.streamer.rateLimiter.CheckTCP(tcpAddr) {
|
||||
fmt.Printf("[TCP DEBUG] Rate limited connection from %s\n", remoteStr)
|
||||
// Silently close connection when rate limited
|
||||
return nil, gnet.Close
|
||||
}
|
||||
|
||||
// Track connection
|
||||
s.streamer.rateLimiter.AddConnection(remoteStr)
|
||||
}
|
||||
|
||||
s.connections.Store(c, struct{}{})
|
||||
|
||||
oldCount := s.streamer.activeConns.Load()
|
||||
@ -34,6 +59,11 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
s.connections.Delete(c)
|
||||
|
||||
// Remove connection tracking
|
||||
if s.streamer.rateLimiter != nil {
|
||||
s.streamer.rateLimiter.RemoveConnection(c.RemoteAddr().String())
|
||||
}
|
||||
|
||||
oldCount := s.streamer.activeConns.Load()
|
||||
newCount := s.streamer.activeConns.Add(-1)
|
||||
fmt.Printf("[TCP ATOMIC] OnClose: %d -> %d (expected: %d)\n", oldCount, newCount, oldCount-1)
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/monitor"
|
||||
"logwisp/src/internal/ratelimit"
|
||||
)
|
||||
|
||||
type TCPStreamer struct {
|
||||
@ -23,15 +24,22 @@ type TCPStreamer struct {
|
||||
startTime time.Time
|
||||
engine *gnet.Engine
|
||||
wg sync.WaitGroup
|
||||
rateLimiter *ratelimit.Limiter
|
||||
}
|
||||
|
||||
func NewTCPStreamer(logChan chan monitor.LogEntry, cfg config.TCPConfig) *TCPStreamer {
|
||||
return &TCPStreamer{
|
||||
t := &TCPStreamer{
|
||||
logChan: logChan,
|
||||
config: cfg,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
if cfg.RateLimit != nil && cfg.RateLimit.Enabled {
|
||||
t.rateLimiter = ratelimit.New(*cfg.RateLimit)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TCPStreamer) Start() error {
|
||||
|
||||
Reference in New Issue
Block a user