v0.4.1 authentication impelemented, not tested and docs not updated
This commit is contained in:
1
go.mod
1
go.mod
@ -10,6 +10,7 @@ require (
|
|||||||
github.com/valyala/fasthttp v1.65.0
|
github.com/valyala/fasthttp v1.65.0
|
||||||
golang.org/x/crypto v0.42.0
|
golang.org/x/crypto v0.42.0
|
||||||
golang.org/x/term v0.35.0
|
golang.org/x/term v0.35.0
|
||||||
|
golang.org/x/time v0.13.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
2
go.sum
2
go.sum
@ -42,6 +42,8 @@ golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
|||||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
|
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
|
||||||
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
|
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
|
||||||
|
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
|
||||||
|
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||||
|
|||||||
@ -3,8 +3,10 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -15,8 +17,12 @@ import (
|
|||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Prevent unbounded map growth
|
||||||
|
const maxAuthTrackedIPs = 10000
|
||||||
|
|
||||||
// Authenticator handles all authentication methods for a pipeline
|
// Authenticator handles all authentication methods for a pipeline
|
||||||
type Authenticator struct {
|
type Authenticator struct {
|
||||||
config *config.AuthConfig
|
config *config.AuthConfig
|
||||||
@ -30,6 +36,18 @@ type Authenticator struct {
|
|||||||
// Session tracking
|
// Session tracking
|
||||||
sessions map[string]*Session
|
sessions map[string]*Session
|
||||||
sessionMu sync.RWMutex
|
sessionMu sync.RWMutex
|
||||||
|
|
||||||
|
// Brute-force protection
|
||||||
|
ipAuthAttempts map[string]*ipAuthState
|
||||||
|
authMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// ADDED: Per-IP auth attempt tracking
|
||||||
|
type ipAuthState struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
failCount int
|
||||||
|
lastAttempt time.Time
|
||||||
|
blockedUntil time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session represents an authenticated connection
|
// Session represents an authenticated connection
|
||||||
@ -50,11 +68,12 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
a := &Authenticator{
|
a := &Authenticator{
|
||||||
config: cfg,
|
config: cfg,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
basicUsers: make(map[string]string),
|
basicUsers: make(map[string]string),
|
||||||
bearerTokens: make(map[string]bool),
|
bearerTokens: make(map[string]bool),
|
||||||
sessions: make(map[string]*Session),
|
sessions: make(map[string]*Session),
|
||||||
|
ipAuthAttempts: make(map[string]*ipAuthState),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize Basic Auth users
|
// Initialize Basic Auth users
|
||||||
@ -82,6 +101,7 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
|
|||||||
a.jwtParser = jwt.NewParser(
|
a.jwtParser = jwt.NewParser(
|
||||||
jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}),
|
jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}),
|
||||||
jwt.WithLeeway(5*time.Second),
|
jwt.WithLeeway(5*time.Second),
|
||||||
|
jwt.WithExpirationRequired(),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Setup key function
|
// Setup key function
|
||||||
@ -102,6 +122,9 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
|
|||||||
// Start session cleanup
|
// Start session cleanup
|
||||||
go a.sessionCleanup()
|
go a.sessionCleanup()
|
||||||
|
|
||||||
|
// Start auth attempt cleanup
|
||||||
|
go a.authAttemptCleanup()
|
||||||
|
|
||||||
logger.Info("msg", "Authenticator initialized",
|
logger.Info("msg", "Authenticator initialized",
|
||||||
"component", "auth",
|
"component", "auth",
|
||||||
"type", cfg.Type)
|
"type", cfg.Type)
|
||||||
@ -109,6 +132,129 @@ func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
|
|||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check and enforce rate limits
|
||||||
|
func (a *Authenticator) checkRateLimit(remoteAddr string) error {
|
||||||
|
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
ip = remoteAddr // Fallback for malformed addresses
|
||||||
|
}
|
||||||
|
|
||||||
|
a.authMu.Lock()
|
||||||
|
defer a.authMu.Unlock()
|
||||||
|
|
||||||
|
state, exists := a.ipAuthAttempts[ip]
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
// Check map size limit before creating new entry
|
||||||
|
if len(a.ipAuthAttempts) >= maxAuthTrackedIPs {
|
||||||
|
// Evict an old entry using simplified LRU
|
||||||
|
// Sample 20 random entries and evict the oldest
|
||||||
|
const sampleSize = 20
|
||||||
|
var oldestIP string
|
||||||
|
oldestTime := now
|
||||||
|
|
||||||
|
// Build sample
|
||||||
|
sampled := 0
|
||||||
|
for sampledIP, sampledState := range a.ipAuthAttempts {
|
||||||
|
if sampledState.lastAttempt.Before(oldestTime) {
|
||||||
|
oldestIP = sampledIP
|
||||||
|
oldestTime = sampledState.lastAttempt
|
||||||
|
}
|
||||||
|
sampled++
|
||||||
|
if sampled >= sampleSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict the oldest from our sample
|
||||||
|
if oldestIP != "" {
|
||||||
|
delete(a.ipAuthAttempts, oldestIP)
|
||||||
|
a.logger.Debug("msg", "Evicted old auth attempt state",
|
||||||
|
"component", "auth",
|
||||||
|
"evicted_ip", oldestIP,
|
||||||
|
"last_seen", oldestTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new state for this IP
|
||||||
|
// 5 attempts per minute, burst of 3
|
||||||
|
state = &ipAuthState{
|
||||||
|
limiter: rate.NewLimiter(rate.Every(12*time.Second), 3),
|
||||||
|
lastAttempt: now,
|
||||||
|
}
|
||||||
|
a.ipAuthAttempts[ip] = state
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if IP is temporarily blocked
|
||||||
|
if now.Before(state.blockedUntil) {
|
||||||
|
remaining := state.blockedUntil.Sub(now)
|
||||||
|
a.logger.Warn("msg", "IP temporarily blocked",
|
||||||
|
"component", "auth",
|
||||||
|
"ip", ip,
|
||||||
|
"remaining", remaining)
|
||||||
|
// Sleep to slow down even blocked attempts
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
return fmt.Errorf("temporarily blocked, try again in %v", remaining.Round(time.Second))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rate limit
|
||||||
|
if !state.limiter.Allow() {
|
||||||
|
state.failCount++
|
||||||
|
|
||||||
|
// Only set new blockedUntil if not already blocked
|
||||||
|
// This prevents indefinite block extension
|
||||||
|
if state.blockedUntil.IsZero() || now.After(state.blockedUntil) {
|
||||||
|
// Progressive blocking: 2^failCount minutes
|
||||||
|
blockMinutes := 1 << min(state.failCount, 6) // Cap at 64 minutes
|
||||||
|
state.blockedUntil = now.Add(time.Duration(blockMinutes) * time.Minute)
|
||||||
|
|
||||||
|
a.logger.Warn("msg", "Rate limit exceeded, blocking IP",
|
||||||
|
"component", "auth",
|
||||||
|
"ip", ip,
|
||||||
|
"fail_count", state.failCount,
|
||||||
|
"block_duration", time.Duration(blockMinutes)*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("rate limit exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
state.lastAttempt = now
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record failed attempt
|
||||||
|
func (a *Authenticator) recordFailure(remoteAddr string) {
|
||||||
|
ip, _, _ := net.SplitHostPort(remoteAddr)
|
||||||
|
if ip == "" {
|
||||||
|
ip = remoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
a.authMu.Lock()
|
||||||
|
defer a.authMu.Unlock()
|
||||||
|
|
||||||
|
if state, exists := a.ipAuthAttempts[ip]; exists {
|
||||||
|
state.failCount++
|
||||||
|
state.lastAttempt = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset failure count on success
|
||||||
|
func (a *Authenticator) recordSuccess(remoteAddr string) {
|
||||||
|
ip, _, _ := net.SplitHostPort(remoteAddr)
|
||||||
|
if ip == "" {
|
||||||
|
ip = remoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
a.authMu.Lock()
|
||||||
|
defer a.authMu.Unlock()
|
||||||
|
|
||||||
|
if state, exists := a.ipAuthAttempts[ip]; exists {
|
||||||
|
state.failCount = 0
|
||||||
|
state.blockedUntil = time.Time{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// AuthenticateHTTP handles HTTP authentication headers
|
// AuthenticateHTTP handles HTTP authentication headers
|
||||||
func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) {
|
func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) {
|
||||||
if a == nil || a.config.Type == "none" {
|
if a == nil || a.config.Type == "none" {
|
||||||
@ -120,14 +266,31 @@ func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Sessio
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit
|
||||||
|
if err := a.checkRateLimit(remoteAddr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var session *Session
|
||||||
|
var err error
|
||||||
|
|
||||||
switch a.config.Type {
|
switch a.config.Type {
|
||||||
case "basic":
|
case "basic":
|
||||||
return a.authenticateBasic(authHeader, remoteAddr)
|
session, err = a.authenticateBasic(authHeader, remoteAddr)
|
||||||
case "bearer":
|
case "bearer":
|
||||||
return a.authenticateBearer(authHeader, remoteAddr)
|
session, err = a.authenticateBearer(authHeader, remoteAddr)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported auth type: %s", a.config.Type)
|
err = fmt.Errorf("unsupported auth type: %s", a.config.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
a.recordFailure(remoteAddr)
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.recordSuccess(remoteAddr)
|
||||||
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthenticateTCP handles TCP connection authentication
|
// AuthenticateTCP handles TCP connection authentication
|
||||||
@ -141,32 +304,54 @@ func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit first
|
||||||
|
if err := a.checkRateLimit(remoteAddr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var session *Session
|
||||||
|
var err error
|
||||||
|
|
||||||
// TCP auth protocol: AUTH <method> <credentials>
|
// TCP auth protocol: AUTH <method> <credentials>
|
||||||
switch strings.ToLower(method) {
|
switch strings.ToLower(method) {
|
||||||
case "token":
|
case "token":
|
||||||
if a.config.Type != "bearer" {
|
if a.config.Type != "bearer" {
|
||||||
return nil, fmt.Errorf("token auth not configured")
|
err = fmt.Errorf("token auth not configured")
|
||||||
|
} else {
|
||||||
|
session, err = a.validateToken(credentials, remoteAddr)
|
||||||
}
|
}
|
||||||
return a.validateToken(credentials, remoteAddr)
|
|
||||||
|
|
||||||
case "basic":
|
case "basic":
|
||||||
if a.config.Type != "basic" {
|
if a.config.Type != "basic" {
|
||||||
return nil, fmt.Errorf("basic auth not configured")
|
err = fmt.Errorf("basic auth not configured")
|
||||||
|
} else {
|
||||||
|
// Expect base64(username:password)
|
||||||
|
decoded, decErr := base64.StdEncoding.DecodeString(credentials)
|
||||||
|
if decErr != nil {
|
||||||
|
err = fmt.Errorf("invalid credentials encoding")
|
||||||
|
} else {
|
||||||
|
parts := strings.SplitN(string(decoded), ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
err = fmt.Errorf("invalid credentials format")
|
||||||
|
} else {
|
||||||
|
session, err = a.validateBasicAuth(parts[0], parts[1], remoteAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Expect base64(username:password)
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(credentials)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid credentials encoding")
|
|
||||||
}
|
|
||||||
parts := strings.SplitN(string(decoded), ":", 2)
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return nil, fmt.Errorf("invalid credentials format")
|
|
||||||
}
|
|
||||||
return a.validateBasicAuth(parts[0], parts[1], remoteAddr)
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported auth method: %s", method)
|
err = fmt.Errorf("unsupported auth method: %s", method)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
a.recordFailure(remoteAddr)
|
||||||
|
// Add delay on failure
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.recordSuccess(remoteAddr)
|
||||||
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) {
|
func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) {
|
||||||
@ -255,6 +440,23 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error
|
|||||||
return nil, fmt.Errorf("invalid JWT token")
|
return nil, fmt.Errorf("invalid JWT token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Explicit expiration check
|
||||||
|
if exp, ok := claims["exp"].(float64); ok {
|
||||||
|
if time.Now().Unix() > int64(exp) {
|
||||||
|
return nil, fmt.Errorf("token expired")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Reject tokens without expiration
|
||||||
|
return nil, fmt.Errorf("token missing expiration claim")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check not-before claim
|
||||||
|
if nbf, ok := claims["nbf"].(float64); ok {
|
||||||
|
if time.Now().Unix() < int64(nbf) {
|
||||||
|
return nil, fmt.Errorf("token not yet valid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check issuer if configured
|
// Check issuer if configured
|
||||||
if a.config.BearerAuth.JWT.Issuer != "" {
|
if a.config.BearerAuth.JWT.Issuer != "" {
|
||||||
if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer {
|
if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer {
|
||||||
@ -264,7 +466,20 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error
|
|||||||
|
|
||||||
// Check audience if configured
|
// Check audience if configured
|
||||||
if a.config.BearerAuth.JWT.Audience != "" {
|
if a.config.BearerAuth.JWT.Audience != "" {
|
||||||
if aud, ok := claims["aud"].(string); !ok || aud != a.config.BearerAuth.JWT.Audience {
|
// Handle both string and []string audience formats
|
||||||
|
audValid := false
|
||||||
|
switch aud := claims["aud"].(type) {
|
||||||
|
case string:
|
||||||
|
audValid = aud == a.config.BearerAuth.JWT.Audience
|
||||||
|
case []interface{}:
|
||||||
|
for _, aa := range aud {
|
||||||
|
if audStr, ok := aa.(string); ok && audStr == a.config.BearerAuth.JWT.Audience {
|
||||||
|
audValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !audValid {
|
||||||
return nil, fmt.Errorf("invalid token audience")
|
return nil, fmt.Errorf("invalid token audience")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -322,6 +537,27 @@ func (a *Authenticator) sessionCleanup() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cleanup old auth attempts
|
||||||
|
func (a *Authenticator) authAttemptCleanup() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
a.authMu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for ip, state := range a.ipAuthAttempts {
|
||||||
|
// Remove entries older than 1 hour with no recent activity
|
||||||
|
if now.Sub(state.lastAttempt) > time.Hour {
|
||||||
|
delete(a.ipAuthAttempts, ip)
|
||||||
|
a.logger.Debug("msg", "Cleaned up auth attempt state",
|
||||||
|
"component", "auth",
|
||||||
|
"ip", ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a.authMu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Authenticator) loadUsersFile(path string) error {
|
func (a *Authenticator) loadUsersFile(path string) error {
|
||||||
file, err := os.Open(path)
|
file, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -366,7 +602,12 @@ func (a *Authenticator) loadUsersFile(path string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateSessionID() string {
|
func generateSessionID() string {
|
||||||
return fmt.Sprintf("%d-%d", time.Now().UnixNano(), time.Now().Unix())
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
// Fallback to a less secure method if crypto/rand fails
|
||||||
|
return fmt.Sprintf("fallback-%d", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSession checks if a session is still valid
|
// ValidateSession checks if a session is still valid
|
||||||
|
|||||||
@ -3,8 +3,6 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
|
|||||||
@ -3,7 +3,6 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,37 +28,6 @@ type RateLimitConfig struct {
|
|||||||
MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"`
|
MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNetAccess(pipelineName string, cfg *NetAccessConfig) error {
|
|
||||||
if cfg == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate CIDR notation
|
|
||||||
for _, cidr := range cfg.IPWhitelist {
|
|
||||||
if !strings.Contains(cidr, "/") {
|
|
||||||
cidr = cidr + "/32"
|
|
||||||
}
|
|
||||||
if _, _, err := net.ParseCIDR(cidr); err != nil {
|
|
||||||
if net.ParseIP(cidr) == nil {
|
|
||||||
return fmt.Errorf("pipeline '%s': invalid IP whitelist entry: %s", pipelineName, cidr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cidr := range cfg.IPBlacklist {
|
|
||||||
if !strings.Contains(cidr, "/") {
|
|
||||||
cidr = cidr + "/32"
|
|
||||||
}
|
|
||||||
if _, _, err := net.ParseCIDR(cidr); err != nil {
|
|
||||||
if net.ParseIP(cidr) == nil {
|
|
||||||
return fmt.Errorf("pipeline '%s': invalid IP blacklist entry: %s", pipelineName, cidr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error {
|
func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -20,9 +20,6 @@ type PipelineConfig struct {
|
|||||||
// Rate limiting
|
// Rate limiting
|
||||||
RateLimit *RateLimitConfig `toml:"rate_limit"`
|
RateLimit *RateLimitConfig `toml:"rate_limit"`
|
||||||
|
|
||||||
// Network access control (IP filtering)
|
|
||||||
NetAccess *NetAccessConfig `toml:"net_access"`
|
|
||||||
|
|
||||||
// Filter configuration
|
// Filter configuration
|
||||||
Filters []FilterConfig `toml:"filters"`
|
Filters []FilterConfig `toml:"filters"`
|
||||||
|
|
||||||
@ -37,12 +34,6 @@ type PipelineConfig struct {
|
|||||||
Auth *AuthConfig `toml:"auth"`
|
Auth *AuthConfig `toml:"auth"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NetAccessConfig defines IP-based access control lists
|
|
||||||
type NetAccessConfig struct {
|
|
||||||
IPWhitelist []string `toml:"ip_whitelist"`
|
|
||||||
IPBlacklist []string `toml:"ip_blacklist"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SourceConfig represents an input data source
|
// SourceConfig represents an input data source
|
||||||
type SourceConfig struct {
|
type SourceConfig struct {
|
||||||
// Source type: "directory", "file", "stdin", etc.
|
// Source type: "directory", "file", "stdin", etc.
|
||||||
@ -132,6 +123,13 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHANGED: Validate SSL if present
|
||||||
|
if ssl, ok := cfg.Options["ssl"].(map[string]any); ok {
|
||||||
|
if err := validateSSLOptions("HTTP source", pipelineName, sourceIndex, ssl); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
case "tcp":
|
case "tcp":
|
||||||
// Validate TCP source options
|
// Validate TCP source options
|
||||||
port, ok := cfg.Options["port"].(int64)
|
port, ok := cfg.Options["port"].(int64)
|
||||||
@ -147,6 +145,13 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHANGED: Validate SSL if present
|
||||||
|
if ssl, ok := cfg.Options["ssl"].(map[string]any); ok {
|
||||||
|
if err := validateSSLOptions("TCP source", pipelineName, sourceIndex, ssl); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("pipeline '%s' source[%d]: unknown source type '%s'",
|
return fmt.Errorf("pipeline '%s' source[%d]: unknown source type '%s'",
|
||||||
pipelineName, sourceIndex, cfg.Type)
|
pipelineName, sourceIndex, cfg.Type)
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
// FILE: logwisp/src/internal/config/server.go
|
// FILE: logwisp/src/internal/config/server.go
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type TCPConfig struct {
|
type TCPConfig struct {
|
||||||
Enabled bool `toml:"enabled"`
|
Enabled bool `toml:"enabled"`
|
||||||
@ -49,6 +53,10 @@ type NetLimitConfig struct {
|
|||||||
// Enable net limiting
|
// Enable net limiting
|
||||||
Enabled bool `toml:"enabled"`
|
Enabled bool `toml:"enabled"`
|
||||||
|
|
||||||
|
// IP Access Control Lists
|
||||||
|
IPWhitelist []string `toml:"ip_whitelist"`
|
||||||
|
IPBlacklist []string `toml:"ip_blacklist"`
|
||||||
|
|
||||||
// Requests per second per client
|
// Requests per second per client
|
||||||
RequestsPerSecond float64 `toml:"requests_per_second"`
|
RequestsPerSecond float64 `toml:"requests_per_second"`
|
||||||
|
|
||||||
@ -90,6 +98,33 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate IP lists if present
|
||||||
|
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||||
|
for i, entry := range ipWhitelist {
|
||||||
|
entryStr, ok := entry.(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := validateIPv4Entry(entryStr); err != nil {
|
||||||
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: whitelist[%d] %v",
|
||||||
|
pipelineName, sinkIndex, serverType, i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||||
|
for i, entry := range ipBlacklist {
|
||||||
|
entryStr, ok := entry.(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := validateIPv4Entry(entryStr); err != nil {
|
||||||
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: blacklist[%d] %v",
|
||||||
|
pipelineName, sinkIndex, serverType, i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Validate requests per second
|
// Validate requests per second
|
||||||
rps, ok := rl["requests_per_second"].(float64)
|
rps, ok := rl["requests_per_second"].(float64)
|
||||||
if !ok || rps <= 0 {
|
if !ok || rps <= 0 {
|
||||||
@ -132,5 +167,39 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateIPv4Entry ensures an IP or CIDR is IPv4
|
||||||
|
func validateIPv4Entry(entry string) error {
|
||||||
|
// Handle single IP
|
||||||
|
if !strings.Contains(entry, "/") {
|
||||||
|
ip := net.ParseIP(entry)
|
||||||
|
if ip == nil {
|
||||||
|
return fmt.Errorf("invalid IP address: %s", entry)
|
||||||
|
}
|
||||||
|
if ip.To4() == nil {
|
||||||
|
return fmt.Errorf("IPv6 not supported (IPv4-only): %s", entry)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle CIDR
|
||||||
|
ipAddr, ipNet, err := net.ParseCIDR(entry)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid CIDR: %s", entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the IP is IPv4
|
||||||
|
if ipAddr.To4() == nil {
|
||||||
|
return fmt.Errorf("IPv6 CIDR not supported (IPv4-only): %s", entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the network mask is appropriate for IPv4
|
||||||
|
_, bits := ipNet.Mask.Size()
|
||||||
|
if bits != 32 {
|
||||||
|
return fmt.Errorf("invalid IPv4 CIDR mask (got %d bits, expected 32): %s", bits, entry)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1,7 +1,10 @@
|
|||||||
// FILE: logwisp/src/internal/config/ssl.go
|
// FILE: logwisp/src/internal/config/ssl.go
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
type SSLConfig struct {
|
type SSLConfig struct {
|
||||||
Enabled bool `toml:"enabled"`
|
Enabled bool `toml:"enabled"`
|
||||||
@ -13,6 +16,9 @@ type SSLConfig struct {
|
|||||||
ClientCAFile string `toml:"client_ca_file"`
|
ClientCAFile string `toml:"client_ca_file"`
|
||||||
VerifyClientCert bool `toml:"verify_client_cert"`
|
VerifyClientCert bool `toml:"verify_client_cert"`
|
||||||
|
|
||||||
|
// Option to skip verification for clients
|
||||||
|
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
|
||||||
|
|
||||||
// TLS version constraints
|
// TLS version constraints
|
||||||
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
|
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
|
||||||
MaxVersion string `toml:"max_version"`
|
MaxVersion string `toml:"max_version"`
|
||||||
@ -31,11 +37,27 @@ func validateSSLOptions(serverType, pipelineName string, sinkIndex int, ssl map[
|
|||||||
pipelineName, sinkIndex, serverType)
|
pipelineName, sinkIndex, serverType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that certificate files exist and are readable
|
||||||
|
if _, err := os.Stat(certFile); err != nil {
|
||||||
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: cert_file is not accessible: %w",
|
||||||
|
pipelineName, sinkIndex, serverType, err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(keyFile); err != nil {
|
||||||
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: key_file is not accessible: %w",
|
||||||
|
pipelineName, sinkIndex, serverType, err)
|
||||||
|
}
|
||||||
|
|
||||||
if clientAuth, ok := ssl["client_auth"].(bool); ok && clientAuth {
|
if clientAuth, ok := ssl["client_auth"].(bool); ok && clientAuth {
|
||||||
if caFile, ok := ssl["client_ca_file"].(string); !ok || caFile == "" {
|
caFile, caOk := ssl["client_ca_file"].(string)
|
||||||
|
if !caOk || caFile == "" {
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified",
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified",
|
||||||
pipelineName, sinkIndex, serverType)
|
pipelineName, sinkIndex, serverType)
|
||||||
}
|
}
|
||||||
|
// Validate that the client CA file exists and is readable
|
||||||
|
if _, err := os.Stat(caFile); err != nil {
|
||||||
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: client_ca_file is not accessible: %w",
|
||||||
|
pipelineName, sinkIndex, serverType, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate TLS versions
|
// Validate TLS versions
|
||||||
|
|||||||
@ -72,11 +72,6 @@ func (c *Config) validate() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate net access if present
|
|
||||||
if err := validateNetAccess(pipeline.Name, pipeline.NetAccess); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate auth if present
|
// Validate auth if present
|
||||||
if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil {
|
if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@ -1,173 +0,0 @@
|
|||||||
// FILE: src/internal/limit/ip.go
|
|
||||||
package limit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"logwisp/src/internal/config"
|
|
||||||
|
|
||||||
"github.com/lixenwraith/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IPChecker handles IP-based access control lists
|
|
||||||
type IPChecker struct {
|
|
||||||
ipWhitelist []*net.IPNet
|
|
||||||
ipBlacklist []*net.IPNet
|
|
||||||
logger *log.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewIPChecker creates a new IPChecker. Returns nil if no rules are defined.
|
|
||||||
func NewIPChecker(cfg *config.NetAccessConfig, logger *log.Logger) *IPChecker {
|
|
||||||
if cfg == nil || (len(cfg.IPWhitelist) == 0 && len(cfg.IPBlacklist) == 0) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &IPChecker{
|
|
||||||
ipWhitelist: make([]*net.IPNet, 0),
|
|
||||||
ipBlacklist: make([]*net.IPNet, 0),
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse whitelist entries
|
|
||||||
for _, cidr := range cfg.IPWhitelist {
|
|
||||||
if !strings.Contains(cidr, "/") {
|
|
||||||
cidr = cidr + "/32"
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR(cidr)
|
|
||||||
if err != nil {
|
|
||||||
// Try parsing as plain IP
|
|
||||||
if ip := net.ParseIP(cidr); ip != nil {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)}
|
|
||||||
} else {
|
|
||||||
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Warn("msg", "Skipping invalid IP whitelist entry",
|
|
||||||
"component", "ip_checker",
|
|
||||||
"entry", cidr,
|
|
||||||
"error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.ipWhitelist = append(c.ipWhitelist, ipNet)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse blacklist entries
|
|
||||||
for _, cidr := range cfg.IPBlacklist {
|
|
||||||
if !strings.Contains(cidr, "/") {
|
|
||||||
cidr = cidr + "/32"
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR(cidr)
|
|
||||||
if err != nil {
|
|
||||||
// Try parsing as plain IP
|
|
||||||
if ip := net.ParseIP(cidr); ip != nil {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)}
|
|
||||||
} else {
|
|
||||||
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Warn("msg", "Skipping invalid IP blacklist entry",
|
|
||||||
"component", "ip_checker",
|
|
||||||
"entry", cidr,
|
|
||||||
"error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.ipBlacklist = append(c.ipBlacklist, ipNet)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("msg", "IP checker initialized",
|
|
||||||
"component", "ip_checker",
|
|
||||||
"whitelist_rules", len(c.ipWhitelist),
|
|
||||||
"blacklist_rules", len(c.ipBlacklist))
|
|
||||||
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsAllowed validates if a remote address is permitted
|
|
||||||
func (c *IPChecker) IsAllowed(remoteAddr net.Addr) bool {
|
|
||||||
if c == nil {
|
|
||||||
return true // No checker = allow all
|
|
||||||
}
|
|
||||||
|
|
||||||
// No rules = allow all
|
|
||||||
if len(c.ipWhitelist) == 0 && len(c.ipBlacklist) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract IP from address
|
|
||||||
var ipStr string
|
|
||||||
switch addr := remoteAddr.(type) {
|
|
||||||
case *net.TCPAddr:
|
|
||||||
ipStr = addr.IP.String()
|
|
||||||
case *net.UDPAddr:
|
|
||||||
ipStr = addr.IP.String()
|
|
||||||
default:
|
|
||||||
// Try string parsing
|
|
||||||
addrStr := remoteAddr.String()
|
|
||||||
host, _, err := net.SplitHostPort(addrStr)
|
|
||||||
if err != nil {
|
|
||||||
ipStr = addrStr
|
|
||||||
} else {
|
|
||||||
ipStr = host
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP(ipStr)
|
|
||||||
if ip == nil {
|
|
||||||
c.logger.Warn("msg", "Could not parse remote address to IP",
|
|
||||||
"component", "ip_checker",
|
|
||||||
"remote_addr", remoteAddr.String())
|
|
||||||
return false // Deny unparseable addresses
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check blacklist first (deny takes precedence)
|
|
||||||
for _, ipNet := range c.ipBlacklist {
|
|
||||||
if ipNet.Contains(ip) {
|
|
||||||
c.logger.Warn("msg", "Blacklisted IP denied",
|
|
||||||
"component", "ip_checker",
|
|
||||||
"ip", ipStr,
|
|
||||||
"rule", ipNet.String())
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If whitelist is configured, IP must be in it
|
|
||||||
if len(c.ipWhitelist) > 0 {
|
|
||||||
for _, ipNet := range c.ipWhitelist {
|
|
||||||
if ipNet.Contains(ip) {
|
|
||||||
c.logger.Debug("msg", "IP allowed by whitelist",
|
|
||||||
"component", "ip_checker",
|
|
||||||
"ip", ipStr,
|
|
||||||
"rule", ipNet.String())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// No whitelist match = deny
|
|
||||||
c.logger.Warn("msg", "IP not in whitelist",
|
|
||||||
"component", "ip_checker",
|
|
||||||
"ip", ipStr)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// No blacklist match + no whitelist configured = allow
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStats returns IP checker statistics
|
|
||||||
func (c *IPChecker) GetStats() map[string]any {
|
|
||||||
if c == nil {
|
|
||||||
return map[string]any{"enabled": false}
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[string]any{
|
|
||||||
"enabled": true,
|
|
||||||
"whitelist_rules": len(c.ipWhitelist),
|
|
||||||
"blacklist_rules": len(c.ipBlacklist),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -14,11 +14,32 @@ import (
|
|||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DenialReason indicates why a request was denied
|
||||||
|
type DenialReason string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// IPv4Only is the enforcement message for IPv6 rejection
|
||||||
|
IPv4Only = "IPv4-only (IPv6 not supported)"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReasonAllowed DenialReason = ""
|
||||||
|
ReasonBlacklisted DenialReason = "IP denied by blacklist"
|
||||||
|
ReasonNotWhitelisted DenialReason = "IP not in whitelist"
|
||||||
|
ReasonRateLimited DenialReason = "Rate limit exceeded"
|
||||||
|
ReasonConnectionLimited DenialReason = "Connection limit exceeded"
|
||||||
|
ReasonInvalidIP DenialReason = "Invalid IP address"
|
||||||
|
)
|
||||||
|
|
||||||
// NetLimiter manages net limiting for a transport
|
// NetLimiter manages net limiting for a transport
|
||||||
type NetLimiter struct {
|
type NetLimiter struct {
|
||||||
config config.NetLimitConfig
|
config config.NetLimitConfig
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
||||||
|
// IP Access Control Lists
|
||||||
|
ipWhitelist []*net.IPNet
|
||||||
|
ipBlacklist []*net.IPNet
|
||||||
|
|
||||||
// Per-IP limiters
|
// Per-IP limiters
|
||||||
ipLimiters map[string]*ipLimiter
|
ipLimiters map[string]*ipLimiter
|
||||||
ipMu sync.RWMutex
|
ipMu sync.RWMutex
|
||||||
@ -27,17 +48,22 @@ type NetLimiter struct {
|
|||||||
globalLimiter *TokenBucket
|
globalLimiter *TokenBucket
|
||||||
|
|
||||||
// Connection tracking
|
// Connection tracking
|
||||||
ipConnections map[string]*atomic.Int64
|
ipConnections map[string]*connTracker
|
||||||
connMu sync.RWMutex
|
connMu sync.RWMutex
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
totalRequests atomic.Uint64
|
totalRequests atomic.Uint64
|
||||||
blockedRequests atomic.Uint64
|
blockedByBlacklist atomic.Uint64
|
||||||
uniqueIPs atomic.Uint64
|
blockedByWhitelist atomic.Uint64
|
||||||
|
blockedByRateLimit atomic.Uint64
|
||||||
|
blockedByConnLimit atomic.Uint64
|
||||||
|
blockedByInvalidIP atomic.Uint64
|
||||||
|
uniqueIPs atomic.Uint64
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
lastCleanup time.Time
|
lastCleanup time.Time
|
||||||
cleanupMu sync.Mutex
|
cleanupMu sync.Mutex
|
||||||
|
cleanupActive atomic.Bool
|
||||||
|
|
||||||
// Lifecycle management
|
// Lifecycle management
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@ -51,9 +77,20 @@ type ipLimiter struct {
|
|||||||
connections atomic.Int64
|
connections atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Connection tracking with activity timestamp
|
||||||
|
type connTracker struct {
|
||||||
|
connections atomic.Int64
|
||||||
|
lastSeen time.Time
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
// Creates a new net limiter
|
// Creates a new net limiter
|
||||||
func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||||
if !cfg.Enabled {
|
// Return nil only if nothing is configured
|
||||||
|
hasACL := len(cfg.IPWhitelist) > 0 || len(cfg.IPBlacklist) > 0
|
||||||
|
hasRateLimit := cfg.Enabled
|
||||||
|
|
||||||
|
if !hasACL && !hasRateLimit {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,28 +102,39 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
|||||||
|
|
||||||
l := &NetLimiter{
|
l := &NetLimiter{
|
||||||
config: cfg,
|
config: cfg,
|
||||||
ipLimiters: make(map[string]*ipLimiter),
|
|
||||||
ipConnections: make(map[string]*atomic.Int64),
|
|
||||||
lastCleanup: time.Now(),
|
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
ipWhitelist: make([]*net.IPNet, 0),
|
||||||
|
ipBlacklist: make([]*net.IPNet, 0),
|
||||||
|
ipLimiters: make(map[string]*ipLimiter),
|
||||||
|
ipConnections: make(map[string]*connTracker),
|
||||||
|
lastCleanup: time.Now(),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
cleanupDone: make(chan struct{}),
|
cleanupDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create global limiter if not using per-IP limiting
|
// Parse IP lists
|
||||||
if cfg.LimitBy == "global" {
|
l.parseIPLists(cfg)
|
||||||
|
|
||||||
|
// Create global limiter if configured
|
||||||
|
if cfg.Enabled && cfg.LimitBy == "global" {
|
||||||
l.globalLimiter = NewTokenBucket(
|
l.globalLimiter = NewTokenBucket(
|
||||||
float64(cfg.BurstSize),
|
float64(cfg.BurstSize),
|
||||||
cfg.RequestsPerSecond,
|
cfg.RequestsPerSecond,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup goroutine
|
// Start cleanup goroutine only if rate limiting is enabled
|
||||||
go l.cleanupLoop()
|
if cfg.Enabled {
|
||||||
|
go l.cleanupLoop()
|
||||||
|
}
|
||||||
|
|
||||||
l.logger.Info("msg", "Net limiter initialized",
|
logger.Info("msg", "Net limiter initialized",
|
||||||
"component", "netlimit",
|
"component", "netlimit",
|
||||||
|
"acl_enabled", hasACL,
|
||||||
|
"rate_limiting", cfg.Enabled,
|
||||||
|
"whitelist_rules", len(l.ipWhitelist),
|
||||||
|
"blacklist_rules", len(l.ipBlacklist),
|
||||||
"requests_per_second", cfg.RequestsPerSecond,
|
"requests_per_second", cfg.RequestsPerSecond,
|
||||||
"burst_size", cfg.BurstSize,
|
"burst_size", cfg.BurstSize,
|
||||||
"limit_by", cfg.LimitBy)
|
"limit_by", cfg.LimitBy)
|
||||||
@ -94,6 +142,120 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseIPLists parses and validates IP whitelist/blacklist
|
||||||
|
func (l *NetLimiter) parseIPLists(cfg config.NetLimitConfig) {
|
||||||
|
// Parse whitelist
|
||||||
|
for _, entry := range cfg.IPWhitelist {
|
||||||
|
if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil {
|
||||||
|
l.ipWhitelist = append(l.ipWhitelist, ipNet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse blacklist
|
||||||
|
for _, entry := range cfg.IPBlacklist {
|
||||||
|
if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil {
|
||||||
|
l.ipBlacklist = append(l.ipBlacklist, ipNet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIPEntry parses a single IP or CIDR entry
|
||||||
|
func (l *NetLimiter) parseIPEntry(entry, listType string) *net.IPNet {
|
||||||
|
// Handle single IP
|
||||||
|
if !strings.Contains(entry, "/") {
|
||||||
|
ip := net.ParseIP(entry)
|
||||||
|
if ip == nil {
|
||||||
|
l.logger.Warn("msg", "Invalid IP entry",
|
||||||
|
"component", "netlimit",
|
||||||
|
"list", listType,
|
||||||
|
"entry", entry)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject IPv6
|
||||||
|
if ip.To4() == nil {
|
||||||
|
l.logger.Warn("msg", "IPv6 address rejected",
|
||||||
|
"component", "netlimit",
|
||||||
|
"list", listType,
|
||||||
|
"entry", entry,
|
||||||
|
"reason", IPv4Only)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &net.IPNet{IP: ip.To4(), Mask: net.CIDRMask(32, 32)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse CIDR
|
||||||
|
ipAddr, ipNet, err := net.ParseCIDR(entry)
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Warn("msg", "Invalid CIDR entry",
|
||||||
|
"component", "netlimit",
|
||||||
|
"list", listType,
|
||||||
|
"entry", entry,
|
||||||
|
"error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject IPv6 CIDR
|
||||||
|
if ipAddr.To4() == nil {
|
||||||
|
l.logger.Warn("msg", "IPv6 CIDR rejected",
|
||||||
|
"component", "netlimit",
|
||||||
|
"list", listType,
|
||||||
|
"entry", entry,
|
||||||
|
"reason", IPv4Only)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure mask is IPv4
|
||||||
|
_, bits := ipNet.Mask.Size()
|
||||||
|
if bits != 32 {
|
||||||
|
l.logger.Warn("msg", "Non-IPv4 CIDR mask rejected",
|
||||||
|
"component", "netlimit",
|
||||||
|
"list", listType,
|
||||||
|
"entry", entry,
|
||||||
|
"mask_bits", bits,
|
||||||
|
"reason", IPv4Only)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &net.IPNet{IP: ipAddr.To4(), Mask: ipNet.Mask}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkIPAccess checks if an IP is allowed by ACLs
|
||||||
|
func (l *NetLimiter) checkIPAccess(ip net.IP) DenialReason {
|
||||||
|
// 1. Check blacklist first (deny takes precedence)
|
||||||
|
for _, ipNet := range l.ipBlacklist {
|
||||||
|
if ipNet.Contains(ip) {
|
||||||
|
l.blockedByBlacklist.Add(1)
|
||||||
|
l.logger.Debug("msg", "IP denied by blacklist",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", ip.String(),
|
||||||
|
"rule", ipNet.String())
|
||||||
|
return ReasonBlacklisted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. If whitelist is configured, IP must be in it
|
||||||
|
if len(l.ipWhitelist) > 0 {
|
||||||
|
for _, ipNet := range l.ipWhitelist {
|
||||||
|
if ipNet.Contains(ip) {
|
||||||
|
l.logger.Debug("msg", "IP allowed by whitelist",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", ip.String(),
|
||||||
|
"rule", ipNet.String())
|
||||||
|
return ReasonAllowed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.blockedByWhitelist.Add(1)
|
||||||
|
l.logger.Debug("msg", "IP not in whitelist",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", ip.String())
|
||||||
|
return ReasonNotWhitelisted
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReasonAllowed
|
||||||
|
}
|
||||||
|
|
||||||
func (l *NetLimiter) Shutdown() {
|
func (l *NetLimiter) Shutdown() {
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return
|
return
|
||||||
@ -121,9 +283,9 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6
|
|||||||
|
|
||||||
l.totalRequests.Add(1)
|
l.totalRequests.Add(1)
|
||||||
|
|
||||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
// Parse IP address
|
||||||
|
ipStr, _, err := net.SplitHostPort(remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If we can't parse the IP, allow the request but log
|
|
||||||
l.logger.Warn("msg", "Failed to parse remote addr",
|
l.logger.Warn("msg", "Failed to parse remote addr",
|
||||||
"component", "netlimit",
|
"component", "netlimit",
|
||||||
"remote_addr", remoteAddr,
|
"remote_addr", remoteAddr,
|
||||||
@ -131,56 +293,82 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6
|
|||||||
return true, 0, ""
|
return true, 0, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only supporting ipv4
|
ip := net.ParseIP(ipStr)
|
||||||
if !isIPv4(ip) {
|
if ip == nil {
|
||||||
// Block non-IPv4 addresses to prevent complications
|
l.blockedByInvalidIP.Add(1)
|
||||||
l.blockedRequests.Add(1)
|
l.logger.Warn("msg", "Failed to parse IP",
|
||||||
l.logger.Warn("msg", "Non-IPv4 address blocked",
|
|
||||||
"component", "netlimit",
|
"component", "netlimit",
|
||||||
"ip", ip)
|
"ip", ipStr)
|
||||||
return false, 403, "IPv4 only"
|
return false, 403, string(ReasonInvalidIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check connection limit for streaming endpoint
|
// Reject IPv6 connections
|
||||||
|
if !isIPv4(ip) {
|
||||||
|
l.blockedByInvalidIP.Add(1)
|
||||||
|
l.logger.Warn("msg", "IPv6 connection rejected",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", ipStr,
|
||||||
|
"reason", IPv4Only)
|
||||||
|
return false, 403, IPv4Only
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize to IPv4 representation
|
||||||
|
ip = ip.To4()
|
||||||
|
|
||||||
|
// Check IP access control
|
||||||
|
if reason := l.checkIPAccess(ip); reason != ReasonAllowed {
|
||||||
|
return false, 403, string(reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If rate limiting is not enabled, allow
|
||||||
|
if !l.config.Enabled {
|
||||||
|
return true, 0, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check connection limits
|
||||||
if l.config.MaxConnectionsPerIP > 0 {
|
if l.config.MaxConnectionsPerIP > 0 {
|
||||||
l.connMu.RLock()
|
l.connMu.RLock()
|
||||||
counter, exists := l.ipConnections[ip]
|
tracker, exists := l.ipConnections[ipStr]
|
||||||
l.connMu.RUnlock()
|
l.connMu.RUnlock()
|
||||||
|
|
||||||
if exists && counter.Load() >= l.config.MaxConnectionsPerIP {
|
if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
|
||||||
l.blockedRequests.Add(1)
|
l.blockedByConnLimit.Add(1)
|
||||||
statusCode = l.config.ResponseCode
|
statusCode = l.config.ResponseCode
|
||||||
if statusCode == 0 {
|
if statusCode == 0 {
|
||||||
statusCode = 429
|
statusCode = 429
|
||||||
}
|
}
|
||||||
message = "Connection limit exceeded"
|
return false, statusCode, string(ReasonConnectionLimited)
|
||||||
|
|
||||||
l.logger.Warn("msg", "Connection limit exceeded",
|
|
||||||
"component", "netlimit",
|
|
||||||
"ip", ip,
|
|
||||||
"connections", counter.Load(),
|
|
||||||
"limit", l.config.MaxConnectionsPerIP)
|
|
||||||
|
|
||||||
return false, statusCode, message
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check net limit
|
// Check rate limit
|
||||||
allowed = l.checkLimit(ip)
|
if !l.checkLimit(ipStr) {
|
||||||
if !allowed {
|
l.blockedByRateLimit.Add(1)
|
||||||
l.blockedRequests.Add(1)
|
|
||||||
statusCode = l.config.ResponseCode
|
statusCode = l.config.ResponseCode
|
||||||
if statusCode == 0 {
|
if statusCode == 0 {
|
||||||
statusCode = 429
|
statusCode = 429
|
||||||
}
|
}
|
||||||
message = l.config.ResponseMessage
|
message = l.config.ResponseMessage
|
||||||
if message == "" {
|
if message == "" {
|
||||||
message = "Net limit exceeded"
|
message = string(ReasonRateLimited)
|
||||||
}
|
}
|
||||||
l.logger.Debug("msg", "Request net limited", "ip", ip)
|
return false, statusCode, message
|
||||||
}
|
}
|
||||||
|
|
||||||
return allowed, statusCode, message
|
return true, 0, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update connection activity
|
||||||
|
func (l *NetLimiter) updateConnectionActivity(ip string) {
|
||||||
|
l.connMu.RLock()
|
||||||
|
tracker, exists := l.ipConnections[ip]
|
||||||
|
l.connMu.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
tracker.mu.Lock()
|
||||||
|
tracker.lastSeen = time.Now()
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks if a TCP connection should be allowed
|
// Checks if a TCP connection should be allowed
|
||||||
@ -194,32 +382,45 @@ func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
|
|||||||
// Extract IP from TCP addr
|
// Extract IP from TCP addr
|
||||||
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return true
|
l.blockedByInvalidIP.Add(1)
|
||||||
}
|
|
||||||
|
|
||||||
ip := tcpAddr.IP.String()
|
|
||||||
|
|
||||||
// Only supporting ipv4
|
|
||||||
if !isIPv4(ip) {
|
|
||||||
l.blockedRequests.Add(1)
|
|
||||||
l.logger.Warn("msg", "Non-IPv4 TCP connection blocked",
|
|
||||||
"component", "netlimit",
|
|
||||||
"ip", ip)
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
allowed := l.checkLimit(ip)
|
// Reject IPv6 connections
|
||||||
if !allowed {
|
if !isIPv4(tcpAddr.IP) {
|
||||||
l.blockedRequests.Add(1)
|
l.blockedByInvalidIP.Add(1)
|
||||||
l.logger.Debug("msg", "TCP connection net limited", "ip", ip)
|
l.logger.Warn("msg", "IPv6 TCP connection rejected",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", tcpAddr.IP.String(),
|
||||||
|
"reason", IPv4Only)
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return allowed
|
// Normalize to IPv4 representation
|
||||||
|
ip := tcpAddr.IP.To4()
|
||||||
|
|
||||||
|
// Check IP access control
|
||||||
|
if reason := l.checkIPAccess(ip); reason != ReasonAllowed {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If rate limiting is not enabled, allow
|
||||||
|
if !l.config.Enabled {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rate limit
|
||||||
|
ipStr := tcpAddr.IP.String()
|
||||||
|
if !l.checkLimit(ipStr) {
|
||||||
|
l.blockedByRateLimit.Add(1)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIPv4(ip string) bool {
|
func isIPv4(ip net.IP) bool {
|
||||||
// Simple check: IPv4 addresses contain dots, IPv6 contain colons
|
return ip.To4() != nil
|
||||||
return strings.Contains(ip, ".") && !strings.Contains(ip, ":")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tracks a new connection for an IP
|
// Tracks a new connection for an IP
|
||||||
@ -230,23 +431,44 @@ func (l *NetLimiter) AddConnection(remoteAddr string) {
|
|||||||
|
|
||||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
l.logger.Warn("msg", "Failed to parse remote address in AddConnection",
|
||||||
|
"component", "netlimit",
|
||||||
|
"remote_addr", remoteAddr,
|
||||||
|
"error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP validation
|
||||||
|
parsedIP := net.ParseIP(ip)
|
||||||
|
if parsedIP == nil {
|
||||||
|
l.logger.Warn("msg", "Failed to parse IP in AddConnection",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only supporting ipv4
|
// Only supporting ipv4
|
||||||
if !isIPv4(ip) {
|
if !isIPv4(parsedIP) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.connMu.Lock()
|
l.connMu.Lock()
|
||||||
counter, exists := l.ipConnections[ip]
|
tracker, exists := l.ipConnections[ip]
|
||||||
if !exists {
|
if !exists {
|
||||||
counter = &atomic.Int64{}
|
// Create new tracker with timestamp
|
||||||
l.ipConnections[ip] = counter
|
tracker = &connTracker{
|
||||||
|
lastSeen: time.Now(),
|
||||||
|
}
|
||||||
|
l.ipConnections[ip] = tracker
|
||||||
}
|
}
|
||||||
l.connMu.Unlock()
|
l.connMu.Unlock()
|
||||||
|
|
||||||
newCount := counter.Add(1)
|
newCount := tracker.connections.Add(1)
|
||||||
|
// Update activity timestamp
|
||||||
|
tracker.mu.Lock()
|
||||||
|
tracker.lastSeen = time.Now()
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
|
||||||
l.logger.Debug("msg", "Connection added",
|
l.logger.Debug("msg", "Connection added",
|
||||||
"ip", ip,
|
"ip", ip,
|
||||||
"connections", newCount)
|
"connections", newCount)
|
||||||
@ -260,20 +482,33 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) {
|
|||||||
|
|
||||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
l.logger.Warn("msg", "Failed to parse remote address in RemoveConnection",
|
||||||
|
"component", "netlimit",
|
||||||
|
"remote_addr", remoteAddr,
|
||||||
|
"error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP validation
|
||||||
|
parsedIP := net.ParseIP(ip)
|
||||||
|
if parsedIP == nil {
|
||||||
|
l.logger.Warn("msg", "Failed to parse IP in RemoveConnection",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only supporting ipv4
|
// Only supporting ipv4
|
||||||
if !isIPv4(ip) {
|
if !isIPv4(parsedIP) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.connMu.RLock()
|
l.connMu.RLock()
|
||||||
counter, exists := l.ipConnections[ip]
|
tracker, exists := l.ipConnections[ip]
|
||||||
l.connMu.RUnlock()
|
l.connMu.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists {
|
||||||
newCount := counter.Add(-1)
|
newCount := tracker.connections.Add(-1)
|
||||||
l.logger.Debug("msg", "Connection removed",
|
l.logger.Debug("msg", "Connection removed",
|
||||||
"ip", ip,
|
"ip", ip,
|
||||||
"connections", newCount)
|
"connections", newCount)
|
||||||
@ -281,7 +516,7 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) {
|
|||||||
if newCount <= 0 {
|
if newCount <= 0 {
|
||||||
// Clean up if no more connections
|
// Clean up if no more connections
|
||||||
l.connMu.Lock()
|
l.connMu.Lock()
|
||||||
if counter.Load() <= 0 {
|
if tracker.connections.Load() <= 0 {
|
||||||
delete(l.ipConnections, ip)
|
delete(l.ipConnections, ip)
|
||||||
}
|
}
|
||||||
l.connMu.Unlock()
|
l.connMu.Unlock()
|
||||||
@ -292,9 +527,7 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) {
|
|||||||
// Returns net limiter statistics
|
// Returns net limiter statistics
|
||||||
func (l *NetLimiter) GetStats() map[string]any {
|
func (l *NetLimiter) GetStats() map[string]any {
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return map[string]any{
|
return map[string]any{"enabled": false}
|
||||||
"enabled": false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
l.ipMu.RLock()
|
l.ipMu.RLock()
|
||||||
@ -303,18 +536,36 @@ func (l *NetLimiter) GetStats() map[string]any {
|
|||||||
|
|
||||||
l.connMu.RLock()
|
l.connMu.RLock()
|
||||||
totalConnections := 0
|
totalConnections := 0
|
||||||
for _, counter := range l.ipConnections {
|
for _, tracker := range l.ipConnections {
|
||||||
totalConnections += int(counter.Load())
|
totalConnections += int(tracker.connections.Load())
|
||||||
}
|
}
|
||||||
l.connMu.RUnlock()
|
l.connMu.RUnlock()
|
||||||
|
|
||||||
|
totalBlocked := l.blockedByBlacklist.Load() +
|
||||||
|
l.blockedByWhitelist.Load() +
|
||||||
|
l.blockedByRateLimit.Load() +
|
||||||
|
l.blockedByConnLimit.Load() +
|
||||||
|
l.blockedByInvalidIP.Load()
|
||||||
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"total_requests": l.totalRequests.Load(),
|
"total_requests": l.totalRequests.Load(),
|
||||||
"blocked_requests": l.blockedRequests.Load(),
|
"total_blocked": totalBlocked,
|
||||||
|
"blocked_breakdown": map[string]uint64{
|
||||||
|
"blacklist": l.blockedByBlacklist.Load(),
|
||||||
|
"whitelist": l.blockedByWhitelist.Load(),
|
||||||
|
"rate_limit": l.blockedByRateLimit.Load(),
|
||||||
|
"conn_limit": l.blockedByConnLimit.Load(),
|
||||||
|
"invalid_ip": l.blockedByInvalidIP.Load(),
|
||||||
|
},
|
||||||
"active_ips": activeIPs,
|
"active_ips": activeIPs,
|
||||||
"total_connections": totalConnections,
|
"total_connections": totalConnections,
|
||||||
"config": map[string]any{
|
"acl": map[string]int{
|
||||||
|
"whitelist_rules": len(l.ipWhitelist),
|
||||||
|
"blacklist_rules": len(l.ipBlacklist),
|
||||||
|
},
|
||||||
|
"rate_limit": map[string]any{
|
||||||
|
"enabled": l.config.Enabled,
|
||||||
"requests_per_second": l.config.RequestsPerSecond,
|
"requests_per_second": l.config.RequestsPerSecond,
|
||||||
"burst_size": l.config.BurstSize,
|
"burst_size": l.config.BurstSize,
|
||||||
"limit_by": l.config.LimitBy,
|
"limit_by": l.config.LimitBy,
|
||||||
@ -324,6 +575,15 @@ func (l *NetLimiter) GetStats() map[string]any {
|
|||||||
|
|
||||||
// Performs the actual net limit check
|
// Performs the actual net limit check
|
||||||
func (l *NetLimiter) checkLimit(ip string) bool {
|
func (l *NetLimiter) checkLimit(ip string) bool {
|
||||||
|
// Validate IP format
|
||||||
|
parsedIP := net.ParseIP(ip)
|
||||||
|
if parsedIP == nil || !isIPv4(parsedIP) {
|
||||||
|
l.logger.Warn("msg", "Invalid or non-IPv4 address in rate limiter",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", ip)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Maybe run cleanup
|
// Maybe run cleanup
|
||||||
l.maybeCleanup()
|
l.maybeCleanup()
|
||||||
|
|
||||||
@ -358,10 +618,10 @@ func (l *NetLimiter) checkLimit(ip string) bool {
|
|||||||
// Check connection limit if configured
|
// Check connection limit if configured
|
||||||
if l.config.MaxConnectionsPerIP > 0 {
|
if l.config.MaxConnectionsPerIP > 0 {
|
||||||
l.connMu.RLock()
|
l.connMu.RLock()
|
||||||
counter, exists := l.ipConnections[ip]
|
tracker, exists := l.ipConnections[ip]
|
||||||
l.connMu.RUnlock()
|
l.connMu.RUnlock()
|
||||||
|
|
||||||
if exists && counter.Load() >= l.config.MaxConnectionsPerIP {
|
if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -379,14 +639,27 @@ func (l *NetLimiter) checkLimit(ip string) bool {
|
|||||||
// Runs cleanup if enough time has passed
|
// Runs cleanup if enough time has passed
|
||||||
func (l *NetLimiter) maybeCleanup() {
|
func (l *NetLimiter) maybeCleanup() {
|
||||||
l.cleanupMu.Lock()
|
l.cleanupMu.Lock()
|
||||||
defer l.cleanupMu.Unlock()
|
|
||||||
|
|
||||||
|
// Check if enough time has passed
|
||||||
if time.Since(l.lastCleanup) < 30*time.Second {
|
if time.Since(l.lastCleanup) < 30*time.Second {
|
||||||
|
l.cleanupMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if cleanup already running
|
||||||
|
if !l.cleanupActive.CompareAndSwap(false, true) {
|
||||||
|
l.cleanupMu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.lastCleanup = time.Now()
|
l.lastCleanup = time.Now()
|
||||||
go l.cleanup()
|
l.cleanupMu.Unlock()
|
||||||
|
|
||||||
|
// Run cleanup async
|
||||||
|
go func() {
|
||||||
|
defer l.cleanupActive.Store(false)
|
||||||
|
l.cleanup()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Removes stale IP limiters
|
// Removes stale IP limiters
|
||||||
@ -397,6 +670,8 @@ func (l *NetLimiter) cleanup() {
|
|||||||
l.ipMu.Lock()
|
l.ipMu.Lock()
|
||||||
defer l.ipMu.Unlock()
|
defer l.ipMu.Unlock()
|
||||||
|
|
||||||
|
// Clean up rate limiters
|
||||||
|
l.ipMu.Lock()
|
||||||
cleaned := 0
|
cleaned := 0
|
||||||
for ip, lim := range l.ipLimiters {
|
for ip, lim := range l.ipLimiters {
|
||||||
if now.Sub(lim.lastSeen) > staleTimeout {
|
if now.Sub(lim.lastSeen) > staleTimeout {
|
||||||
@ -404,12 +679,37 @@ func (l *NetLimiter) cleanup() {
|
|||||||
cleaned++
|
cleaned++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
l.ipMu.Unlock()
|
||||||
|
|
||||||
if cleaned > 0 {
|
if cleaned > 0 {
|
||||||
l.logger.Debug("msg", "Cleaned up stale IP limiters",
|
l.logger.Debug("msg", "Cleaned up stale IP limiters",
|
||||||
|
"component", "netlimit",
|
||||||
"cleaned", cleaned,
|
"cleaned", cleaned,
|
||||||
"remaining", len(l.ipLimiters))
|
"remaining", len(l.ipLimiters))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean up stale connection trackers
|
||||||
|
l.connMu.Lock()
|
||||||
|
connCleaned := 0
|
||||||
|
for ip, tracker := range l.ipConnections {
|
||||||
|
tracker.mu.Lock()
|
||||||
|
lastSeen := tracker.lastSeen
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
|
||||||
|
// Remove if no activity for 5 minutes AND no active connections
|
||||||
|
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
|
||||||
|
delete(l.ipConnections, ip)
|
||||||
|
connCleaned++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.connMu.Unlock()
|
||||||
|
|
||||||
|
if connCleaned > 0 {
|
||||||
|
l.logger.Debug("msg", "Cleaned up stale connection trackers",
|
||||||
|
"component", "netlimit",
|
||||||
|
"cleaned", connCleaned,
|
||||||
|
"remaining", len(l.ipConnections))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runs periodic cleanup
|
// Runs periodic cleanup
|
||||||
|
|||||||
@ -139,9 +139,6 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
|
|||||||
|
|
||||||
// Configure authentication for sinks that support it
|
// Configure authentication for sinks that support it
|
||||||
for _, sinkInst := range pipeline.Sinks {
|
for _, sinkInst := range pipeline.Sinks {
|
||||||
if setter, ok := sinkInst.(sink.NetAccessSetter); ok {
|
|
||||||
setter.SetNetAccessConfig(cfg.NetAccess)
|
|
||||||
}
|
|
||||||
if setter, ok := sinkInst.(sink.AuthSetter); ok {
|
if setter, ok := sinkInst.(sink.AuthSetter); ok {
|
||||||
setter.SetAuthConfig(cfg.Auth)
|
setter.SetAuthConfig(cfg.Auth)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -48,7 +48,6 @@ type HTTPSink struct {
|
|||||||
|
|
||||||
// Net limiting
|
// Net limiting
|
||||||
netLimiter *limit.NetLimiter
|
netLimiter *limit.NetLimiter
|
||||||
ipChecker *limit.IPChecker
|
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
totalProcessed atomic.Uint64
|
totalProcessed atomic.Uint64
|
||||||
@ -156,6 +155,22 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
|||||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
||||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
cfg.NetLimit.MaxTotalConnections = maxTotal
|
||||||
}
|
}
|
||||||
|
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||||
|
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||||
|
for _, entry := range ipWhitelist {
|
||||||
|
if str, ok := entry.(string); ok {
|
||||||
|
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||||
|
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||||
|
for _, entry := range ipBlacklist {
|
||||||
|
if str, ok := entry.(string); ok {
|
||||||
|
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h := &HTTPSink{
|
h := &HTTPSink{
|
||||||
@ -195,8 +210,10 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
|||||||
|
|
||||||
// Configure TLS if enabled
|
// Configure TLS if enabled
|
||||||
if h.tlsManager != nil {
|
if h.tlsManager != nil {
|
||||||
tlsConfig := h.tlsManager.GetHTTPConfig()
|
h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
|
||||||
h.server.TLSConfig = tlsConfig
|
h.logger.Info("msg", "TLS enabled for HTTP sink",
|
||||||
|
"component", "http_sink",
|
||||||
|
"port", h.config.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%d", h.config.Port)
|
addr := fmt.Sprintf(":%d", h.config.Port)
|
||||||
@ -208,12 +225,13 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
|||||||
"component", "http_sink",
|
"component", "http_sink",
|
||||||
"port", h.config.Port,
|
"port", h.config.Port,
|
||||||
"stream_path", h.streamPath,
|
"stream_path", h.streamPath,
|
||||||
"status_path", h.statusPath)
|
"status_path", h.statusPath,
|
||||||
|
"tls_enabled", h.tlsManager != nil)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if h.tlsManager != nil {
|
if h.tlsManager != nil {
|
||||||
// HTTPS server
|
// HTTPS server
|
||||||
err = h.server.ListenAndServeTLS(addr, "", "")
|
err = h.server.ListenAndServeTLS(addr, h.config.SSL.CertFile, h.config.SSL.KeyFile)
|
||||||
} else {
|
} else {
|
||||||
// HTTP server
|
// HTTP server
|
||||||
err = h.server.ListenAndServe(addr)
|
err = h.server.ListenAndServe(addr)
|
||||||
@ -306,24 +324,18 @@ func (h *HTTPSink) GetStats() SinkStats {
|
|||||||
func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||||
remoteAddr := ctx.RemoteAddr().String()
|
remoteAddr := ctx.RemoteAddr().String()
|
||||||
|
|
||||||
// Check IP access control
|
|
||||||
if h.ipChecker != nil {
|
|
||||||
if !h.ipChecker.IsAllowed(ctx.RemoteAddr()) {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
|
||||||
ctx.SetContentType("text/plain")
|
|
||||||
ctx.SetBodyString("Forbidden")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check net limit
|
// Check net limit
|
||||||
if h.netLimiter != nil {
|
if h.netLimiter != nil {
|
||||||
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
||||||
ctx.SetStatusCode(int(statusCode))
|
ctx.SetStatusCode(int(statusCode))
|
||||||
ctx.SetContentType("application/json")
|
ctx.SetContentType("application/json")
|
||||||
|
h.logger.Warn("msg", "Net limited",
|
||||||
|
"component", "http_sink",
|
||||||
|
"remote_addr", remoteAddr,
|
||||||
|
"status_code", statusCode,
|
||||||
|
"error", message)
|
||||||
json.NewEncoder(ctx).Encode(map[string]any{
|
json.NewEncoder(ctx).Encode(map[string]any{
|
||||||
"error": message,
|
"error": "Too many requests",
|
||||||
"retry_after": "60", // seconds
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -355,7 +367,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
|
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
|
||||||
realm := h.authConfig.BasicAuth.Realm
|
realm := h.authConfig.BasicAuth.Realm
|
||||||
if realm == "" {
|
if realm == "" {
|
||||||
realm = "LogWisp"
|
realm = "Restricted"
|
||||||
}
|
}
|
||||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
|
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
|
||||||
} else if h.authConfig.Type == "bearer" {
|
} else if h.authConfig.Type == "bearer" {
|
||||||
@ -364,7 +376,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
|
|
||||||
ctx.SetContentType("application/json")
|
ctx.SetContentType("application/json")
|
||||||
json.NewEncoder(ctx).Encode(map[string]string{
|
json.NewEncoder(ctx).Encode(map[string]string{
|
||||||
"error": "Authentication required",
|
"error": "Unauthorized",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -379,8 +391,6 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
ctx.SetContentType("application/json")
|
ctx.SetContentType("application/json")
|
||||||
json.NewEncoder(ctx).Encode(map[string]any{
|
json.NewEncoder(ctx).Encode(map[string]any{
|
||||||
"error": "Not Found",
|
"error": "Not Found",
|
||||||
"message": fmt.Sprintf("Available endpoints: %s (SSE transport), %s (status)",
|
|
||||||
h.streamPath, h.statusPath),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -656,14 +666,6 @@ func (h *HTTPSink) GetStatusPath() string {
|
|||||||
return h.statusPath
|
return h.statusPath
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HTTPSink) SetNetAccessConfig(cfg *config.NetAccessConfig) {
|
|
||||||
h.ipChecker = limit.NewIPChecker(cfg, h.logger)
|
|
||||||
if h.ipChecker != nil {
|
|
||||||
h.logger.Info("msg", "IP access control configured for HTTP sink",
|
|
||||||
"component", "http_sink")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAuthConfig configures http sink authentication
|
// SetAuthConfig configures http sink authentication
|
||||||
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||||
if authCfg == nil || authCfg.Type == "none" {
|
if authCfg == nil || authCfg.Type == "none" {
|
||||||
|
|||||||
@ -4,8 +4,12 @@ package sink
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -55,6 +59,7 @@ type HTTPClientConfig struct {
|
|||||||
|
|
||||||
// TLS configuration
|
// TLS configuration
|
||||||
InsecureSkipVerify bool
|
InsecureSkipVerify bool
|
||||||
|
CAFile string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPClientSink creates a new HTTP client sink
|
// NewHTTPClientSink creates a new HTTP client sink
|
||||||
@ -126,6 +131,11 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
|
|||||||
cfg.Headers["Content-Type"] = "application/json"
|
cfg.Headers["Content-Type"] = "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract TLS options
|
||||||
|
if caFile, ok := options["ca_file"].(string); ok && caFile != "" {
|
||||||
|
cfg.CAFile = caFile
|
||||||
|
}
|
||||||
|
|
||||||
h := &HTTPClientSink{
|
h := &HTTPClientSink{
|
||||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||||
config: cfg,
|
config: cfg,
|
||||||
@ -147,17 +157,28 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
|
|||||||
DisableHeaderNamesNormalizing: true,
|
DisableHeaderNamesNormalizing: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Implement custom TLS configuration, including InsecureSkipVerify,
|
// Configure TLS if using HTTPS
|
||||||
// by setting a custom dialer on the fasthttp.Client.
|
if strings.HasPrefix(cfg.URL, "https://") {
|
||||||
// For example:
|
tlsConfig := &tls.Config{
|
||||||
// if cfg.InsecureSkipVerify {
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
// h.client.Dial = func(addr string) (net.Conn, error) {
|
}
|
||||||
// return fasthttp.DialDualStackTimeout(addr, cfg.Timeout, &tls.Config{
|
|
||||||
// InsecureSkipVerify: true,
|
// Load custom CA if provided
|
||||||
// })
|
if cfg.CAFile != "" {
|
||||||
// }
|
caCert, err := os.ReadFile(cfg.CAFile)
|
||||||
// }
|
if err != nil {
|
||||||
// FIXED: Removed incorrect TLS configuration that referenced non-existent field
|
return nil, fmt.Errorf("failed to read CA file: %w", err)
|
||||||
|
}
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
return nil, fmt.Errorf("failed to parse CA certificate")
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set TLS config directly on the client
|
||||||
|
h.client.TLSConfig = tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
@ -341,15 +362,22 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
|||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
// Wait before retry
|
// Wait before retry
|
||||||
time.Sleep(retryDelay)
|
time.Sleep(retryDelay)
|
||||||
retryDelay = time.Duration(float64(retryDelay) * h.config.RetryBackoff)
|
|
||||||
|
// Calculate new delay with overflow protection
|
||||||
|
newDelay := time.Duration(float64(retryDelay) * h.config.RetryBackoff)
|
||||||
|
|
||||||
|
// Cap at maximum to prevent integer overflow
|
||||||
|
if newDelay > h.config.Timeout || newDelay < retryDelay {
|
||||||
|
// Either exceeded max or overflowed (negative/wrapped)
|
||||||
|
retryDelay = h.config.Timeout
|
||||||
|
} else {
|
||||||
|
retryDelay = newDelay
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: defer placement issue
|
// Acquire resources inside loop, release immediately after use
|
||||||
// Create request
|
|
||||||
req := fasthttp.AcquireRequest()
|
req := fasthttp.AcquireRequest()
|
||||||
defer fasthttp.ReleaseRequest(req)
|
|
||||||
resp := fasthttp.AcquireResponse()
|
resp := fasthttp.AcquireResponse()
|
||||||
defer fasthttp.ReleaseResponse(resp)
|
|
||||||
|
|
||||||
req.SetRequestURI(h.config.URL)
|
req.SetRequestURI(h.config.URL)
|
||||||
req.Header.SetMethod("POST")
|
req.Header.SetMethod("POST")
|
||||||
@ -362,35 +390,50 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
|||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
err := h.client.DoTimeout(req, resp, h.config.Timeout)
|
err := h.client.DoTimeout(req, resp, h.config.Timeout)
|
||||||
|
|
||||||
|
// Capture response before releasing
|
||||||
|
statusCode := resp.StatusCode()
|
||||||
|
var responseBody []byte
|
||||||
|
if len(resp.Body()) > 0 {
|
||||||
|
responseBody = make([]byte, len(resp.Body()))
|
||||||
|
copy(responseBody, resp.Body())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release immediately, not deferred
|
||||||
|
fasthttp.ReleaseRequest(req)
|
||||||
|
fasthttp.ReleaseResponse(resp)
|
||||||
|
|
||||||
|
// Handle errors
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = fmt.Errorf("request failed: %w", err)
|
lastErr = fmt.Errorf("request failed: %w", err)
|
||||||
h.logger.Warn("msg", "HTTP request failed",
|
h.logger.Warn("msg", "HTTP request failed",
|
||||||
"component", "http_client_sink",
|
"component", "http_client_sink",
|
||||||
"attempt", attempt+1,
|
"attempt", attempt+1,
|
||||||
|
"max_retries", h.config.MaxRetries,
|
||||||
"error", err)
|
"error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check response status
|
// Check response status
|
||||||
statusCode := resp.StatusCode()
|
|
||||||
if statusCode >= 200 && statusCode < 300 {
|
if statusCode >= 200 && statusCode < 300 {
|
||||||
// Success
|
// Success
|
||||||
h.logger.Debug("msg", "Batch sent successfully",
|
h.logger.Debug("msg", "Batch sent successfully",
|
||||||
"component", "http_client_sink",
|
"component", "http_client_sink",
|
||||||
"batch_size", len(batch),
|
"batch_size", len(batch),
|
||||||
"status_code", statusCode)
|
"status_code", statusCode,
|
||||||
|
"attempt", attempt+1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non-2xx status
|
// Non-2xx status
|
||||||
lastErr = fmt.Errorf("server returned status %d: %s", statusCode, resp.Body())
|
lastErr = fmt.Errorf("server returned status %d: %s", statusCode, responseBody)
|
||||||
|
|
||||||
// Don't retry on 4xx errors (client errors)
|
// Don't retry on 4xx errors (client errors)
|
||||||
if statusCode >= 400 && statusCode < 500 {
|
if statusCode >= 400 && statusCode < 500 {
|
||||||
h.logger.Error("msg", "Batch rejected by server",
|
h.logger.Error("msg", "Batch rejected by server",
|
||||||
"component", "http_client_sink",
|
"component", "http_client_sink",
|
||||||
"status_code", statusCode,
|
"status_code", statusCode,
|
||||||
"response", string(resp.Body()),
|
"response", string(responseBody),
|
||||||
"batch_size", len(batch))
|
"batch_size", len(batch))
|
||||||
h.failedBatches.Add(1)
|
h.failedBatches.Add(1)
|
||||||
return
|
return
|
||||||
@ -400,13 +443,14 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
|||||||
"component", "http_client_sink",
|
"component", "http_client_sink",
|
||||||
"attempt", attempt+1,
|
"attempt", attempt+1,
|
||||||
"status_code", statusCode,
|
"status_code", statusCode,
|
||||||
"response", string(resp.Body()))
|
"response", string(responseBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
// All retries failed
|
// All retries exhausted
|
||||||
h.logger.Error("msg", "Failed to send batch after retries",
|
h.logger.Error("msg", "Failed to send batch after all retries",
|
||||||
"component", "http_client_sink",
|
"component", "http_client_sink",
|
||||||
"batch_size", len(batch),
|
"batch_size", len(batch),
|
||||||
|
"retries", h.config.MaxRetries,
|
||||||
"last_error", lastErr)
|
"last_error", lastErr)
|
||||||
h.failedBatches.Add(1)
|
h.failedBatches.Add(1)
|
||||||
}
|
}
|
||||||
@ -34,11 +34,6 @@ type SinkStats struct {
|
|||||||
Details map[string]any
|
Details map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
// NetAccessSetter is an interface for sinks that can accept network access configuration
|
|
||||||
type NetAccessSetter interface {
|
|
||||||
SetNetAccessConfig(cfg *config.NetAccessConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthSetter is an interface for sinks that can accept an AuthConfig.
|
// AuthSetter is an interface for sinks that can accept an AuthConfig.
|
||||||
type AuthSetter interface {
|
type AuthSetter interface {
|
||||||
SetAuthConfig(auth *config.AuthConfig)
|
SetAuthConfig(auth *config.AuthConfig)
|
||||||
|
|||||||
@ -36,7 +36,6 @@ type TCPSink struct {
|
|||||||
engineMu sync.Mutex
|
engineMu sync.Mutex
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
netLimiter *limit.NetLimiter
|
netLimiter *limit.NetLimiter
|
||||||
ipChecker *limit.IPChecker
|
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
formatter format.Formatter
|
formatter format.Formatter
|
||||||
|
|
||||||
@ -50,6 +49,11 @@ type TCPSink struct {
|
|||||||
lastProcessed atomic.Value // time.Time
|
lastProcessed atomic.Value // time.Time
|
||||||
authFailures atomic.Uint64
|
authFailures atomic.Uint64
|
||||||
authSuccesses atomic.Uint64
|
authSuccesses atomic.Uint64
|
||||||
|
|
||||||
|
// Write error tracking
|
||||||
|
writeErrors atomic.Uint64
|
||||||
|
consecutiveWriteErrors map[gnet.Conn]int
|
||||||
|
errorMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCPConfig holds TCP sink configuration
|
// TCPConfig holds TCP sink configuration
|
||||||
@ -141,6 +145,22 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
|||||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
||||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
cfg.NetLimit.MaxTotalConnections = maxTotal
|
||||||
}
|
}
|
||||||
|
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||||
|
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||||
|
for _, entry := range ipWhitelist {
|
||||||
|
if str, ok := entry.(string); ok {
|
||||||
|
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||||
|
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||||
|
for _, entry := range ipBlacklist {
|
||||||
|
if str, ok := entry.(string); ok {
|
||||||
|
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &TCPSink{
|
t := &TCPSink{
|
||||||
@ -191,17 +211,6 @@ func (t *TCPSink) Start(ctx context.Context) error {
|
|||||||
gnet.WithReusePort(true),
|
gnet.WithReusePort(true),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Add TLS if configured
|
|
||||||
if t.tlsManager != nil {
|
|
||||||
// tlsConfig := t.tlsManager.GetTCPConfig()
|
|
||||||
// TODO: tlsConfig is not used, wrapper to be implemented, non-TLS stream to be available without wrapper
|
|
||||||
// ☢ SECURITY: gnet doesn't support TLS natively - would need wrapper
|
|
||||||
// This is a limitation that requires implementing TLS at application layer
|
|
||||||
t.logger.Warn("msg", "TLS configured but gnet doesn't support native TLS",
|
|
||||||
"component", "tcp_sink",
|
|
||||||
"workaround", "Use stunnel or nginx TCP proxy for TLS termination")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start gnet server
|
// Start gnet server
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@ -338,7 +347,29 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
|||||||
t.server.mu.RLock()
|
t.server.mu.RLock()
|
||||||
for conn, client := range t.server.clients {
|
for conn, client := range t.server.clients {
|
||||||
if client.authenticated {
|
if client.authenticated {
|
||||||
conn.AsyncWrite(data, nil)
|
// Send through TLS bridge if present
|
||||||
|
if client.tlsBridge != nil {
|
||||||
|
if _, err := client.tlsBridge.Write(data); err != nil {
|
||||||
|
// TLS write failed, connection likely dead
|
||||||
|
t.logger.Debug("msg", "TLS write failed",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"error", err)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
t.writeErrors.Add(1)
|
||||||
|
t.handleWriteError(c, err)
|
||||||
|
} else {
|
||||||
|
// Reset consecutive error count on success
|
||||||
|
t.errorMu.Lock()
|
||||||
|
delete(t.consecutiveWriteErrors, c)
|
||||||
|
t.errorMu.Unlock()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
t.server.mu.RUnlock()
|
t.server.mu.RUnlock()
|
||||||
@ -364,7 +395,22 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
conn.AsyncWrite(data, nil)
|
if client.tlsBridge != nil {
|
||||||
|
if _, err := client.tlsBridge.Write(data); err != nil {
|
||||||
|
t.logger.Debug("msg", "TLS heartbeat write failed",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"error", err)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
t.writeErrors.Add(1)
|
||||||
|
t.handleWriteError(c, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
t.server.mu.RUnlock()
|
t.server.mu.RUnlock()
|
||||||
@ -375,6 +421,36 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle write errors with threshold-based connection termination
|
||||||
|
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
|
||||||
|
t.errorMu.Lock()
|
||||||
|
defer t.errorMu.Unlock()
|
||||||
|
|
||||||
|
// Track consecutive errors per connection
|
||||||
|
if t.consecutiveWriteErrors == nil {
|
||||||
|
t.consecutiveWriteErrors = make(map[gnet.Conn]int)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.consecutiveWriteErrors[c]++
|
||||||
|
errorCount := t.consecutiveWriteErrors[c]
|
||||||
|
|
||||||
|
t.logger.Debug("msg", "AsyncWrite error",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"remote_addr", c.RemoteAddr(),
|
||||||
|
"error", err,
|
||||||
|
"consecutive_errors", errorCount)
|
||||||
|
|
||||||
|
// Close connection after 3 consecutive write errors
|
||||||
|
if errorCount >= 3 {
|
||||||
|
t.logger.Warn("msg", "Closing connection due to repeated write errors",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"remote_addr", c.RemoteAddr(),
|
||||||
|
"error_count", errorCount)
|
||||||
|
delete(t.consecutiveWriteErrors, c)
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create heartbeat as a proper LogEntry
|
// Create heartbeat as a proper LogEntry
|
||||||
func (t *TCPSink) createHeartbeatEntry() core.LogEntry {
|
func (t *TCPSink) createHeartbeatEntry() core.LogEntry {
|
||||||
message := "heartbeat"
|
message := "heartbeat"
|
||||||
@ -406,11 +482,13 @@ func (t *TCPSink) GetActiveConnections() int64 {
|
|||||||
|
|
||||||
// tcpClient represents a connected TCP client with auth state
|
// tcpClient represents a connected TCP client with auth state
|
||||||
type tcpClient struct {
|
type tcpClient struct {
|
||||||
conn gnet.Conn
|
conn gnet.Conn
|
||||||
buffer bytes.Buffer
|
buffer bytes.Buffer
|
||||||
authenticated bool
|
authenticated bool
|
||||||
session *auth.Session
|
session *auth.Session
|
||||||
authTimeout time.Time
|
authTimeout time.Time
|
||||||
|
tlsBridge *tls.GNetTLSConn
|
||||||
|
authTimeoutSet bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// tcpServer handles gnet events with authentication
|
// tcpServer handles gnet events with authentication
|
||||||
@ -434,15 +512,13 @@ func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||||
remoteAddr := c.RemoteAddr().String()
|
remoteAddr := c.RemoteAddr()
|
||||||
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
|
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
|
||||||
|
|
||||||
// Check IP access control first
|
// Reject IPv6 connections immediately
|
||||||
if s.sink.ipChecker != nil {
|
if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok {
|
||||||
if !s.sink.ipChecker.IsAllowed(c.RemoteAddr()) {
|
if tcpAddr.IP.To4() == nil {
|
||||||
s.sink.logger.Warn("msg", "TCP connection denied by IP filter",
|
return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close
|
||||||
"remote_addr", remoteAddr)
|
|
||||||
return nil, gnet.Close
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -467,11 +543,26 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
s.sink.netLimiter.AddConnection(remoteStr)
|
s.sink.netLimiter.AddConnection(remoteStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create client state
|
// Create client state without auth timeout initially
|
||||||
client := &tcpClient{
|
client := &tcpClient{
|
||||||
conn: c,
|
conn: c,
|
||||||
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
||||||
authTimeout: time.Now().Add(30 * time.Second), // 30s to authenticate
|
authTimeoutSet: false, // Auth timeout not started yet
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize TLS bridge if enabled
|
||||||
|
if s.sink.tlsManager != nil {
|
||||||
|
tlsConfig := s.sink.tlsManager.GetTCPConfig()
|
||||||
|
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
|
||||||
|
client.tlsBridge.Handshake() // Start async handshake
|
||||||
|
|
||||||
|
s.sink.logger.Debug("msg", "TLS handshake initiated",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"remote_addr", remoteAddr)
|
||||||
|
} else if s.sink.authenticator != nil {
|
||||||
|
// Only set auth timeout if no TLS (plain connection)
|
||||||
|
client.authTimeout = time.Now().Add(30 * time.Second) // TODO: configurable or non-hardcoded timer
|
||||||
|
client.authTimeoutSet = true
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@ -485,7 +576,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
"requires_auth", s.sink.authenticator != nil)
|
"requires_auth", s.sink.authenticator != nil)
|
||||||
|
|
||||||
// Send auth prompt if authentication is required
|
// Send auth prompt if authentication is required
|
||||||
if s.sink.authenticator != nil {
|
if s.sink.authenticator != nil && s.sink.tlsManager == nil {
|
||||||
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
|
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
|
||||||
return authPrompt, gnet.None
|
return authPrompt, gnet.None
|
||||||
}
|
}
|
||||||
@ -498,9 +589,22 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
|||||||
|
|
||||||
// Remove client state
|
// Remove client state
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
client := s.clients[c]
|
||||||
delete(s.clients, c)
|
delete(s.clients, c)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Clean up TLS bridge if present
|
||||||
|
if client != nil && client.tlsBridge != nil {
|
||||||
|
client.tlsBridge.Close()
|
||||||
|
s.sink.logger.Debug("msg", "TLS connection closed",
|
||||||
|
"remote_addr", remoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up write error tracking
|
||||||
|
s.sink.errorMu.Lock()
|
||||||
|
delete(s.sink.consecutiveWriteErrors, c)
|
||||||
|
s.sink.errorMu.Unlock()
|
||||||
|
|
||||||
// Remove connection tracking
|
// Remove connection tracking
|
||||||
if s.sink.netLimiter != nil {
|
if s.sink.netLimiter != nil {
|
||||||
s.sink.netLimiter.RemoveConnection(remoteAddr)
|
s.sink.netLimiter.RemoveConnection(remoteAddr)
|
||||||
@ -523,13 +627,18 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
return gnet.Close
|
return gnet.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check auth timeout
|
// // Check auth timeout
|
||||||
if !client.authenticated && time.Now().After(client.authTimeout) {
|
// if !client.authenticated && time.Now().After(client.authTimeout) {
|
||||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
// s.sink.logger.Warn("msg", "Authentication timeout",
|
||||||
"remote_addr", c.RemoteAddr().String())
|
// "component", "tcp_sink",
|
||||||
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
// "remote_addr", c.RemoteAddr().String())
|
||||||
return gnet.Close
|
// if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
|
||||||
}
|
// client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
|
||||||
|
// } else if client.tlsBridge == nil {
|
||||||
|
// c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||||
|
// }
|
||||||
|
// return gnet.Close
|
||||||
|
// }
|
||||||
|
|
||||||
// Read all available data
|
// Read all available data
|
||||||
data, err := c.Next(-1)
|
data, err := c.Next(-1)
|
||||||
@ -540,6 +649,70 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
return gnet.Close
|
return gnet.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process through TLS bridge if present
|
||||||
|
if client.tlsBridge != nil {
|
||||||
|
// Feed encrypted data into TLS engine
|
||||||
|
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
|
||||||
|
s.sink.logger.Error("msg", "TLS processing error",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
|
"error", err)
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if handshake is complete
|
||||||
|
if !client.tlsBridge.IsHandshakeDone() {
|
||||||
|
// Still handshaking, wait for more data
|
||||||
|
return gnet.None
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check handshake result
|
||||||
|
_, hsErr := client.tlsBridge.HandshakeComplete()
|
||||||
|
if hsErr != nil {
|
||||||
|
s.sink.logger.Error("msg", "TLS handshake failed",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
|
"error", hsErr)
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set auth timeout only after TLS handshake completes
|
||||||
|
if !client.authTimeoutSet && s.sink.authenticator != nil && !client.authenticated {
|
||||||
|
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||||
|
client.authTimeoutSet = true
|
||||||
|
s.sink.logger.Debug("msg", "Auth timeout started after TLS handshake",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"remote_addr", c.RemoteAddr().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read decrypted plaintext
|
||||||
|
data = client.tlsBridge.Read()
|
||||||
|
if data == nil || len(data) == 0 {
|
||||||
|
// No plaintext available yet
|
||||||
|
return gnet.None
|
||||||
|
}
|
||||||
|
|
||||||
|
// First data after TLS handshake - send auth prompt if needed
|
||||||
|
if s.sink.authenticator != nil && !client.authenticated &&
|
||||||
|
len(client.buffer.Bytes()) == 0 {
|
||||||
|
authPrompt := []byte("AUTH REQUIRED\n")
|
||||||
|
client.tlsBridge.Write(authPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only check auth timeout if it has been set
|
||||||
|
if !client.authenticated && client.authTimeoutSet && time.Now().After(client.authTimeout) {
|
||||||
|
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||||
|
"component", "tcp_sink",
|
||||||
|
"remote_addr", c.RemoteAddr().String())
|
||||||
|
if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
|
||||||
|
client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
|
||||||
|
} else if client.tlsBridge == nil {
|
||||||
|
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||||
|
}
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
|
||||||
// If not authenticated, expect auth command
|
// If not authenticated, expect auth command
|
||||||
if !client.authenticated {
|
if !client.authenticated {
|
||||||
client.buffer.Write(data)
|
client.buffer.Write(data)
|
||||||
@ -551,7 +724,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
// Parse AUTH command: AUTH <method> <credentials>
|
// Parse AUTH command: AUTH <method> <credentials>
|
||||||
parts := strings.SplitN(string(line), " ", 3)
|
parts := strings.SplitN(string(line), " ", 3)
|
||||||
if len(parts) != 3 || parts[0] != "AUTH" {
|
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||||
c.AsyncWrite([]byte("ERROR: Invalid auth format\n"), nil)
|
// Send error through TLS if enabled
|
||||||
|
errMsg := []byte("AUTH FAILED\n")
|
||||||
|
if client.tlsBridge != nil {
|
||||||
|
client.tlsBridge.Write(errMsg)
|
||||||
|
} else {
|
||||||
|
c.AsyncWrite(errMsg, nil)
|
||||||
|
}
|
||||||
return gnet.None
|
return gnet.None
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -563,7 +742,13 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
"remote_addr", c.RemoteAddr().String(),
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
"method", parts[1],
|
"method", parts[1],
|
||||||
"error", err)
|
"error", err)
|
||||||
c.AsyncWrite([]byte(fmt.Sprintf("AUTH FAILED: %v\n", err)), nil)
|
// Send error through TLS if enabled
|
||||||
|
errMsg := []byte("AUTH FAILED\n")
|
||||||
|
if client.tlsBridge != nil {
|
||||||
|
client.tlsBridge.Write(errMsg)
|
||||||
|
} else {
|
||||||
|
c.AsyncWrite(errMsg, nil)
|
||||||
|
}
|
||||||
return gnet.Close
|
return gnet.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -575,11 +760,19 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
s.sink.logger.Info("msg", "TCP client authenticated",
|
s.sink.logger.Info("msg", "TCP client authenticated",
|
||||||
|
"component", "tcp_sink",
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
"username", session.Username,
|
"username", session.Username,
|
||||||
"method", session.Method)
|
"method", session.Method,
|
||||||
|
"tls", client.tlsBridge != nil)
|
||||||
|
|
||||||
c.AsyncWrite([]byte("AUTH OK\n"), nil)
|
// Send success through TLS if enabled
|
||||||
|
successMsg := []byte("AUTH OK\n")
|
||||||
|
if client.tlsBridge != nil {
|
||||||
|
client.tlsBridge.Write(successMsg)
|
||||||
|
} else {
|
||||||
|
c.AsyncWrite(successMsg, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// Clear buffer after auth
|
// Clear buffer after auth
|
||||||
client.buffer.Reset()
|
client.buffer.Reset()
|
||||||
@ -610,7 +803,7 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
|||||||
|
|
||||||
// Initialize TLS manager if SSL is configured
|
// Initialize TLS manager if SSL is configured
|
||||||
if t.config.SSL != nil && t.config.SSL.Enabled {
|
if t.config.SSL != nil && t.config.SSL.Enabled {
|
||||||
tlsManager, err := tls.New(t.config.SSL, t.logger)
|
tlsManager, err := tls.NewManager(t.config.SSL, t.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.logger.Error("msg", "Failed to create TLS manager",
|
t.logger.Error("msg", "Failed to create TLS manager",
|
||||||
"component", "tcp_sink",
|
"component", "tcp_sink",
|
||||||
@ -624,5 +817,6 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
|||||||
t.logger.Info("msg", "Authentication configured for TCP sink",
|
t.logger.Info("msg", "Authentication configured for TCP sink",
|
||||||
"component", "tcp_sink",
|
"component", "tcp_sink",
|
||||||
"auth_type", authCfg.Type,
|
"auth_type", authCfg.Type,
|
||||||
"tls_enabled", t.tlsManager != nil)
|
"tls_enabled", t.tlsManager != nil,
|
||||||
|
"tls_bridge", t.tlsManager != nil)
|
||||||
}
|
}
|
||||||
@ -3,6 +3,7 @@ package sink
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -10,8 +11,10 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"logwisp/src/internal/config"
|
||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
"logwisp/src/internal/format"
|
"logwisp/src/internal/format"
|
||||||
|
tlspkg "logwisp/src/internal/tls"
|
||||||
|
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
)
|
)
|
||||||
@ -28,6 +31,10 @@ type TCPClientSink struct {
|
|||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
formatter format.Formatter
|
formatter format.Formatter
|
||||||
|
|
||||||
|
// TLS support
|
||||||
|
tlsManager *tlspkg.Manager
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
// Reconnection state
|
// Reconnection state
|
||||||
reconnecting atomic.Bool
|
reconnecting atomic.Bool
|
||||||
lastConnectErr error
|
lastConnectErr error
|
||||||
@ -53,6 +60,9 @@ type TCPClientConfig struct {
|
|||||||
ReconnectDelay time.Duration
|
ReconnectDelay time.Duration
|
||||||
MaxReconnectDelay time.Duration
|
MaxReconnectDelay time.Duration
|
||||||
ReconnectBackoff float64
|
ReconnectBackoff float64
|
||||||
|
|
||||||
|
// TLS config
|
||||||
|
SSL *config.SSLConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPClientSink creates a new TCP client sink
|
// NewTCPClientSink creates a new TCP client sink
|
||||||
@ -103,6 +113,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
|||||||
cfg.ReconnectBackoff = backoff
|
cfg.ReconnectBackoff = backoff
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract SSL config
|
||||||
|
if ssl, ok := options["ssl"].(map[string]any); ok {
|
||||||
|
cfg.SSL = &config.SSLConfig{}
|
||||||
|
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
|
||||||
|
if certFile, ok := ssl["cert_file"].(string); ok {
|
||||||
|
cfg.SSL.CertFile = certFile
|
||||||
|
}
|
||||||
|
if keyFile, ok := ssl["key_file"].(string); ok {
|
||||||
|
cfg.SSL.KeyFile = keyFile
|
||||||
|
}
|
||||||
|
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
|
||||||
|
if caFile, ok := ssl["client_ca_file"].(string); ok {
|
||||||
|
cfg.SSL.ClientCAFile = caFile
|
||||||
|
}
|
||||||
|
if insecure, ok := ssl["insecure_skip_verify"].(bool); ok {
|
||||||
|
cfg.SSL.InsecureSkipVerify = insecure
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t := &TCPClientSink{
|
t := &TCPClientSink{
|
||||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||||
config: cfg,
|
config: cfg,
|
||||||
@ -114,6 +143,34 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
|||||||
t.lastProcessed.Store(time.Time{})
|
t.lastProcessed.Store(time.Time{})
|
||||||
t.connectionUptime.Store(time.Duration(0))
|
t.connectionUptime.Store(time.Duration(0))
|
||||||
|
|
||||||
|
// Initialize TLS manager if SSL is configured
|
||||||
|
if cfg.SSL != nil && cfg.SSL.Enabled {
|
||||||
|
tlsManager, err := tlspkg.NewManager(cfg.SSL, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||||
|
}
|
||||||
|
t.tlsManager = tlsManager
|
||||||
|
|
||||||
|
// Get client TLS config
|
||||||
|
t.tlsConfig = tlsManager.GetTCPConfig()
|
||||||
|
|
||||||
|
// ADDED: Client-specific TLS config adjustments
|
||||||
|
t.tlsConfig.InsecureSkipVerify = cfg.SSL.InsecureSkipVerify
|
||||||
|
|
||||||
|
// Extract server name from address for SNI
|
||||||
|
host, _, err := net.SplitHostPort(cfg.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse address for SNI: %w", err)
|
||||||
|
}
|
||||||
|
t.tlsConfig.ServerName = host
|
||||||
|
|
||||||
|
logger.Info("msg", "TLS enabled for TCP client",
|
||||||
|
"component", "tcp_client_sink",
|
||||||
|
"address", cfg.Address,
|
||||||
|
"server_name", host,
|
||||||
|
"insecure", cfg.SSL.InsecureSkipVerify)
|
||||||
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,6 +337,35 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
|
|||||||
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrap with TLS if configured
|
||||||
|
if t.tlsConfig != nil {
|
||||||
|
t.logger.Debug("msg", "Initiating TLS handshake",
|
||||||
|
"component", "tcp_client_sink",
|
||||||
|
"address", t.config.Address)
|
||||||
|
|
||||||
|
tlsConn := tls.Client(conn, t.tlsConfig)
|
||||||
|
|
||||||
|
// Perform handshake with timeout
|
||||||
|
handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := tlsConn.HandshakeContext(handshakeCtx); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log connection details
|
||||||
|
state := tlsConn.ConnectionState()
|
||||||
|
t.logger.Info("msg", "TLS connection established",
|
||||||
|
"component", "tcp_client_sink",
|
||||||
|
"address", t.config.Address,
|
||||||
|
"tls_version", tlsVersionString(state.Version),
|
||||||
|
"cipher_suite", tls.CipherSuiteName(state.CipherSuite),
|
||||||
|
"server_name", state.ServerName)
|
||||||
|
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,7 +381,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Set read deadline
|
// Set read deadline
|
||||||
// TODO: Add t.config.ReadTimeout instead of static value
|
// TODO: Add t.config.ReadTimeout and after addition use it instead of static value
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
|
||||||
t.logger.Debug("msg", "Failed to set read deadline", "error", err)
|
t.logger.Debug("msg", "Failed to set read deadline", "error", err)
|
||||||
return
|
return
|
||||||
@ -378,4 +464,20 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tlsVersionString returns human-readable TLS version
|
||||||
|
func tlsVersionString(version uint16) string {
|
||||||
|
switch version {
|
||||||
|
case tls.VersionTLS10:
|
||||||
|
return "TLS1.0"
|
||||||
|
case tls.VersionTLS11:
|
||||||
|
return "TLS1.1"
|
||||||
|
case tls.VersionTLS12:
|
||||||
|
return "TLS1.2"
|
||||||
|
case tls.VersionTLS13:
|
||||||
|
return "TLS1.3"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("0x%04x", version)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@ -4,6 +4,8 @@ package source
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"logwisp/src/internal/tls"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -29,6 +31,10 @@ type HTTPSource struct {
|
|||||||
netLimiter *limit.NetLimiter
|
netLimiter *limit.NetLimiter
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
||||||
|
// CHANGED: Add TLS support
|
||||||
|
tlsManager *tls.Manager
|
||||||
|
sslConfig *config.SSLConfig
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
totalEntries atomic.Uint64
|
totalEntries atomic.Uint64
|
||||||
droppedEntries atomic.Uint64
|
droppedEntries atomic.Uint64
|
||||||
@ -94,6 +100,28 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract SSL config after existing options
|
||||||
|
if ssl, ok := options["ssl"].(map[string]any); ok {
|
||||||
|
h.sslConfig = &config.SSLConfig{}
|
||||||
|
h.sslConfig.Enabled, _ = ssl["enabled"].(bool)
|
||||||
|
if certFile, ok := ssl["cert_file"].(string); ok {
|
||||||
|
h.sslConfig.CertFile = certFile
|
||||||
|
}
|
||||||
|
if keyFile, ok := ssl["key_file"].(string); ok {
|
||||||
|
h.sslConfig.KeyFile = keyFile
|
||||||
|
}
|
||||||
|
// TODO: extract other SSL options similar to tcp_client_sink
|
||||||
|
|
||||||
|
// Create TLS manager
|
||||||
|
if h.sslConfig.Enabled {
|
||||||
|
tlsManager, err := tls.NewManager(h.sslConfig, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||||
|
}
|
||||||
|
h.tlsManager = tlsManager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,9 +151,19 @@ func (h *HTTPSource) Start() error {
|
|||||||
h.logger.Info("msg", "HTTP source server starting",
|
h.logger.Info("msg", "HTTP source server starting",
|
||||||
"component", "http_source",
|
"component", "http_source",
|
||||||
"port", h.port,
|
"port", h.port,
|
||||||
"ingest_path", h.ingestPath)
|
"ingest_path", h.ingestPath,
|
||||||
|
"tls_enabled", h.tlsManager != nil)
|
||||||
|
|
||||||
if err := h.server.ListenAndServe(addr); err != nil {
|
var err error
|
||||||
|
// Check for TLS manager and start the appropriate server type
|
||||||
|
if h.tlsManager != nil {
|
||||||
|
h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
|
||||||
|
err = h.server.ListenAndServeTLS(addr, h.sslConfig.CertFile, h.sslConfig.KeyFile)
|
||||||
|
} else {
|
||||||
|
err = h.server.ListenAndServe(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
h.logger.Error("msg", "HTTP source server failed",
|
h.logger.Error("msg", "HTTP source server failed",
|
||||||
"component", "http_source",
|
"component", "http_source",
|
||||||
"port", h.port,
|
"port", h.port,
|
||||||
@ -134,7 +172,7 @@ func (h *HTTPSource) Start() error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Give server time to start
|
// Give server time to start
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond) // TODO: standardize and better manage timers
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,8 +240,21 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check net limit
|
// Extract and validate IP
|
||||||
remoteAddr := ctx.RemoteAddr().String()
|
remoteAddr := ctx.RemoteAddr().String()
|
||||||
|
ipStr, _, err := net.SplitHostPort(remoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
if ip := net.ParseIP(ipStr); ip != nil && ip.To4() == nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.SetContentType("application/json")
|
||||||
|
json.NewEncoder(ctx).Encode(map[string]string{
|
||||||
|
"error": "IPv4-only (IPv6 not supported)",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check net limit
|
||||||
if h.netLimiter != nil {
|
if h.netLimiter != nil {
|
||||||
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
||||||
ctx.SetStatusCode(int(statusCode))
|
ctx.SetStatusCode(int(statusCode))
|
||||||
@ -280,7 +331,7 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) {
|
|||||||
// Try to parse as JSON array
|
// Try to parse as JSON array
|
||||||
var array []core.LogEntry
|
var array []core.LogEntry
|
||||||
if err := json.Unmarshal(body, &array); err == nil {
|
if err := json.Unmarshal(body, &array); err == nil {
|
||||||
// TODO: Placeholder; For array, divide total size by entry count as approximation
|
// NOTE: Placeholder; For array, divide total size by entry count as approximation
|
||||||
approxSizePerEntry := int64(len(body) / len(array))
|
approxSizePerEntry := int64(len(body) / len(array))
|
||||||
for i, entry := range array {
|
for i, entry := range array {
|
||||||
if entry.Message == "" {
|
if entry.Message == "" {
|
||||||
@ -292,7 +343,7 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) {
|
|||||||
if entry.Source == "" {
|
if entry.Source == "" {
|
||||||
array[i].Source = "http"
|
array[i].Source = "http"
|
||||||
}
|
}
|
||||||
// TODO: Placeholder
|
// NOTE: Placeholder
|
||||||
array[i].RawSize = approxSizePerEntry
|
array[i].RawSize = approxSizePerEntry
|
||||||
}
|
}
|
||||||
return array, nil
|
return array, nil
|
||||||
|
|||||||
@ -5,21 +5,31 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"logwisp/src/internal/auth"
|
||||||
"logwisp/src/internal/config"
|
"logwisp/src/internal/config"
|
||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
"logwisp/src/internal/limit"
|
"logwisp/src/internal/limit"
|
||||||
|
"logwisp/src/internal/tls"
|
||||||
|
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
"github.com/lixenwraith/log/compat"
|
"github.com/lixenwraith/log/compat"
|
||||||
"github.com/panjf2000/gnet/v2"
|
"github.com/panjf2000/gnet/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxClientBufferSize = 10 * 1024 * 1024 // 10MB max per client
|
||||||
|
maxLineLength = 1 * 1024 * 1024 // 1MB max per log line
|
||||||
|
maxEncryptedDataPerRead = 1 * 1024 * 1024 // 1MB max encrypted data per read
|
||||||
|
maxCumulativeEncrypted = 20 * 1024 * 1024 // 20MB total encrypted before processing
|
||||||
|
)
|
||||||
|
|
||||||
// TCPSource receives log entries via TCP connections
|
// TCPSource receives log entries via TCP connections
|
||||||
type TCPSource struct {
|
type TCPSource struct {
|
||||||
port int64
|
port int64
|
||||||
@ -32,6 +42,8 @@ type TCPSource struct {
|
|||||||
engineMu sync.Mutex
|
engineMu sync.Mutex
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
netLimiter *limit.NetLimiter
|
netLimiter *limit.NetLimiter
|
||||||
|
tlsManager *tls.Manager
|
||||||
|
sslConfig *config.SSLConfig
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
@ -91,6 +103,32 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract SSL config and initialize TLS manager
|
||||||
|
if ssl, ok := options["ssl"].(map[string]any); ok {
|
||||||
|
t.sslConfig = &config.SSLConfig{}
|
||||||
|
t.sslConfig.Enabled, _ = ssl["enabled"].(bool)
|
||||||
|
if certFile, ok := ssl["cert_file"].(string); ok {
|
||||||
|
t.sslConfig.CertFile = certFile
|
||||||
|
}
|
||||||
|
if keyFile, ok := ssl["key_file"].(string); ok {
|
||||||
|
t.sslConfig.KeyFile = keyFile
|
||||||
|
}
|
||||||
|
t.sslConfig.ClientAuth, _ = ssl["client_auth"].(bool)
|
||||||
|
if caFile, ok := ssl["client_ca_file"].(string); ok {
|
||||||
|
t.sslConfig.ClientCAFile = caFile
|
||||||
|
}
|
||||||
|
t.sslConfig.VerifyClientCert, _ = ssl["verify_client_cert"].(bool)
|
||||||
|
|
||||||
|
// Create TLS manager if enabled
|
||||||
|
if t.sslConfig.Enabled {
|
||||||
|
tlsManager, err := tls.NewManager(t.sslConfig, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||||
|
}
|
||||||
|
t.tlsManager = tlsManager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,7 +159,8 @@ func (t *TCPSource) Start() error {
|
|||||||
defer t.wg.Done()
|
defer t.wg.Done()
|
||||||
t.logger.Info("msg", "TCP source server starting",
|
t.logger.Info("msg", "TCP source server starting",
|
||||||
"component", "tcp_source",
|
"component", "tcp_source",
|
||||||
"port", t.port)
|
"port", t.port,
|
||||||
|
"tls_enabled", t.tlsManager != nil)
|
||||||
|
|
||||||
err := gnet.Run(t.server, addr,
|
err := gnet.Run(t.server, addr,
|
||||||
gnet.WithLogger(gnetLogger),
|
gnet.WithLogger(gnetLogger),
|
||||||
@ -233,8 +272,14 @@ func (t *TCPSource) publish(entry core.LogEntry) bool {
|
|||||||
|
|
||||||
// tcpClient represents a connected TCP client
|
// tcpClient represents a connected TCP client
|
||||||
type tcpClient struct {
|
type tcpClient struct {
|
||||||
conn gnet.Conn
|
conn gnet.Conn
|
||||||
buffer bytes.Buffer
|
buffer bytes.Buffer
|
||||||
|
authenticated bool
|
||||||
|
session *auth.Session
|
||||||
|
authTimeout time.Time
|
||||||
|
tlsBridge *tls.GNetTLSConn
|
||||||
|
maxBufferSeen int
|
||||||
|
cumulativeEncrypted int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// tcpSourceServer handles gnet events
|
// tcpSourceServer handles gnet events
|
||||||
@ -265,8 +310,7 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
|
|
||||||
// Check net limit
|
// Check net limit
|
||||||
if s.source.netLimiter != nil {
|
if s.source.netLimiter != nil {
|
||||||
remoteStr := c.RemoteAddr().String()
|
tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr)
|
||||||
tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.source.logger.Warn("msg", "Failed to parse TCP address",
|
s.source.logger.Warn("msg", "Failed to parse TCP address",
|
||||||
"component", "tcp_source",
|
"component", "tcp_source",
|
||||||
@ -283,7 +327,21 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Track connection
|
// Track connection
|
||||||
s.source.netLimiter.AddConnection(remoteStr)
|
s.source.netLimiter.AddConnection(remoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create client state
|
||||||
|
client := &tcpClient{conn: c}
|
||||||
|
|
||||||
|
// Initialize TLS bridge if enabled
|
||||||
|
if s.source.tlsManager != nil {
|
||||||
|
tlsConfig := s.source.tlsManager.GetTCPConfig()
|
||||||
|
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
|
||||||
|
client.tlsBridge.Handshake() // Start async handshake
|
||||||
|
|
||||||
|
s.source.logger.Debug("msg", "TLS handshake initiated",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", remoteAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create client state
|
// Create client state
|
||||||
@ -295,7 +353,8 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
s.source.logger.Debug("msg", "TCP connection opened",
|
s.source.logger.Debug("msg", "TCP connection opened",
|
||||||
"component", "tcp_source",
|
"component", "tcp_source",
|
||||||
"remote_addr", remoteAddr,
|
"remote_addr", remoteAddr,
|
||||||
"active_connections", newCount)
|
"active_connections", newCount,
|
||||||
|
"tls_enabled", s.source.tlsManager != nil)
|
||||||
|
|
||||||
return nil, gnet.None
|
return nil, gnet.None
|
||||||
}
|
}
|
||||||
@ -305,9 +364,18 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
|||||||
|
|
||||||
// Remove client state
|
// Remove client state
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
client := s.clients[c]
|
||||||
delete(s.clients, c)
|
delete(s.clients, c)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Clean up TLS bridge if present
|
||||||
|
if client != nil && client.tlsBridge != nil {
|
||||||
|
client.tlsBridge.Close()
|
||||||
|
s.source.logger.Debug("msg", "TLS connection closed",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", remoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// Remove connection tracking
|
// Remove connection tracking
|
||||||
if s.source.netLimiter != nil {
|
if s.source.netLimiter != nil {
|
||||||
s.source.netLimiter.RemoveConnection(remoteAddr)
|
s.source.netLimiter.RemoveConnection(remoteAddr)
|
||||||
@ -340,9 +408,113 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
return gnet.Close
|
return gnet.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check encrypted data size BEFORE processing through TLS
|
||||||
|
if len(data) > maxEncryptedDataPerRead {
|
||||||
|
s.source.logger.Warn("msg", "Encrypted data per read limit exceeded",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
|
"data_size", len(data),
|
||||||
|
"limit", maxEncryptedDataPerRead)
|
||||||
|
s.source.invalidEntries.Add(1)
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track cumulative encrypted data to prevent slow accumulation
|
||||||
|
client.cumulativeEncrypted += int64(len(data))
|
||||||
|
if client.cumulativeEncrypted > maxCumulativeEncrypted {
|
||||||
|
s.source.logger.Warn("msg", "Cumulative encrypted data limit exceeded",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
|
"total_encrypted", client.cumulativeEncrypted,
|
||||||
|
"limit", maxCumulativeEncrypted)
|
||||||
|
s.source.invalidEntries.Add(1)
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process through TLS bridge if present
|
||||||
|
if client.tlsBridge != nil {
|
||||||
|
// Feed encrypted data into TLS engine
|
||||||
|
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
|
||||||
|
if errors.Is(err, tls.ErrTLSBackpressure) {
|
||||||
|
s.source.logger.Warn("msg", "TLS backpressure, closing slow client",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", c.RemoteAddr().String())
|
||||||
|
} else {
|
||||||
|
s.source.logger.Error("msg", "TLS processing error",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
|
"error", err)
|
||||||
|
}
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if handshake is complete
|
||||||
|
if !client.tlsBridge.IsHandshakeDone() {
|
||||||
|
// Still handshaking, wait for more data
|
||||||
|
return gnet.None
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check handshake result
|
||||||
|
_, hsErr := client.tlsBridge.HandshakeComplete()
|
||||||
|
if hsErr != nil {
|
||||||
|
s.source.logger.Error("msg", "TLS handshake failed",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
|
"error", hsErr)
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read decrypted plaintext
|
||||||
|
data = client.tlsBridge.Read()
|
||||||
|
if data == nil || len(data) == 0 {
|
||||||
|
// No plaintext available yet
|
||||||
|
return gnet.None
|
||||||
|
}
|
||||||
|
// Reset cumulative counter after successful decryption and processing
|
||||||
|
client.cumulativeEncrypted = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check buffer size before appending
|
||||||
|
if client.buffer.Len()+len(data) > maxClientBufferSize {
|
||||||
|
s.source.logger.Warn("msg", "Client buffer limit exceeded",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
|
"buffer_size", client.buffer.Len(),
|
||||||
|
"incoming_size", len(data))
|
||||||
|
s.source.invalidEntries.Add(1)
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
|
||||||
// Append to client buffer
|
// Append to client buffer
|
||||||
client.buffer.Write(data)
|
client.buffer.Write(data)
|
||||||
|
|
||||||
|
// Track high buffer
|
||||||
|
if client.buffer.Len() > client.maxBufferSeen {
|
||||||
|
client.maxBufferSeen = client.buffer.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for suspiciously long lines before attempting to read
|
||||||
|
if client.buffer.Len() > maxLineLength {
|
||||||
|
// Scan for newline in current buffer
|
||||||
|
bufBytes := client.buffer.Bytes()
|
||||||
|
hasNewline := false
|
||||||
|
for _, b := range bufBytes {
|
||||||
|
if b == '\n' {
|
||||||
|
hasNewline = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasNewline {
|
||||||
|
s.source.logger.Warn("msg", "Line too long without newline",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
|
"buffer_size", client.buffer.Len())
|
||||||
|
s.source.invalidEntries.Add(1)
|
||||||
|
return gnet.Close
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Process complete lines
|
// Process complete lines
|
||||||
for {
|
for {
|
||||||
line, err := client.buffer.ReadBytes('\n')
|
line, err := client.buffer.ReadBytes('\n')
|
||||||
|
|||||||
341
src/internal/tls/gnet_bridge.go
Normal file
341
src/internal/tls/gnet_bridge.go
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
// FILE: src/internal/tls/gnet_bridge.go
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/panjf2000/gnet/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTLSBackpressure = errors.New("TLS processing backpressure")
|
||||||
|
ErrConnectionClosed = errors.New("connection closed")
|
||||||
|
ErrPlaintextBufferExceeded = errors.New("plaintext buffer size exceeded")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Maximum plaintext buffer size to prevent memory exhaustion
|
||||||
|
const maxPlaintextBufferSize = 32 * 1024 * 1024 // 32MB
|
||||||
|
|
||||||
|
// GNetTLSConn bridges gnet.Conn with crypto/tls via io.Pipe
|
||||||
|
type GNetTLSConn struct {
|
||||||
|
gnetConn gnet.Conn
|
||||||
|
tlsConn *tls.Conn
|
||||||
|
config *tls.Config
|
||||||
|
|
||||||
|
// Buffered channels for non-blocking operation
|
||||||
|
incomingCipher chan []byte // Network → TLS (encrypted)
|
||||||
|
outgoingCipher chan []byte // TLS → Network (encrypted)
|
||||||
|
|
||||||
|
// Handshake state
|
||||||
|
handshakeOnce sync.Once
|
||||||
|
handshakeDone chan struct{}
|
||||||
|
handshakeErr error
|
||||||
|
|
||||||
|
// Decrypted data buffer
|
||||||
|
plainBuf []byte
|
||||||
|
plainMu sync.Mutex
|
||||||
|
|
||||||
|
// Lifecycle
|
||||||
|
closed atomic.Bool
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Error tracking
|
||||||
|
lastErr atomic.Value // error
|
||||||
|
logger interface{ Warn(args ...any) } // Minimal logger interface
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerConn creates a server-side TLS bridge
|
||||||
|
func NewServerConn(gnetConn gnet.Conn, config *tls.Config) *GNetTLSConn {
|
||||||
|
tc := &GNetTLSConn{
|
||||||
|
gnetConn: gnetConn,
|
||||||
|
config: config,
|
||||||
|
handshakeDone: make(chan struct{}),
|
||||||
|
// Buffered channels sized for throughput without blocking
|
||||||
|
incomingCipher: make(chan []byte, 128), // 128 packets buffer
|
||||||
|
outgoingCipher: make(chan []byte, 128),
|
||||||
|
plainBuf: make([]byte, 0, 65536), // 64KB initial capacity
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS conn with channel-based transport
|
||||||
|
rawConn := &channelConn{
|
||||||
|
incoming: tc.incomingCipher,
|
||||||
|
outgoing: tc.outgoingCipher,
|
||||||
|
localAddr: gnetConn.LocalAddr(),
|
||||||
|
remoteAddr: gnetConn.RemoteAddr(),
|
||||||
|
tc: tc,
|
||||||
|
}
|
||||||
|
tc.tlsConn = tls.Server(rawConn, config)
|
||||||
|
|
||||||
|
// Start pump goroutines
|
||||||
|
tc.wg.Add(2)
|
||||||
|
go tc.pumpCipherToNetwork()
|
||||||
|
go tc.pumpPlaintextFromTLS()
|
||||||
|
|
||||||
|
return tc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientConn creates a client-side TLS bridge (similar changes)
|
||||||
|
func NewClientConn(gnetConn gnet.Conn, config *tls.Config, serverName string) *GNetTLSConn {
|
||||||
|
tc := &GNetTLSConn{
|
||||||
|
gnetConn: gnetConn,
|
||||||
|
config: config,
|
||||||
|
handshakeDone: make(chan struct{}),
|
||||||
|
incomingCipher: make(chan []byte, 128),
|
||||||
|
outgoingCipher: make(chan []byte, 128),
|
||||||
|
plainBuf: make([]byte, 0, 65536),
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ServerName == "" {
|
||||||
|
config = config.Clone()
|
||||||
|
config.ServerName = serverName
|
||||||
|
}
|
||||||
|
|
||||||
|
rawConn := &channelConn{
|
||||||
|
incoming: tc.incomingCipher,
|
||||||
|
outgoing: tc.outgoingCipher,
|
||||||
|
localAddr: gnetConn.LocalAddr(),
|
||||||
|
remoteAddr: gnetConn.RemoteAddr(),
|
||||||
|
tc: tc,
|
||||||
|
}
|
||||||
|
tc.tlsConn = tls.Client(rawConn, config)
|
||||||
|
|
||||||
|
tc.wg.Add(2)
|
||||||
|
go tc.pumpCipherToNetwork()
|
||||||
|
go tc.pumpPlaintextFromTLS()
|
||||||
|
|
||||||
|
return tc
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessIncoming feeds encrypted data from network into TLS engine (non-blocking)
|
||||||
|
func (tc *GNetTLSConn) ProcessIncoming(encryptedData []byte) error {
|
||||||
|
if tc.closed.Load() {
|
||||||
|
return ErrConnectionClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-blocking send with backpressure detection
|
||||||
|
select {
|
||||||
|
case tc.incomingCipher <- encryptedData:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
// Channel full - TLS processing can't keep up
|
||||||
|
// Drop connection under backpressure vs blocking event loop
|
||||||
|
if tc.logger != nil {
|
||||||
|
tc.logger.Warn("msg", "TLS backpressure, dropping data",
|
||||||
|
"remote_addr", tc.gnetConn.RemoteAddr())
|
||||||
|
}
|
||||||
|
return ErrTLSBackpressure
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pumpCipherToNetwork sends TLS-encrypted data to network
|
||||||
|
func (tc *GNetTLSConn) pumpCipherToNetwork() {
|
||||||
|
defer tc.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case data, ok := <-tc.outgoingCipher:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send to network
|
||||||
|
if err := tc.gnetConn.AsyncWrite(data, nil); err != nil {
|
||||||
|
tc.lastErr.Store(err)
|
||||||
|
tc.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-time.After(30 * time.Second):
|
||||||
|
// Keepalive/timeout check
|
||||||
|
if tc.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pumpPlaintextFromTLS reads decrypted data from TLS
|
||||||
|
func (tc *GNetTLSConn) pumpPlaintextFromTLS() {
|
||||||
|
defer tc.wg.Done()
|
||||||
|
buf := make([]byte, 32768) // 32KB read buffer
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := tc.tlsConn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
tc.plainMu.Lock()
|
||||||
|
// Check buffer size limit before appending to prevent memory exhaustion
|
||||||
|
if len(tc.plainBuf)+n > maxPlaintextBufferSize {
|
||||||
|
tc.plainMu.Unlock()
|
||||||
|
// Log warning about buffer limit
|
||||||
|
if tc.logger != nil {
|
||||||
|
tc.logger.Warn("msg", "Plaintext buffer limit exceeded, closing connection",
|
||||||
|
"remote_addr", tc.gnetConn.RemoteAddr(),
|
||||||
|
"buffer_size", len(tc.plainBuf),
|
||||||
|
"incoming_size", n,
|
||||||
|
"limit", maxPlaintextBufferSize)
|
||||||
|
}
|
||||||
|
// Store error and close connection
|
||||||
|
tc.lastErr.Store(ErrPlaintextBufferExceeded)
|
||||||
|
tc.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tc.plainBuf = append(tc.plainBuf, buf[:n]...)
|
||||||
|
tc.plainMu.Unlock()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
tc.lastErr.Store(err)
|
||||||
|
}
|
||||||
|
tc.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read returns available decrypted plaintext (non-blocking)
|
||||||
|
func (tc *GNetTLSConn) Read() []byte {
|
||||||
|
tc.plainMu.Lock()
|
||||||
|
defer tc.plainMu.Unlock()
|
||||||
|
|
||||||
|
if len(tc.plainBuf) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Atomic buffer swap under mutex protection to prevent race condition
|
||||||
|
data := tc.plainBuf
|
||||||
|
tc.plainBuf = make([]byte, 0, cap(tc.plainBuf))
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write encrypts plaintext and queues for network transmission
|
||||||
|
func (tc *GNetTLSConn) Write(plaintext []byte) (int, error) {
|
||||||
|
if tc.closed.Load() {
|
||||||
|
return 0, ErrConnectionClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tc.IsHandshakeDone() {
|
||||||
|
return 0, errors.New("handshake not complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
return tc.tlsConn.Write(plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handshake initiates TLS handshake asynchronously
|
||||||
|
func (tc *GNetTLSConn) Handshake() {
|
||||||
|
tc.handshakeOnce.Do(func() {
|
||||||
|
go func() {
|
||||||
|
tc.handshakeErr = tc.tlsConn.Handshake()
|
||||||
|
close(tc.handshakeDone)
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHandshakeDone checks if handshake is complete
|
||||||
|
func (tc *GNetTLSConn) IsHandshakeDone() bool {
|
||||||
|
select {
|
||||||
|
case <-tc.handshakeDone:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandshakeComplete waits for handshake completion
|
||||||
|
func (tc *GNetTLSConn) HandshakeComplete() (<-chan struct{}, error) {
|
||||||
|
<-tc.handshakeDone
|
||||||
|
return tc.handshakeDone, tc.handshakeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the bridge
|
||||||
|
func (tc *GNetTLSConn) Close() error {
|
||||||
|
tc.closeOnce.Do(func() {
|
||||||
|
tc.closed.Store(true)
|
||||||
|
|
||||||
|
// Close TLS connection
|
||||||
|
tc.tlsConn.Close()
|
||||||
|
|
||||||
|
// Close channels to stop pumps
|
||||||
|
close(tc.incomingCipher)
|
||||||
|
close(tc.outgoingCipher)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wait for pumps to finish
|
||||||
|
tc.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionState returns TLS connection state
|
||||||
|
func (tc *GNetTLSConn) GetConnectionState() tls.ConnectionState {
|
||||||
|
return tc.tlsConn.ConnectionState()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetError returns last error
|
||||||
|
func (tc *GNetTLSConn) GetError() error {
|
||||||
|
if err, ok := tc.lastErr.Load().(error); ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// channelConn implements net.Conn over channels
|
||||||
|
type channelConn struct {
|
||||||
|
incoming <-chan []byte
|
||||||
|
outgoing chan<- []byte
|
||||||
|
localAddr net.Addr
|
||||||
|
remoteAddr net.Addr
|
||||||
|
tc *GNetTLSConn
|
||||||
|
readBuf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *channelConn) Read(b []byte) (int, error) {
|
||||||
|
// Use buffered read for efficiency
|
||||||
|
if len(c.readBuf) > 0 {
|
||||||
|
n := copy(b, c.readBuf)
|
||||||
|
c.readBuf = c.readBuf[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for new data
|
||||||
|
select {
|
||||||
|
case data, ok := <-c.incoming:
|
||||||
|
if !ok {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(b, data)
|
||||||
|
if n < len(data) {
|
||||||
|
c.readBuf = data[n:] // Buffer remainder
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
case <-time.After(30 * time.Second):
|
||||||
|
return 0, errors.New("read timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *channelConn) Write(b []byte) (int, error) {
|
||||||
|
if c.tc.closed.Load() {
|
||||||
|
return 0, ErrConnectionClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a copy since TLS may hold reference
|
||||||
|
data := make([]byte, len(b))
|
||||||
|
copy(data, b)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.outgoing <- data:
|
||||||
|
return len(b), nil
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
return 0, errors.New("write timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *channelConn) Close() error { return nil }
|
||||||
|
func (c *channelConn) LocalAddr() net.Addr { return c.localAddr }
|
||||||
|
func (c *channelConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||||
|
func (c *channelConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (c *channelConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (c *channelConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
@ -20,8 +20,8 @@ type Manager struct {
|
|||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a TLS configuration from SSL config
|
// NewManager creates a TLS configuration from SSL config
|
||||||
func New(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) {
|
func NewManager(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) {
|
||||||
if cfg == nil || !cfg.Enabled {
|
if cfg == nil || !cfg.Enabled {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user