v0.5.0 removed tcp tls, basic auth hash changed to argon2, refactor

This commit is contained in:
2025-09-29 05:42:22 -04:00
parent 15d72baafd
commit c33ec148ba
27 changed files with 985 additions and 1287 deletions

View File

@ -37,7 +37,6 @@ type = "directory"
path = "./" # Directory to monitor
pattern = "*.log" # Glob pattern
check_interval_ms = 100 # Scan interval
read_from_beginning = false # Start position
### Console Sources
# [[pipelines.sources]]
@ -74,11 +73,10 @@ read_from_beginning = false # Start position
# ip_blacklist = [] # Blocked IPs/CIDRs
# requests_per_second = 100.0 # Rate limit per client
# burst_size = 100 # Burst capacity
# limit_by = "ip" # ip|user|token|global
# response_code = 429 # HTTP status when limited
# response_message = "Rate limit exceeded"
# max_connections_per_ip = 10 # Max concurrent per IP
# max_total_connections = 1000 # Max total connections
# max_connections_total = 1000 # Max total connections
### TCP Sources
# [[pipelines.sources]]
@ -88,30 +86,18 @@ read_from_beginning = false # Start position
# host = "0.0.0.0" # Listen address
# port = 9091 # Listen port
# [pipelines.sources.options.tls]
# enabled = false # Enable TLS
# cert_file = "" # TLS certificate
# key_file = "" # TLS key
# client_auth = false # Require client certs
# client_ca_file = "" # Client CA cert
# verify_client_cert = false # Verify client certs
# insecure_skip_verify = false # Skip verification
# ca_file = "" # Custom CA file
# min_version = "TLS1.2" # Min TLS version
# max_version = "TLS1.3" # Max TLS version
# cipher_suites = "" # Comma-separated list
# [pipelines.sources.options.net_limit]
# enabled = false # Enable rate limiting
# ip_whitelist = [] # Allowed IPs/CIDRs
# ip_blacklist = [] # Blocked IPs/CIDRs
# requests_per_second = 100.0 # Rate limit per client
# burst_size = 100 # Burst capacity
# limit_by = "ip" # ip|user|token|global
# response_code = 429 # Response code when limited
# response_message = "Rate limit exceeded"
# max_connections_per_ip = 10 # Max concurrent per IP
# max_total_connections = 1000 # Max total connections
# max_connections_per_user = 10 # Max concurrent per user
# max_connections_per_token = 10 # Max concurrent per token
# max_connections_total = 1000 # Max total connections
### Rate limiting
# [pipelines.rate_limit]
@ -126,6 +112,21 @@ read_from_beginning = false # Start position
# logic = "or" # or|and
# patterns = [] # Regex patterns
## Examples of filter patterns:
## Include only error or fatal logs containing "database":
## type = "include"
## logic = "and"
## patterns = ["(?i)(error|fatal)", "database"]
##
## Exclude debug logs from test environment:
## type = "exclude"
## logic = "or"
## patterns = ["(?i)debug", "test-env"]
##
## Include only JSON formatted logs:
## type = "include"
## patterns = ["^\\{.*\\}$"]
### Format
### Raw formatter (default)
@ -174,7 +175,6 @@ format = "comment" # comment|message
# verify_client_cert = false # Verify client certs
# insecure_skip_verify = false # Skip verification
# ca_file = "" # Custom CA file
# server_name = "" # Expected server name
# min_version = "TLS1.2" # Min TLS version
# max_version = "TLS1.3" # Max TLS version
# cipher_suites = "" # Comma-separated list
@ -185,11 +185,10 @@ format = "comment" # comment|message
# ip_blacklist = [] # Blocked IPs/CIDRs
# requests_per_second = 100.0 # Rate limit per client
# burst_size = 100 # Burst capacity
# limit_by = "ip" # ip|user|token|global
# response_code = 429 # HTTP status when limited
# response_message = "Rate limit exceeded"
# max_connections_per_ip = 10 # Max concurrent per IP
# max_total_connections = 1000 # Max total connections
# max_connections_total = 1000 # Max total connections
### TCP Sinks
# [[pipelines.sinks]]
@ -207,31 +206,18 @@ format = "comment" # comment|message
# include_stats = false # Include statistics
# format = "comment" # comment|message
# [pipelines.sinks.options.tls]
# enabled = false # Enable TLS
# cert_file = "" # TLS certificate
# key_file = "" # TLS key
# client_auth = false # Require client certs
# client_ca_file = "" # Client CA cert
# verify_client_cert = false # Verify client certs
# insecure_skip_verify = false # Skip verification
# ca_file = "" # Custom CA file
# server_name = "" # Expected server name
# min_version = "TLS1.2" # Min TLS version
# max_version = "TLS1.3" # Max TLS version
# cipher_suites = "" # Comma-separated list
# [pipelines.sinks.options.net_limit]
# enabled = false # Enable rate limiting
# ip_whitelist = [] # Allowed IPs/CIDRs
# ip_blacklist = [] # Blocked IPs/CIDRs
# requests_per_second = 100.0 # Rate limit per client
# burst_size = 100 # Burst capacity
# limit_by = "ip" # ip|user|token|global
# response_code = 429 # HTTP status when limited
# response_message = "Rate limit exceeded"
# max_connections_per_ip = 10 # Max concurrent per IP
# max_total_connections = 1000 # Max total connections
# max_connections_per_user = 10 # Max concurrent per user
# max_connections_per_token = 10 # Max concurrent per token
# max_connections_total = 1000 # Max total connections
### HTTP Client Sinks
# [[pipelines.sinks]]
@ -283,6 +269,7 @@ format = "comment" # comment|message
# [pipelines.sinks.options]
# directory = "" # Output dir (required)
# name = "" # Base name (required)
# buffer_size = 1000 # Input channel buffer size
# max_size_mb = 100 # Rotation size
# max_total_size_mb = 0 # Total limit (0=unlimited)
# retention_hours = 0.0 # Retention (0=disabled)

6
go.mod
View File

@ -5,9 +5,9 @@ go 1.25.1
require (
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208
github.com/panjf2000/gnet/v2 v2.9.3
github.com/valyala/fasthttp v1.65.0
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e
github.com/panjf2000/gnet/v2 v2.9.4
github.com/valyala/fasthttp v1.66.0
golang.org/x/crypto v0.42.0
golang.org/x/term v0.35.0
golang.org/x/time v0.13.0

12
go.sum
View File

@ -12,20 +12,20 @@ github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zt
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0=
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208 h1:IB1O/HLv9VR/4mL1Tkjlr91lk+r8anP6bab7rYdS/oE=
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e h1:/XWCqFdSOiUf0/a5a63GHsvEdpglsYfn3qieNxTeyDc=
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
github.com/panjf2000/gnet/v2 v2.9.3 h1:auV3/A9Na3jiBDmYAAU00rPhFKnsAI+TnI1F7YUJMHQ=
github.com/panjf2000/gnet/v2 v2.9.3/go.mod h1:WQTxDWYuQ/hz3eccH0FN32IVuvZ19HewEWx0l62fx7E=
github.com/panjf2000/gnet/v2 v2.9.4 h1:XvPCcaFwO4XWg4IgSfZnNV4dfDy5g++HIEx7sH0ldHc=
github.com/panjf2000/gnet/v2 v2.9.4/go.mod h1:WQTxDWYuQ/hz3eccH0FN32IVuvZ19HewEWx0l62fx7E=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8=
github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4=
github.com/valyala/fasthttp v1.66.0 h1:M87A0Z7EayeyNaV6pfO3tUTUiYO0dZfEJnRGXTVNuyU=
github.com/valyala/fasthttp v1.66.0/go.mod h1:Y4eC+zwoocmXSVCB1JmhNbYtS7tZPRI2ztPB72EVObs=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=

View File

@ -60,7 +60,7 @@ func initializeLogger(cfg *config.Config) error {
// In quiet mode, disable ALL logging output
logCfg.Level = 255 // A level that disables all output
logCfg.DisableFile = true
logCfg.EnableStdout = false
logCfg.EnableConsole = false
return logger.ApplyConfig(logCfg)
}
@ -74,29 +74,24 @@ func initializeLogger(cfg *config.Config) error {
// Configure based on output mode
switch cfg.Logging.Output {
case "none":
logCfg.EnableStdout = false
logCfg.EnableConsole = false
case "stdout":
logCfg.EnableStdout = true
logCfg.StdoutTarget = "stdout"
logCfg.EnableConsole = true
logCfg.ConsoleTarget = "stdout"
case "stderr":
logCfg.EnableStdout = true
logCfg.StdoutTarget = "stderr"
logCfg.EnableConsole = true
logCfg.ConsoleTarget = "stderr"
case "split":
logCfg.EnableStdout = true
logCfg.StdoutTarget = "split"
logCfg.EnableConsole = true
logCfg.ConsoleTarget = "split"
case "file":
logCfg.DisableFile = false
logCfg.EnableStdout = false
configureFileLogging(logCfg, cfg)
case "both":
logCfg.DisableFile = false
logCfg.EnableStdout = true
logCfg.StdoutTarget = "stdout"
logCfg.EnableConsole = false
configureFileLogging(logCfg, cfg)
case "all":
logCfg.DisableFile = false
logCfg.EnableStdout = true
logCfg.StdoutTarget = "split"
logCfg.EnableConsole = true
logCfg.ConsoleTarget = "split"
configureFileLogging(logCfg, cfg)
default:
return fmt.Errorf("invalid log output mode: %s", cfg.Logging.Output)

View File

@ -127,13 +127,13 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
"listen", fmt.Sprintf("%s:%d", host, port))
// Display net limit info if configured
if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
if enabled, ok := rl["enabled"].(bool); ok && enabled {
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
if enabled, ok := nl["enabled"].(bool); ok && enabled {
logger.Info("msg", "TCP net limiting enabled",
"pipeline", cfg.Name,
"sink_index", i,
"requests_per_second", rl["requests_per_second"],
"burst_size", rl["burst_size"])
"requests_per_second", nl["requests_per_second"],
"burst_size", nl["burst_size"])
}
}
}
@ -162,14 +162,13 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
"status_url", fmt.Sprintf("http://%s:%d%s", host, port, statusPath))
// Display net limit info if configured
if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
if enabled, ok := rl["enabled"].(bool); ok && enabled {
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
if enabled, ok := nl["enabled"].(bool); ok && enabled {
logger.Info("msg", "HTTP net limiting enabled",
"pipeline", cfg.Name,
"sink_index", i,
"requests_per_second", rl["requests_per_second"],
"burst_size", rl["burst_size"],
"limit_by", rl["limit_by"])
"requests_per_second", nl["requests_per_second"],
"burst_size", nl["burst_size"])
}
}
}

View File

@ -4,6 +4,7 @@ package auth
import (
"bufio"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"net"
@ -16,7 +17,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/lixenwraith/log"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/argon2"
"golang.org/x/time/rate"
)
@ -378,12 +379,14 @@ func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string)
a.mu.RUnlock()
if !exists {
// Perform bcrypt anyway to prevent timing attacks
bcrypt.CompareHashAndPassword([]byte("$2a$10$dummy.hash.to.prevent.timing.attacks"), []byte(password))
// Perform argon2 anyway to prevent timing attacks
dummySalt := make([]byte, 16)
argon2.IDKey([]byte(password), dummySalt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
return nil, fmt.Errorf("invalid credentials")
}
if err := bcrypt.CompareHashAndPassword([]byte(expectedHash), []byte(password)); err != nil {
// Parse PHC format hash
if !verifyArgon2idHash(password, expectedHash) {
return nil, fmt.Errorf("invalid credentials")
}
@ -400,6 +403,43 @@ func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string)
return session, nil
}
// Verify Argon2id hashes
func verifyArgon2idHash(password, hash string) bool {
// Parse PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
parts := strings.Split(hash, "$")
if len(parts) != 6 || parts[1] != "argon2id" {
return false
}
// Parse version
var version int
fmt.Sscanf(parts[2], "v=%d", &version)
if version != argon2.Version {
return false
}
// Parse parameters
var memory, time, threads uint32
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
// Decode salt and hash
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return false
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return false
}
// Compute hash
computedHash := argon2.IDKey([]byte(password), salt, time, memory, uint8(threads), uint32(len(expectedHash)))
// Constant time comparison
return subtle.ConstantTimeCompare(computedHash, expectedHash) == 1
}
func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) {
if !strings.HasPrefix(authHeader, "Bearer ") {
return nil, fmt.Errorf("invalid bearer auth header")

View File

@ -10,17 +10,24 @@ import (
"os"
"syscall"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/argon2"
"golang.org/x/term"
)
// Handles auth credential generation
// Argon2id parameters
const (
argon2Time = 3
argon2Memory = 64 * 1024 // 64 MB
argon2Threads = 4
argon2SaltLen = 16
argon2KeyLen = 32
)
type GeneratorCommand struct {
output io.Writer
errOut io.Writer
}
// Creates a new auth generator command handler
func NewGeneratorCommand() *GeneratorCommand {
return &GeneratorCommand{
output: os.Stdout,
@ -28,7 +35,6 @@ func NewGeneratorCommand() *GeneratorCommand {
}
}
// Runs the auth generation command with provided arguments
func (g *GeneratorCommand) Execute(args []string) error {
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
cmd.SetOutput(g.errOut)
@ -36,7 +42,6 @@ func (g *GeneratorCommand) Execute(args []string) error {
var (
username = cmd.String("u", "", "Username for basic auth")
password = cmd.String("p", "", "Password to hash (will prompt if not provided)")
cost = cmd.Int("c", 10, "Bcrypt cost (10-31)")
genToken = cmd.Bool("t", false, "Generate random bearer token")
tokenLen = cmd.Int("l", 32, "Token length in bytes")
)
@ -45,7 +50,7 @@ func (g *GeneratorCommand) Execute(args []string) error {
fmt.Fprintln(g.errOut, "Generate authentication credentials for LogWisp")
fmt.Fprintln(g.errOut, "\nUsage: logwisp auth [options]")
fmt.Fprintln(g.errOut, "\nExamples:")
fmt.Fprintln(g.errOut, " # Generate bcrypt hash for user")
fmt.Fprintln(g.errOut, " # Generate Argon2id hash for user")
fmt.Fprintln(g.errOut, " logwisp auth -u admin")
fmt.Fprintln(g.errOut, " ")
fmt.Fprintln(g.errOut, " # Generate 64-byte bearer token")
@ -58,26 +63,19 @@ func (g *GeneratorCommand) Execute(args []string) error {
return err
}
// Token generation mode
if *genToken {
return g.generateToken(*tokenLen)
}
// Password hash generation mode
if *username == "" {
cmd.Usage()
return fmt.Errorf("username required for password hash generation")
}
return g.generatePasswordHash(*username, *password, *cost)
return g.generatePasswordHash(*username, *password)
}
func (g *GeneratorCommand) generatePasswordHash(username, password string, cost int) error {
// Validate cost
if cost < 10 || cost > 31 {
return fmt.Errorf("bcrypt cost must be between 10 and 31")
}
func (g *GeneratorCommand) generatePasswordHash(username, password string) error {
// Get password if not provided
if password == "" {
pass1 := g.promptPassword("Enter password: ")
@ -88,20 +86,29 @@ func (g *GeneratorCommand) generatePasswordHash(username, password string, cost
password = pass1
}
// Generate hash
hash, err := bcrypt.GenerateFromPassword([]byte(password), cost)
if err != nil {
return fmt.Errorf("failed to generate hash: %w", err)
// Generate salt
salt := make([]byte, argon2SaltLen)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("failed to generate salt: %w", err)
}
// Generate Argon2id hash
hash := argon2.IDKey([]byte(password), salt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
// Encode in PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64)
// Output configuration snippets
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):")
fmt.Fprintln(g.output, "[[pipelines.auth.basic_auth.users]]")
fmt.Fprintf(g.output, "username = %q\n", username)
fmt.Fprintf(g.output, "password_hash = %q\n\n", string(hash))
fmt.Fprintf(g.output, "password_hash = %q\n\n", phcHash)
fmt.Fprintln(g.output, "# Users File Format (for external auth file):")
fmt.Fprintf(g.output, "%s:%s\n", username, hash)
fmt.Fprintf(g.output, "%s:%s\n", username, phcHash)
return nil
}
@ -119,11 +126,9 @@ func (g *GeneratorCommand) generateToken(length int) error {
return fmt.Errorf("failed to generate random bytes: %w", err)
}
// Generate both encodings
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
hex := fmt.Sprintf("%x", token)
// Output configuration
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):")
fmt.Fprintf(g.output, "tokens = [%q]\n\n", b64)
@ -139,7 +144,6 @@ func (g *GeneratorCommand) promptPassword(prompt string) string {
password, err := term.ReadPassword(int(syscall.Stdin))
fmt.Fprintln(g.errOut)
if err != nil {
// Fatal error - can't continue without password
fmt.Fprintf(g.errOut, "Failed to read password: %v\n", err)
os.Exit(1)
}

View File

@ -29,7 +29,7 @@ type BasicAuthConfig struct {
type BasicAuthUser struct {
Username string `toml:"username"`
// Password hash (bcrypt)
// Password hash (Argon2id)
PasswordHash string `toml:"password_hash"`
}

View File

@ -5,13 +5,13 @@ import "fmt"
// Represents logging configuration for LogWisp
type LogConfig struct {
// Output mode: "file", "stdout", "stderr", "both", "none"
// Output mode: "file", "stdout", "stderr", "split", "all", "none"
Output string `toml:"output"`
// Log level: "debug", "info", "warn", "error"
Level string `toml:"level"`
// File output settings (when Output includes "file" or "both")
// File output settings (when Output includes "file" or "all")
File *LogFileConfig `toml:"file"`
// Console output settings
@ -66,7 +66,7 @@ func DefaultLogConfig() *LogConfig {
func validateLogConfig(cfg *LogConfig) error {
validOutputs := map[string]bool{
"file": true, "stdout": true, "stderr": true,
"both": true, "all": true, "none": true,
"split": true, "all": true, "none": true,
}
if !validOutputs[cfg.Output] {
return fmt.Errorf("invalid log output mode: %s", cfg.Output)

View File

@ -131,8 +131,8 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
}
// Validate net_limit
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, rl); err != nil {
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, nl); err != nil {
return err
}
}
@ -161,15 +161,8 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
}
// Validate net_limit
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, rl); err != nil {
return err
}
}
// Validate TLS
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
if err := validateTLSOptions("TCP source", pipelineName, sourceIndex, tls); err != nil {
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, nl); err != nil {
return err
}
}
@ -196,7 +189,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
pipelineName, sinkIndex)
}
// Validate host if provided
// Validate host
if host, ok := cfg.Options["host"].(string); ok && host != "" {
if net.ParseIP(host) == nil {
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
@ -219,7 +212,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
}
}
// Validate paths if provided
// Validate paths
if streamPath, ok := cfg.Options["stream_path"].(string); ok {
if !strings.HasPrefix(streamPath, "/") {
return fmt.Errorf("pipeline '%s' sink[%d]: stream path must start with /: %s",
@ -234,7 +227,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
}
}
// Validate heartbeat if present
// Validate heartbeat
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
if err := validateHeartbeatOptions("HTTP", pipelineName, sinkIndex, hb); err != nil {
return err
@ -248,9 +241,9 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
}
}
// Validate net limit if present
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, rl); err != nil {
// Validate net limit
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, nl); err != nil {
return err
}
}
@ -263,7 +256,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
pipelineName, sinkIndex)
}
// Validate host if provided
// Validate host
if host, ok := cfg.Options["host"].(string); ok && host != "" {
if net.ParseIP(host) == nil {
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
@ -286,23 +279,16 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
}
}
// Validate heartbeat if present
// Validate heartbeat
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
if err := validateHeartbeatOptions("TCP", pipelineName, sinkIndex, hb); err != nil {
return err
}
}
// Validate TLS if present
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
if err := validateTLSOptions("TCP", pipelineName, sinkIndex, tls); err != nil {
return err
}
}
// Validate net limit if present
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, rl); err != nil {
// Validate net limit
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, nl); err != nil {
return err
}
}

View File

@ -12,9 +12,6 @@ type TCPConfig struct {
Port int64 `toml:"port"`
BufferSize int64 `toml:"buffer_size"`
// TLS Configuration
TLS *TLSConfig `toml:"tls"`
// Net limiting
NetLimit *NetLimitConfig `toml:"net_limit"`
@ -63,16 +60,15 @@ type NetLimitConfig struct {
// Burst size (token bucket)
BurstSize int64 `toml:"burst_size"`
// Net limit by: "ip", "user", "token", "global"
LimitBy string `toml:"limit_by"`
// Response when net limited
ResponseCode int64 `toml:"response_code"` // Default: 429
ResponseMessage string `toml:"response_message"` // Default: "Net limit exceeded"
// Connection limits
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
MaxTotalConnections int64 `toml:"max_total_connections"`
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
MaxConnectionsPerUser int64 `toml:"max_connections_per_user"`
MaxConnectionsPerToken int64 `toml:"max_connections_per_token"`
MaxConnectionsTotal int64 `toml:"max_connections_total"`
}
func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb map[string]any) error {
@ -93,13 +89,13 @@ func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb
return nil
}
func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl map[string]any) error {
if enabled, ok := rl["enabled"].(bool); !ok || !enabled {
func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, nl map[string]any) error {
if enabled, ok := nl["enabled"].(bool); !ok || !enabled {
return nil
}
// Validate IP lists if present
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
for i, entry := range ipWhitelist {
entryStr, ok := entry.(string)
if !ok {
@ -112,7 +108,7 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
}
}
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
for i, entry := range ipBlacklist {
entryStr, ok := entry.(string)
if !ok {
@ -126,30 +122,21 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
}
// Validate requests per second
rps, ok := rl["requests_per_second"].(float64)
rps, ok := nl["requests_per_second"].(float64)
if !ok || rps <= 0 {
return fmt.Errorf("pipeline '%s' sink[%d] %s: requests_per_second must be positive",
pipelineName, sinkIndex, serverType)
}
// Validate burst size
burst, ok := rl["burst_size"].(int64)
burst, ok := nl["burst_size"].(int64)
if !ok || burst < 1 {
return fmt.Errorf("pipeline '%s' sink[%d] %s: burst_size must be at least 1",
pipelineName, sinkIndex, serverType)
}
// Validate limit_by
if limitBy, ok := rl["limit_by"].(string); ok && limitBy != "" {
validLimitBy := map[string]bool{"ip": true, "global": true}
if !validLimitBy[limitBy] {
return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid limit_by value: %s (must be 'ip' or 'global')",
pipelineName, sinkIndex, serverType, limitBy)
}
}
// Validate response code
if respCode, ok := rl["response_code"].(int64); ok {
if respCode, ok := nl["response_code"].(int64); ok {
if respCode > 0 && (respCode < 400 || respCode >= 600) {
return fmt.Errorf("pipeline '%s' sink[%d] %s: response_code must be 4xx or 5xx: %d",
pipelineName, sinkIndex, serverType, respCode)
@ -157,14 +144,25 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
}
// Validate connection limits
maxPerIP, perIPOk := rl["max_connections_per_ip"].(int64)
maxTotal, totalOk := rl["max_total_connections"].(int64)
maxPerIP, perIPOk := nl["max_connections_per_ip"].(int64)
maxPerUser, perUserOk := nl["max_connections_per_user"].(int64)
maxPerToken, perTokenOk := nl["max_connections_per_token"].(int64)
maxTotal, totalOk := nl["max_connections_total"].(int64)
if perIPOk && totalOk && maxPerIP > 0 && maxTotal > 0 {
if perIPOk && perUserOk && perTokenOk && totalOk &&
maxPerIP > 0 && maxPerUser > 0 && maxPerToken > 0 && maxTotal > 0 {
if maxPerIP > maxTotal {
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_total_connections (%d)",
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_connections_total (%d)",
pipelineName, sinkIndex, serverType, maxPerIP, maxTotal)
}
if maxPerUser > maxTotal {
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_user (%d) cannot exceed max_connections_total (%d)",
pipelineName, sinkIndex, serverType, maxPerUser, maxTotal)
}
if maxPerToken > maxTotal {
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_token (%d) cannot exceed max_connections_total (%d)",
pipelineName, sinkIndex, serverType, maxPerToken, maxTotal)
}
}
return nil

View File

@ -49,8 +49,11 @@ type NetLimiter struct {
globalLimiter *TokenBucket
// Connection tracking
ipConnections map[string]*connTracker
connMu sync.RWMutex
ipConnections map[string]*connTracker
userConnections map[string]*connTracker
tokenConnections map[string]*connTracker
totalConnections atomic.Int64
connMu sync.RWMutex
// Statistics
totalRequests atomic.Uint64
@ -102,29 +105,23 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
ctx, cancel := context.WithCancel(context.Background())
l := &NetLimiter{
config: cfg,
logger: logger,
ipWhitelist: make([]*net.IPNet, 0),
ipBlacklist: make([]*net.IPNet, 0),
ipLimiters: make(map[string]*ipLimiter),
ipConnections: make(map[string]*connTracker),
lastCleanup: time.Now(),
ctx: ctx,
cancel: cancel,
cleanupDone: make(chan struct{}),
config: cfg,
logger: logger,
ipWhitelist: make([]*net.IPNet, 0),
ipBlacklist: make([]*net.IPNet, 0),
ipLimiters: make(map[string]*ipLimiter),
ipConnections: make(map[string]*connTracker),
userConnections: make(map[string]*connTracker),
tokenConnections: make(map[string]*connTracker),
lastCleanup: time.Now(),
ctx: ctx,
cancel: cancel,
cleanupDone: make(chan struct{}),
}
// Parse IP lists
l.parseIPLists(cfg)
// Create global limiter if configured
if cfg.Enabled && cfg.LimitBy == "global" {
l.globalLimiter = NewTokenBucket(
float64(cfg.BurstSize),
cfg.RequestsPerSecond,
)
}
// Start cleanup goroutine only if rate limiting is enabled
if cfg.Enabled {
go l.cleanupLoop()
@ -138,7 +135,10 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
"blacklist_rules", len(l.ipBlacklist),
"requests_per_second", cfg.RequestsPerSecond,
"burst_size", cfg.BurstSize,
"limit_by", cfg.LimitBy)
"max_connections_per_ip", cfg.MaxConnectionsPerIP,
"max_connections_per_user", cfg.MaxConnectionsPerUser,
"max_connections_per_token", cfg.MaxConnectionsPerToken,
"max_connections_total", cfg.MaxConnectionsTotal)
return l
}
@ -276,7 +276,7 @@ func (l *NetLimiter) Shutdown() {
}
}
// Checks if an HTTP request should be allowed
// Checks if an HTTP request should be allowed: IP access control + connection limits (IP only) + calls
func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int64, message string) {
if l == nil {
return true, 0, ""
@ -343,7 +343,7 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6
}
// Check rate limit
if !l.checkLimit(ipStr) {
if !l.checkIPLimit(ipStr) {
l.blockedByRateLimit.Add(1)
statusCode = l.config.ResponseCode
if statusCode == 0 {
@ -372,7 +372,7 @@ func (l *NetLimiter) updateConnectionActivity(ip string) {
}
}
// Checks if a TCP connection should be allowed
// Checks if a TCP connection should be allowed: IP access control + calls checkIPLimit()
func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
if l == nil {
return true
@ -412,7 +412,7 @@ func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
// Check rate limit
ipStr := tcpAddr.IP.String()
if !l.checkLimit(ipStr) {
if !l.checkIPLimit(ipStr) {
l.blockedByRateLimit.Add(1)
return false
}
@ -531,17 +531,40 @@ func (l *NetLimiter) GetStats() map[string]any {
return map[string]any{"enabled": false}
}
// Get active rate limiters count
l.ipMu.RLock()
activeIPs := len(l.ipLimiters)
l.ipMu.RUnlock()
// Get connection tracker counts and calculate total active connections
l.connMu.RLock()
totalConnections := 0
ipConnTrackers := len(l.ipConnections)
userConnTrackers := len(l.userConnections)
tokenConnTrackers := len(l.tokenConnections)
// Calculate actual connection count by summing all IP connections
// Potentially more accurate than totalConnections counter which might drift
// TODO: test and refactor if they match
actualIPConnections := 0
for _, tracker := range l.ipConnections {
totalConnections += int(tracker.connections.Load())
actualIPConnections += int(tracker.connections.Load())
}
actualUserConnections := 0
for _, tracker := range l.userConnections {
actualUserConnections += int(tracker.connections.Load())
}
actualTokenConnections := 0
for _, tracker := range l.tokenConnections {
actualTokenConnections += int(tracker.connections.Load())
}
// Use the counter for total (should match actualIPConnections in most cases)
totalConns := l.totalConnections.Load()
l.connMu.RUnlock()
// Calculate total blocked
totalBlocked := l.blockedByBlacklist.Load() +
l.blockedByWhitelist.Load() +
l.blockedByRateLimit.Load() +
@ -559,23 +582,39 @@ func (l *NetLimiter) GetStats() map[string]any {
"conn_limit": l.blockedByConnLimit.Load(),
"invalid_ip": l.blockedByInvalidIP.Load(),
},
"active_ips": activeIPs,
"total_connections": totalConnections,
"acl": map[string]int{
"whitelist_rules": len(l.ipWhitelist),
"blacklist_rules": len(l.ipBlacklist),
},
"rate_limit": map[string]any{
"rate_limiting": map[string]any{
"enabled": l.config.Enabled,
"requests_per_second": l.config.RequestsPerSecond,
"burst_size": l.config.BurstSize,
"limit_by": l.config.LimitBy,
"active_ip_limiters": activeIPs, // IPs being rate-limited
},
"access_control": map[string]any{
"whitelist_rules": len(l.ipWhitelist),
"blacklist_rules": len(l.ipBlacklist),
},
"connections": map[string]any{
// Actual counts
"total_active": totalConns, // Counter-based total
"active_ip_connections": actualIPConnections, // Sum of all IP connections
"active_user_connections": actualUserConnections, // Sum of all user connections
"active_token_connections": actualTokenConnections, // Sum of all token connections
// Tracker counts (number of unique IPs/users/tokens being tracked)
"tracked_ips": ipConnTrackers,
"tracked_users": userConnTrackers,
"tracked_tokens": tokenConnTrackers,
// Configuration limits (0 = disabled)
"limit_per_ip": l.config.MaxConnectionsPerIP,
"limit_per_user": l.config.MaxConnectionsPerUser,
"limit_per_token": l.config.MaxConnectionsPerToken,
"limit_total": l.config.MaxConnectionsTotal,
},
}
}
// Performs the actual net limit check
func (l *NetLimiter) checkLimit(ip string) bool {
// Performs IP net limit check (req/sec)
func (l *NetLimiter) checkIPLimit(ip string) bool {
// Validate IP format
parsedIP := net.ParseIP(ip)
if parsedIP == nil || !isIPv4(parsedIP) {
@ -588,53 +627,36 @@ func (l *NetLimiter) checkLimit(ip string) bool {
// Maybe run cleanup
l.maybeCleanup()
switch l.config.LimitBy {
case "global":
return l.globalLimiter.Allow()
case "ip", "":
// Default to per-IP limiting
l.ipMu.Lock()
lim, exists := l.ipLimiters[ip]
if !exists {
// Create new limiter for this IP
lim = &ipLimiter{
bucket: NewTokenBucket(
float64(l.config.BurstSize),
l.config.RequestsPerSecond,
),
lastSeen: time.Now(),
}
l.ipLimiters[ip] = lim
l.uniqueIPs.Add(1)
l.logger.Debug("msg", "Created new IP limiter",
"ip", ip,
"total_ips", l.uniqueIPs.Load())
} else {
lim.lastSeen = time.Now()
// IP limit
l.ipMu.Lock()
lim, exists := l.ipLimiters[ip]
if !exists {
// Create new limiter for this IP
lim = &ipLimiter{
bucket: NewTokenBucket(
float64(l.config.BurstSize),
l.config.RequestsPerSecond,
),
lastSeen: time.Now(),
}
l.ipMu.Unlock()
l.ipLimiters[ip] = lim
l.uniqueIPs.Add(1)
// Check connection limit if configured
if l.config.MaxConnectionsPerIP > 0 {
l.connMu.RLock()
tracker, exists := l.ipConnections[ip]
l.connMu.RUnlock()
if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
return false
}
}
return lim.bucket.Allow()
default:
// Unknown limit_by value, allow by default
l.logger.Warn("msg", "Unknown limit_by value",
"limit_by", l.config.LimitBy)
return true
l.logger.Debug("msg", "Created new IP limiter",
"ip", ip,
"total_ips", l.uniqueIPs.Load())
} else {
lim.lastSeen = time.Now()
}
l.ipMu.Unlock()
// Rate limit check
allowed := lim.bucket.Allow()
if !allowed {
l.blockedByRateLimit.Add(1)
}
return allowed
}
// Runs cleanup if enough time has passed
@ -691,25 +713,57 @@ func (l *NetLimiter) cleanup() {
// Clean up stale connection trackers
l.connMu.Lock()
connCleaned := 0
// Clean IP connections
ipCleaned := 0
for ip, tracker := range l.ipConnections {
tracker.mu.Lock()
lastSeen := tracker.lastSeen
tracker.mu.Unlock()
// Remove if no activity for 5 minutes AND no active connections
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
delete(l.ipConnections, ip)
connCleaned++
ipCleaned++
}
}
// Clean user connections
userCleaned := 0
for user, tracker := range l.userConnections {
tracker.mu.Lock()
lastSeen := tracker.lastSeen
tracker.mu.Unlock()
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
delete(l.userConnections, user)
userCleaned++
}
}
// Clean token connections
tokenCleaned := 0
for token, tracker := range l.tokenConnections {
tracker.mu.Lock()
lastSeen := tracker.lastSeen
tracker.mu.Unlock()
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
delete(l.tokenConnections, token)
tokenCleaned++
}
}
l.connMu.Unlock()
if connCleaned > 0 {
if ipCleaned > 0 || userCleaned > 0 || tokenCleaned > 0 {
l.logger.Debug("msg", "Cleaned up stale connection trackers",
"component", "netlimit",
"cleaned", connCleaned,
"remaining", len(l.ipConnections))
"ip_cleaned", ipCleaned,
"user_cleaned", userCleaned,
"token_cleaned", tokenCleaned,
"ip_remaining", len(l.ipConnections),
"user_remaining", len(l.userConnections),
"token_remaining", len(l.tokenConnections))
}
}
@ -730,4 +784,164 @@ func (l *NetLimiter) cleanupLoop() {
l.cleanup()
}
}
}
// Tracks a new connection with optional user/token info: Connection limits (IP/user/token/total) for TCP only
func (l *NetLimiter) TrackConnection(ip string, user string, token string) bool {
if l == nil {
return true
}
l.connMu.Lock()
defer l.connMu.Unlock()
// Check total connections limit (0 = disabled)
if l.config.MaxConnectionsTotal > 0 {
currentTotal := l.totalConnections.Load()
if currentTotal >= l.config.MaxConnectionsTotal {
l.blockedByConnLimit.Add(1)
l.logger.Debug("msg", "TCP connection blocked by total limit",
"component", "netlimit",
"current_total", currentTotal,
"max_total", l.config.MaxConnectionsTotal)
return false
}
}
// Check per-IP connection limit (0 = disabled)
if l.config.MaxConnectionsPerIP > 0 && ip != "" {
tracker, exists := l.ipConnections[ip]
if !exists {
tracker = &connTracker{lastSeen: time.Now()}
l.ipConnections[ip] = tracker
}
if tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
l.blockedByConnLimit.Add(1)
l.logger.Debug("msg", "TCP connection blocked by IP limit",
"component", "netlimit",
"ip", ip,
"current", tracker.connections.Load(),
"max", l.config.MaxConnectionsPerIP)
return false
}
}
// Check per-user connection limit (0 = disabled)
if l.config.MaxConnectionsPerUser > 0 && user != "" {
tracker, exists := l.userConnections[user]
if !exists {
tracker = &connTracker{lastSeen: time.Now()}
l.userConnections[user] = tracker
}
if tracker.connections.Load() >= l.config.MaxConnectionsPerUser {
l.blockedByConnLimit.Add(1)
l.logger.Debug("msg", "TCP connection blocked by user limit",
"component", "netlimit",
"user", user,
"current", tracker.connections.Load(),
"max", l.config.MaxConnectionsPerUser)
return false
}
}
// Check per-token connection limit (0 = disabled)
if l.config.MaxConnectionsPerToken > 0 && token != "" {
tracker, exists := l.tokenConnections[token]
if !exists {
tracker = &connTracker{lastSeen: time.Now()}
l.tokenConnections[token] = tracker
}
if tracker.connections.Load() >= l.config.MaxConnectionsPerToken {
l.blockedByConnLimit.Add(1)
l.logger.Debug("msg", "TCP connection blocked by token limit",
"component", "netlimit",
"token", token,
"current", tracker.connections.Load(),
"max", l.config.MaxConnectionsPerToken)
return false
}
}
// All checks passed, increment counters
l.totalConnections.Add(1)
if ip != "" && l.config.MaxConnectionsPerIP > 0 {
if tracker, exists := l.ipConnections[ip]; exists {
tracker.connections.Add(1)
tracker.mu.Lock()
tracker.lastSeen = time.Now()
tracker.mu.Unlock()
}
}
if user != "" && l.config.MaxConnectionsPerUser > 0 {
if tracker, exists := l.userConnections[user]; exists {
tracker.connections.Add(1)
tracker.mu.Lock()
tracker.lastSeen = time.Now()
tracker.mu.Unlock()
}
}
if token != "" && l.config.MaxConnectionsPerToken > 0 {
if tracker, exists := l.tokenConnections[token]; exists {
tracker.connections.Add(1)
tracker.mu.Lock()
tracker.lastSeen = time.Now()
tracker.mu.Unlock()
}
}
return true
}
// Releases a tracked connection
func (l *NetLimiter) ReleaseConnection(ip string, user string, token string) {
if l == nil {
return
}
l.connMu.Lock()
defer l.connMu.Unlock()
// Decrement total
if l.totalConnections.Load() > 0 {
l.totalConnections.Add(-1)
}
// Decrement IP counter
if ip != "" {
if tracker, exists := l.ipConnections[ip]; exists {
if tracker.connections.Load() > 0 {
tracker.connections.Add(-1)
}
tracker.mu.Lock()
tracker.lastSeen = time.Now()
tracker.mu.Unlock()
}
}
// Decrement user counter
if user != "" {
if tracker, exists := l.userConnections[user]; exists {
if tracker.connections.Load() > 0 {
tracker.connections.Add(-1)
}
tracker.mu.Lock()
tracker.lastSeen = time.Now()
tracker.mu.Unlock()
}
}
// Decrement token counter
if token != "" {
if tracker, exists := l.tokenConnections[token]; exists {
if tracker.connections.Load() > 0 {
tracker.connections.Add(-1)
}
tracker.mu.Lock()
tracker.lastSeen = time.Now()
tracker.mu.Unlock()
}
}
}

View File

@ -129,6 +129,11 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
}
}
// Configure authentication for sources that support it
for _, sourceInst := range pipeline.Sources {
sourceInst.SetAuth(cfg.Auth)
}
// Start all sinks
for i, sinkInst := range pipeline.Sinks {
if err := sinkInst.Start(pipelineCtx); err != nil {
@ -139,9 +144,7 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
// Configure authentication for sinks that support it
for _, sinkInst := range pipeline.Sinks {
if setter, ok := sinkInst.(sink.AuthSetter); ok {
setter.SetAuthConfig(cfg.Auth)
}
sinkInst.SetAuth(cfg.Auth)
}
// Wire sources to sinks through filters

View File

@ -9,6 +9,7 @@ import (
"sync/atomic"
"time"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"logwisp/src/internal/format"
@ -121,7 +122,9 @@ func (s *StdoutSink) processLoop(ctx context.Context) {
// Format and write
formatted, err := s.formatter.Format(entry)
if err != nil {
s.logger.Error("msg", "Failed to format log entry for stdout", "error", err)
s.logger.Error("msg", "Failed to format log entry for stdout",
"component", "stdout_sink",
"error", err)
continue
}
s.output.Write(formatted)
@ -234,7 +237,9 @@ func (s *StderrSink) processLoop(ctx context.Context) {
// Format and write
formatted, err := s.formatter.Format(entry)
if err != nil {
s.logger.Error("msg", "Failed to format log entry for stderr", "error", err)
s.logger.Error("msg", "Failed to format log entry for stderr",
"component", "stderr_sink",
"error", err)
continue
}
s.output.Write(formatted)
@ -245,4 +250,12 @@ func (s *StderrSink) processLoop(ctx context.Context) {
return
}
}
}
func (s *StdoutSink) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to stdout sink
}
func (s *StderrSink) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to stderr sink
}

View File

@ -4,6 +4,7 @@ package sink
import (
"context"
"fmt"
"logwisp/src/internal/config"
"sync/atomic"
"time"
@ -43,7 +44,7 @@ func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Fo
writerConfig := log.DefaultConfig()
writerConfig.Directory = directory
writerConfig.Name = name
writerConfig.EnableStdout = false // File only
writerConfig.EnableConsole = false // File only
writerConfig.ShowTimestamp = false // We already have timestamps in entries
writerConfig.ShowLevel = false // We already have levels in entries
@ -164,4 +165,8 @@ func (fs *FileSink) processLoop(ctx context.Context) {
return
}
}
}
func (fs *FileSink) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to file sink
}

View File

@ -142,31 +142,28 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
}
// Extract net limit config
if rl, ok := options["net_limit"].(map[string]any); ok {
if nl, ok := options["net_limit"].(map[string]any); ok {
cfg.NetLimit = &config.NetLimitConfig{}
cfg.NetLimit.Enabled, _ = rl["enabled"].(bool)
if rps, ok := rl["requests_per_second"].(float64); ok {
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
if rps, ok := nl["requests_per_second"].(float64); ok {
cfg.NetLimit.RequestsPerSecond = rps
}
if burst, ok := rl["burst_size"].(int64); ok {
if burst, ok := nl["burst_size"].(int64); ok {
cfg.NetLimit.BurstSize = burst
}
if limitBy, ok := rl["limit_by"].(string); ok {
cfg.NetLimit.LimitBy = limitBy
}
if respCode, ok := rl["response_code"].(int64); ok {
if respCode, ok := nl["response_code"].(int64); ok {
cfg.NetLimit.ResponseCode = respCode
}
if msg, ok := rl["response_message"].(string); ok {
if msg, ok := nl["response_message"].(string); ok {
cfg.NetLimit.ResponseMessage = msg
}
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
}
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
cfg.NetLimit.MaxTotalConnections = maxTotal
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.NetLimit.MaxConnectionsTotal = maxTotal
}
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
for _, entry := range ipWhitelist {
if str, ok := entry.(string); ok {
@ -174,7 +171,7 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
}
}
}
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
for _, entry := range ipBlacklist {
if str, ok := entry.(string); ok {
@ -806,8 +803,8 @@ func (h *HTTPSink) GetHost() string {
return h.config.Host
}
// Configures http sink authentication
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
// Configures http sink auth
func (h *HTTPSink) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}

View File

@ -6,7 +6,10 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"logwisp/src/internal/auth"
"logwisp/src/internal/config"
"net/url"
"os"
"strings"
@ -23,16 +26,17 @@ import (
// Forwards log entries to a remote HTTP endpoint
type HTTPClientSink struct {
input chan core.LogEntry
config HTTPClientConfig
client *fasthttp.Client
batch []core.LogEntry
batchMu sync.Mutex
done chan struct{}
wg sync.WaitGroup
startTime time.Time
logger *log.Logger
formatter format.Formatter
input chan core.LogEntry
config HTTPClientConfig
client *fasthttp.Client
batch []core.LogEntry
batchMu sync.Mutex
done chan struct{}
wg sync.WaitGroup
startTime time.Time
logger *log.Logger
formatter format.Formatter
authenticator *auth.Authenticator
// Statistics
totalProcessed atomic.Uint64
@ -44,7 +48,9 @@ type HTTPClientSink struct {
}
// Holds HTTP client sink configuration
// TODO: missing toml tags
type HTTPClientConfig struct {
// Config
URL string
BufferSize int64
BatchSize int64
@ -57,6 +63,10 @@ type HTTPClientConfig struct {
RetryDelay time.Duration
RetryBackoff float64 // Multiplier for exponential backoff
// Security
Username string
Password string
// TLS configuration
InsecureSkipVerify bool
CAFile string
@ -118,6 +128,12 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
if insecure, ok := options["insecure_skip_verify"].(bool); ok {
cfg.InsecureSkipVerify = insecure
}
if username, ok := options["username"].(string); ok {
cfg.Username = username
}
if password, ok := options["password"].(string); ok {
cfg.Password = password // TODO: change to Argon2 hashed password
}
// Extract headers
if headers, ok := options["headers"].(map[string]any); ok {
@ -422,8 +438,16 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
req.SetRequestURI(h.config.URL)
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.SetBody(body)
// Add Basic Auth header if credentials configured
if h.config.Username != "" && h.config.Password != "" {
creds := h.config.Username + ":" + h.config.Password
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
req.Header.Set("Authorization", "Basic "+encodedCreds)
}
// Set headers
for k, v := range h.config.Headers {
req.Header.Set(k, v)
@ -494,4 +518,10 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
"retries", h.config.MaxRetries,
"last_error", lastErr)
h.failedBatches.Add(1)
}
// Not applicable, Clients authenticate to remote servers using Username/Password in config
func (h *HTTPClientSink) SetAuth(authCfg *config.AuthConfig) {
// No-op: client sinks don't validate incoming connections
// They authenticate to remote servers using Username/Password fields
}

View File

@ -9,19 +9,22 @@ import (
"logwisp/src/internal/core"
)
// Represents an output destination for log entries
// Represents an output data stream
type Sink interface {
// Input returns the channel for sending log entries to this sink
// Returns the channel for sending log entries to this sink
Input() chan<- core.LogEntry
// Start begins processing log entries
// Begins processing log entries
Start(ctx context.Context) error
// Stop gracefully shuts down the sink
// Gracefully shuts down the sink
Stop()
// GetStats returns sink statistics
// Returns sink statistics
GetStats() SinkStats
// Configure authentication
SetAuth(auth *config.AuthConfig)
}
// Contains statistics about a sink
@ -32,9 +35,4 @@ type SinkStats struct {
StartTime time.Time
LastProcessed time.Time
Details map[string]any
}
// Interface for sinks that can accept an AuthConfig
type AuthSetter interface {
SetAuthConfig(auth *config.AuthConfig)
}

View File

@ -17,7 +17,6 @@ import (
"logwisp/src/internal/core"
"logwisp/src/internal/format"
"logwisp/src/internal/limit"
"logwisp/src/internal/tls"
"github.com/lixenwraith/log"
"github.com/lixenwraith/log/compat"
@ -26,23 +25,20 @@ import (
// Streams log entries via TCP
type TCPSink struct {
input chan core.LogEntry
config TCPConfig
server *tcpServer
done chan struct{}
activeConns atomic.Int64
startTime time.Time
engine *gnet.Engine
engineMu sync.Mutex
wg sync.WaitGroup
netLimiter *limit.NetLimiter
logger *log.Logger
formatter format.Formatter
// Security components
// C
input chan core.LogEntry
config TCPConfig
server *tcpServer
done chan struct{}
activeConns atomic.Int64
startTime time.Time
engine *gnet.Engine
engineMu sync.Mutex
wg sync.WaitGroup
netLimiter *limit.NetLimiter
logger *log.Logger
formatter format.Formatter
authenticator *auth.Authenticator
tlsManager *tls.Manager
authConfig *config.AuthConfig
// Statistics
totalProcessed atomic.Uint64
@ -62,7 +58,6 @@ type TCPConfig struct {
Port int64
BufferSize int64
Heartbeat *config.HeartbeatConfig
TLS *config.TLSConfig
NetLimit *config.NetLimitConfig
}
@ -99,58 +94,35 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
}
}
// Extract TLS config
if tc, ok := options["tls"].(map[string]any); ok {
cfg.TLS = &config.TLSConfig{}
cfg.TLS.Enabled, _ = tc["enabled"].(bool)
if certFile, ok := tc["cert_file"].(string); ok {
cfg.TLS.CertFile = certFile
}
if keyFile, ok := tc["key_file"].(string); ok {
cfg.TLS.KeyFile = keyFile
}
cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool)
if caFile, ok := tc["client_ca_file"].(string); ok {
cfg.TLS.ClientCAFile = caFile
}
cfg.TLS.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
if minVer, ok := tc["min_version"].(string); ok {
cfg.TLS.MinVersion = minVer
}
if maxVer, ok := tc["max_version"].(string); ok {
cfg.TLS.MaxVersion = maxVer
}
if ciphers, ok := tc["cipher_suites"].(string); ok {
cfg.TLS.CipherSuites = ciphers
}
}
// Extract net limit config
if rl, ok := options["net_limit"].(map[string]any); ok {
if nl, ok := options["net_limit"].(map[string]any); ok {
cfg.NetLimit = &config.NetLimitConfig{}
cfg.NetLimit.Enabled, _ = rl["enabled"].(bool)
if rps, ok := rl["requests_per_second"].(float64); ok {
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
if rps, ok := nl["requests_per_second"].(float64); ok {
cfg.NetLimit.RequestsPerSecond = rps
}
if burst, ok := rl["burst_size"].(int64); ok {
if burst, ok := nl["burst_size"].(int64); ok {
cfg.NetLimit.BurstSize = burst
}
if limitBy, ok := rl["limit_by"].(string); ok {
cfg.NetLimit.LimitBy = limitBy
}
if respCode, ok := rl["response_code"].(int64); ok {
if respCode, ok := nl["response_code"].(int64); ok {
cfg.NetLimit.ResponseCode = respCode
}
if msg, ok := rl["response_message"].(string); ok {
if msg, ok := nl["response_message"].(string); ok {
cfg.NetLimit.ResponseMessage = msg
}
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
}
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
cfg.NetLimit.MaxTotalConnections = maxTotal
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
cfg.NetLimit.MaxConnectionsPerUser = maxPerUser
}
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
cfg.NetLimit.MaxConnectionsPerToken = maxPerToken
}
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.NetLimit.MaxConnectionsTotal = maxTotal
}
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
for _, entry := range ipWhitelist {
if str, ok := entry.(string); ok {
@ -158,7 +130,7 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
}
}
}
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
for _, entry := range ipBlacklist {
if str, ok := entry.(string); ok {
@ -290,18 +262,6 @@ func (t *TCPSink) GetStats() SinkStats {
netLimitStats = t.netLimiter.GetStats()
}
var authStats map[string]any
if t.authenticator != nil {
authStats = t.authenticator.GetStats()
authStats["failures"] = t.authFailures.Load()
authStats["successes"] = t.authSuccesses.Load()
}
var tlsStats map[string]any
if t.tlsManager != nil {
tlsStats = t.tlsManager.GetStats()
}
return SinkStats{
Type: "tcp",
TotalProcessed: t.totalProcessed.Load(),
@ -312,8 +272,7 @@ func (t *TCPSink) GetStats() SinkStats {
"port": t.config.Port,
"buffer_size": t.config.BufferSize,
"net_limit": netLimitStats,
"auth": authStats,
"tls": tlsStats,
"auth": map[string]any{"enabled": false},
},
}
}
@ -347,37 +306,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
"entry_source", entry.Source)
continue
}
// Broadcast only to authenticated clients
t.server.mu.RLock()
for conn, client := range t.server.clients {
if client.authenticated {
// Send through TLS bridge if present
if client.tlsBridge != nil {
if _, err := client.tlsBridge.Write(data); err != nil {
// TLS write failed, connection likely dead
t.logger.Debug("msg", "TLS write failed",
"component", "tcp_sink",
"error", err)
conn.Close()
}
} else {
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
if err != nil {
t.writeErrors.Add(1)
t.handleWriteError(c, err)
} else {
// Reset consecutive error count on success
t.errorMu.Lock()
delete(t.consecutiveWriteErrors, c)
t.errorMu.Unlock()
}
return nil
})
}
}
}
t.server.mu.RUnlock()
t.broadcastData(data)
case <-tickerChan:
heartbeatEntry := t.createHeartbeatEntry()
@ -388,37 +317,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
"error", err)
continue
}
t.server.mu.RLock()
for conn, client := range t.server.clients {
if client.authenticated {
// Validate session is still active
if t.authenticator != nil && client.session != nil {
if !t.authenticator.ValidateSession(client.session.ID) {
// Session expired, close connection
conn.Close()
continue
}
}
if client.tlsBridge != nil {
if _, err := client.tlsBridge.Write(data); err != nil {
t.logger.Debug("msg", "TLS heartbeat write failed",
"component", "tcp_sink",
"error", err)
conn.Close()
}
} else {
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
if err != nil {
t.writeErrors.Add(1)
t.handleWriteError(c, err)
}
return nil
})
}
}
}
t.server.mu.RUnlock()
t.broadcastData(data)
case <-t.done:
return
@ -426,6 +325,28 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
}
}
func (t *TCPSink) broadcastData(data []byte) {
t.server.mu.RLock()
defer t.server.mu.RUnlock()
for conn, client := range t.server.clients {
if client.authenticated {
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
if err != nil {
t.writeErrors.Add(1)
t.handleWriteError(c, err)
} else {
// Reset consecutive error count on success
t.errorMu.Lock()
delete(t.consecutiveWriteErrors, c)
t.errorMu.Unlock()
}
return nil
})
}
}
}
// Handle write errors with threshold-based connection termination
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
t.errorMu.Lock()
@ -487,13 +408,11 @@ func (t *TCPSink) GetActiveConnections() int64 {
// Represents a connected TCP client with auth state
type tcpClient struct {
conn gnet.Conn
buffer bytes.Buffer
authenticated bool
session *auth.Session
authTimeout time.Time
tlsBridge *tls.GNetTLSConn
authTimeoutSet bool
conn gnet.Conn
buffer bytes.Buffer
authenticated bool
authTimeout time.Time
session *auth.Session
}
// Handles gnet events with authentication
@ -550,24 +469,12 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
// Create client state without auth timeout initially
client := &tcpClient{
conn: c,
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
authTimeoutSet: false, // Auth timeout not started yet
conn: c,
authenticated: s.sink.authenticator == nil,
}
// Initialize TLS bridge if enabled
if s.sink.tlsManager != nil {
tlsConfig := s.sink.tlsManager.GetTCPConfig()
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
client.tlsBridge.Handshake() // Start async handshake
s.sink.logger.Debug("msg", "TLS handshake initiated",
"component", "tcp_sink",
"remote_addr", remoteAddr)
} else if s.sink.authenticator != nil {
// Only set auth timeout if no TLS (plain connection)
client.authTimeout = time.Now().Add(30 * time.Second) // TODO: configurable or non-hardcoded timer
client.authTimeoutSet = true
if s.sink.authenticator != nil {
client.authTimeout = time.Now().Add(30 * time.Second)
}
s.mu.Lock()
@ -578,12 +485,11 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
s.sink.logger.Debug("msg", "TCP connection opened",
"remote_addr", remoteAddr,
"active_connections", newCount,
"requires_auth", s.sink.authenticator != nil)
"auth_enabled", s.sink.authenticator != nil)
// Send auth prompt if authentication is required
if s.sink.authenticator != nil && s.sink.tlsManager == nil {
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
return authPrompt, gnet.None
if s.sink.authenticator != nil {
return []byte("AUTH_REQUIRED\n"), gnet.None
}
return nil, gnet.None
@ -594,17 +500,9 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
// Remove client state
s.mu.Lock()
client := s.clients[c]
delete(s.clients, c)
s.mu.Unlock()
// Clean up TLS bridge if present
if client != nil && client.tlsBridge != nil {
client.tlsBridge.Close()
s.sink.logger.Debug("msg", "TLS connection closed",
"remote_addr", remoteAddr)
}
// Clean up write error tracking
s.sink.errorMu.Lock()
delete(s.sink.consecutiveWriteErrors, c)
@ -632,98 +530,34 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.Close
}
// Read all available data
data, err := c.Next(-1)
if err != nil {
s.sink.logger.Error("msg", "Error reading from connection",
"component", "tcp_sink",
"error", err)
return gnet.Close
}
// Process through TLS bridge if present
if client.tlsBridge != nil {
// Feed encrypted data into TLS engine
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
s.sink.logger.Error("msg", "TLS processing error",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String(),
"error", err)
return gnet.Close
}
// Check if handshake is complete
if !client.tlsBridge.IsHandshakeDone() {
// Still handshaking, wait for more data
return gnet.None
}
// Check handshake result
_, hsErr := client.tlsBridge.HandshakeComplete()
if hsErr != nil {
s.sink.logger.Error("msg", "TLS handshake failed",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String(),
"error", hsErr)
return gnet.Close
}
// Set auth timeout only after TLS handshake completes
if !client.authTimeoutSet && s.sink.authenticator != nil && !client.authenticated {
client.authTimeout = time.Now().Add(30 * time.Second)
client.authTimeoutSet = true
s.sink.logger.Debug("msg", "Auth timeout started after TLS handshake",
// Authentication phase
if !client.authenticated {
// Check auth timeout
if time.Now().After(client.authTimeout) {
s.sink.logger.Warn("msg", "Authentication timeout",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String())
return gnet.Close
}
// Read decrypted plaintext
data = client.tlsBridge.Read()
if data == nil || len(data) == 0 {
// No plaintext available yet
// Read auth data
data, _ := c.Next(-1)
if len(data) == 0 {
return gnet.None
}
// First data after TLS handshake - send auth prompt if needed
if s.sink.authenticator != nil && !client.authenticated &&
len(client.buffer.Bytes()) == 0 {
authPrompt := []byte("AUTH REQUIRED\n")
client.tlsBridge.Write(authPrompt)
}
}
// Only check auth timeout if it has been set
if !client.authenticated && client.authTimeoutSet && time.Now().After(client.authTimeout) {
s.sink.logger.Warn("msg", "Authentication timeout",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String())
if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
} else if client.tlsBridge == nil {
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
}
return gnet.Close
}
// If not authenticated, expect auth command
if !client.authenticated {
client.buffer.Write(data)
// Look for complete auth line
if line, err := client.buffer.ReadBytes('\n'); err == nil {
line = bytes.TrimSpace(line)
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 {
line := client.buffer.Bytes()[:idx]
client.buffer.Next(idx + 1)
// Parse AUTH command: AUTH <method> <credentials>
parts := strings.SplitN(string(line), " ", 3)
if len(parts) != 3 || parts[0] != "AUTH" {
// Send error through TLS if enabled
errMsg := []byte("AUTH FAILED\n")
if client.tlsBridge != nil {
client.tlsBridge.Write(errMsg)
} else {
c.AsyncWrite(errMsg, nil)
}
return gnet.None
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
return gnet.Close
}
// Authenticate
@ -734,13 +568,7 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
"remote_addr", c.RemoteAddr().String(),
"method", parts[1],
"error", err)
// Send error through TLS if enabled
errMsg := []byte("AUTH FAILED\n")
if client.tlsBridge != nil {
client.tlsBridge.Write(errMsg)
} else {
c.AsyncWrite(errMsg, nil)
}
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
return gnet.Close
}
@ -755,35 +583,25 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String(),
"username", session.Username,
"method", session.Method,
"tls", client.tlsBridge != nil)
"method", session.Method)
// Send success through TLS if enabled
successMsg := []byte("AUTH OK\n")
if client.tlsBridge != nil {
client.tlsBridge.Write(successMsg)
} else {
c.AsyncWrite(successMsg, nil)
}
// Clear buffer after auth
c.AsyncWrite([]byte("AUTH_OK\n"), nil)
client.buffer.Reset()
}
return gnet.None
}
// Authenticated clients shouldn't send data, just discard
// Clients shouldn't send data, just discard
c.Discard(-1)
return gnet.None
}
// Configures tcp sink authentication
func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
// Configures tcp sink auth
func (t *TCPSink) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
t.authConfig = authCfg
authenticator, err := auth.New(authCfg, t.logger)
if err != nil {
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
@ -793,22 +611,7 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
}
t.authenticator = authenticator
// Initialize TLS manager if TLS is configured
if t.config.TLS != nil && t.config.TLS.Enabled {
tlsManager, err := tls.NewManager(t.config.TLS, t.logger)
if err != nil {
t.logger.Error("msg", "Failed to create TLS manager",
"component", "tcp_sink",
"error", err)
// Continue without TLS
return
}
t.tlsManager = tlsManager
}
t.logger.Info("msg", "Authentication configured for TCP sink",
"component", "tcp_sink",
"auth_type", authCfg.Type,
"tls_enabled", t.tlsManager != nil,
"tls_bridge", t.tlsManager != nil)
"auth_type", authCfg.Type)
}

View File

@ -2,41 +2,37 @@
package sink
import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"logwisp/src/internal/auth"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"logwisp/src/internal/format"
tlspkg "logwisp/src/internal/tls"
"github.com/lixenwraith/log"
)
// Forwards log entries to a remote TCP endpoint
type TCPClientSink struct {
input chan core.LogEntry
config TCPClientConfig
conn net.Conn
connMu sync.RWMutex
done chan struct{}
wg sync.WaitGroup
startTime time.Time
logger *log.Logger
formatter format.Formatter
// TLS support
tlsManager *tlspkg.Manager
tlsConfig *tls.Config
input chan core.LogEntry
config TCPClientConfig
conn net.Conn
connMu sync.RWMutex
done chan struct{}
wg sync.WaitGroup
startTime time.Time
logger *log.Logger
formatter format.Formatter
authenticator *auth.Authenticator
// Reconnection state
reconnecting atomic.Bool
@ -60,6 +56,10 @@ type TCPClientConfig struct {
ReadTimeout time.Duration
KeepAlive time.Duration
// Security
Username string
Password string
// Reconnection settings
ReconnectDelay time.Duration
MaxReconnectDelay time.Duration
@ -120,27 +120,11 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
cfg.ReconnectBackoff = backoff
}
// Extract TLS config
if tc, ok := options["tls"].(map[string]any); ok {
cfg.TLS = &config.TLSConfig{}
cfg.TLS.Enabled, _ = tc["enabled"].(bool)
if certFile, ok := tc["cert_file"].(string); ok {
cfg.TLS.CertFile = certFile
}
if keyFile, ok := tc["key_file"].(string); ok {
cfg.TLS.KeyFile = keyFile
}
cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool)
if caFile, ok := tc["client_ca_file"].(string); ok {
cfg.TLS.ClientCAFile = caFile
}
if insecure, ok := tc["insecure_skip_verify"].(bool); ok {
cfg.TLS.InsecureSkipVerify = insecure
}
if caFile, ok := tc["ca_file"].(string); ok {
cfg.TLS.CAFile = caFile
}
if username, ok := options["username"].(string); ok {
cfg.Username = username
}
if password, ok := options["password"].(string); ok {
cfg.Password = password
}
t := &TCPClientSink{
@ -154,62 +138,6 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
t.lastProcessed.Store(time.Time{})
t.connectionUptime.Store(time.Duration(0))
// Initialize TLS manager if TLS is configured
if cfg.TLS != nil && cfg.TLS.Enabled {
// Build custom TLS config for client
t.tlsConfig = &tls.Config{
InsecureSkipVerify: cfg.TLS.InsecureSkipVerify,
}
// Extract server name from address for SNI
host, _, err := net.SplitHostPort(cfg.Address)
if err != nil {
return nil, fmt.Errorf("failed to parse address for SNI: %w", err)
}
t.tlsConfig.ServerName = host
// Load custom CA for server verification
if cfg.TLS.CAFile != "" {
caCert, err := os.ReadFile(cfg.TLS.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.TLS.CAFile, err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.TLS.CAFile)
}
t.tlsConfig.RootCAs = caCertPool
logger.Debug("msg", "Custom CA loaded for server verification",
"component", "tcp_client_sink",
"ca_file", cfg.TLS.CAFile)
}
// Load client certificate for mTLS
if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" {
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertFile, cfg.TLS.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}
t.tlsConfig.Certificates = []tls.Certificate{cert}
logger.Info("msg", "Client certificate loaded for mTLS",
"component", "tcp_client_sink",
"cert_file", cfg.TLS.CertFile)
}
// Set minimum TLS version if configured
if cfg.TLS.MinVersion != "" {
t.tlsConfig.MinVersion = parseTLSVersion(cfg.TLS.MinVersion, tls.VersionTLS12)
} else {
t.tlsConfig.MinVersion = tls.VersionTLS12 // Default minimum
}
logger.Info("msg", "TLS enabled for TCP client",
"component", "tcp_client_sink",
"address", cfg.Address,
"server_name", host,
"insecure", cfg.TLS.InsecureSkipVerify,
"mtls", cfg.TLS.CertFile != "")
}
return t, nil
}
@ -376,33 +304,44 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
}
// Wrap with TLS if configured
if t.tlsConfig != nil {
t.logger.Debug("msg", "Initiating TLS handshake",
"component", "tcp_client_sink",
"address", t.config.Address)
tlsConn := tls.Client(conn, t.tlsConfig)
// Perform handshake with timeout
handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := tlsConn.HandshakeContext(handshakeCtx); err != nil {
// Handle authentication if credentials configured
if t.config.Username != "" && t.config.Password != "" {
// Read auth challenge
reader := bufio.NewReader(conn)
challenge, err := reader.ReadString('\n')
if err != nil {
conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
return nil, fmt.Errorf("failed to read auth challenge: %w", err)
}
// Log connection details
state := tlsConn.ConnectionState()
t.logger.Info("msg", "TLS connection established",
"component", "tcp_client_sink",
"address", t.config.Address,
"tls_version", tlsVersionString(state.Version),
"cipher_suite", tls.CipherSuiteName(state.CipherSuite),
"server_name", state.ServerName)
if strings.TrimSpace(challenge) == "AUTH_REQUIRED" {
// Send credentials
creds := t.config.Username + ":" + t.config.Password
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
authCmd := fmt.Sprintf("AUTH basic %s\n", encodedCreds)
return tlsConn, nil
if _, err := conn.Write([]byte(authCmd)); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to send auth: %w", err)
}
// Read response
response, err := reader.ReadString('\n')
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to read auth response: %w", err)
}
if strings.TrimSpace(response) != "AUTH_OK" {
conn.Close()
return nil, fmt.Errorf("authentication failed: %s", response)
}
t.logger.Debug("msg", "TCP authentication successful",
"component", "tcp_client_sink",
"address", t.config.Address,
"username", t.config.Username)
}
}
return conn, nil
@ -504,34 +443,8 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
return nil
}
// Returns human-readable TLS version
func tlsVersionString(version uint16) string {
switch version {
case tls.VersionTLS10:
return "TLS1.0"
case tls.VersionTLS11:
return "TLS1.1"
case tls.VersionTLS12:
return "TLS1.2"
case tls.VersionTLS13:
return "TLS1.3"
default:
return fmt.Sprintf("0x%04x", version)
}
}
// Converts string to TLS version constant
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
switch strings.ToUpper(version) {
case "TLS1.0", "TLS10":
return tls.VersionTLS10
case "TLS1.1", "TLS11":
return tls.VersionTLS11
case "TLS1.2", "TLS12":
return tls.VersionTLS12
case "TLS1.3", "TLS13":
return tls.VersionTLS13
default:
return defaultVersion
}
// Not applicable, Clients authenticate to remote servers using Username/Password in config
func (h *TCPClientSink) SetAuth(authCfg *config.AuthConfig) {
// No-op: client sinks don't validate incoming connections
// They authenticate to remote servers using Username/Password fields
}

View File

@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"github.com/lixenwraith/log"
@ -286,4 +287,8 @@ func globToRegex(glob string) string {
regex = strings.ReplaceAll(regex, `\*`, `.*`)
regex = strings.ReplaceAll(regex, `\?`, `.`)
return "^" + regex + "$"
}
func (ds *DirectorySource) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to directory source
}

View File

@ -4,6 +4,7 @@ package source
import (
"encoding/json"
"fmt"
"logwisp/src/internal/auth"
"net"
"sync"
"sync/atomic"
@ -20,21 +21,31 @@ import (
// Receives log entries via HTTP POST requests
type HTTPSource struct {
host string
port int64
path string
bufferSize int64
// Config
host string
port int64
path string
bufferSize int64
maxRequestBodySize int64
// Application
server *fasthttp.Server
subscribers []chan core.LogEntry
mu sync.RWMutex
done chan struct{}
wg sync.WaitGroup
netLimiter *limit.NetLimiter
logger *log.Logger
// Add TLS support
tlsManager *tls.Manager
tlsConfig *config.TLSConfig
// Runtime
mu sync.RWMutex
done chan struct{}
wg sync.WaitGroup
// Security
authenticator *auth.Authenticator
authConfig *config.AuthConfig
authFailures atomic.Uint64
authSuccesses atomic.Uint64
tlsManager *tls.Manager
tlsConfig *config.TLSConfig
// Statistics
totalEntries atomic.Uint64
@ -66,42 +77,54 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err
bufferSize = bufSize
}
maxRequestBodySize := int64(10 * 1024 * 1024) // fasthttp default 10MB
if maxBodySize, ok := options["max_body_size"].(int64); ok && maxBodySize > 0 && maxBodySize < maxRequestBodySize {
maxRequestBodySize = maxBodySize
}
h := &HTTPSource{
host: host,
port: port,
path: ingestPath,
bufferSize: bufferSize,
done: make(chan struct{}),
startTime: time.Now(),
logger: logger,
host: host,
port: port,
path: ingestPath,
bufferSize: bufferSize,
maxRequestBodySize: maxRequestBodySize,
done: make(chan struct{}),
startTime: time.Now(),
logger: logger,
}
h.lastEntryTime.Store(time.Time{})
// Initialize net limiter if configured
if rl, ok := options["net_limit"].(map[string]any); ok {
if enabled, _ := rl["enabled"].(bool); enabled {
if nl, ok := options["net_limit"].(map[string]any); ok {
if enabled, _ := nl["enabled"].(bool); enabled {
cfg := config.NetLimitConfig{
Enabled: true,
}
if rps, ok := rl["requests_per_second"].(float64); ok {
if rps, ok := nl["requests_per_second"].(float64); ok {
cfg.RequestsPerSecond = rps
}
if burst, ok := rl["burst_size"].(int64); ok {
if burst, ok := nl["burst_size"].(int64); ok {
cfg.BurstSize = burst
}
if limitBy, ok := rl["limit_by"].(string); ok {
cfg.LimitBy = limitBy
}
if respCode, ok := rl["response_code"].(int64); ok {
if respCode, ok := nl["response_code"].(int64); ok {
cfg.ResponseCode = respCode
}
if msg, ok := rl["response_message"].(string); ok {
if msg, ok := nl["response_message"].(string); ok {
cfg.ResponseMessage = msg
}
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
cfg.MaxConnectionsPerIP = maxPerIP
}
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
cfg.MaxConnectionsPerUser = maxPerUser
}
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
cfg.MaxConnectionsPerToken = maxPerToken
}
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.MaxConnectionsTotal = maxTotal
}
h.netLimiter = limit.NewNetLimiter(cfg, logger)
}
@ -157,10 +180,11 @@ func (h *HTTPSource) Subscribe() <-chan core.LogEntry {
func (h *HTTPSource) Start() error {
h.server = &fasthttp.Server{
Handler: h.requestHandler,
DisableKeepalive: false,
StreamRequestBody: true,
CloseOnShutdown: true,
Handler: h.requestHandler,
DisableKeepalive: false,
StreamRequestBody: true,
CloseOnShutdown: true,
MaxRequestBodySize: int(h.maxRequestBodySize),
}
// Use configured host and port
@ -259,19 +283,9 @@ func (h *HTTPSource) GetStats() SourceStats {
}
func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
// Only handle POST to the configured ingest path
if string(ctx.Method()) != "POST" || string(ctx.Path()) != h.path {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "Not Found",
"hint": fmt.Sprintf("POST logs to %s", h.path),
})
return
}
// Extract and validate IP
remoteAddr := ctx.RemoteAddr().String()
// 1. IPv6 check (early reject)
ipStr, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
if ip := net.ParseIP(ipStr); ip != nil && ip.To4() == nil {
@ -284,7 +298,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
}
}
// Check net limit
// 2. Net limit check (early reject)
if h.netLimiter != nil {
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
ctx.SetStatusCode(int(statusCode))
@ -297,7 +311,65 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
}
}
// Process the request body
// 3. Path check (only process ingest path)
path := string(ctx.Path())
if path != h.path {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "Not Found",
"hint": fmt.Sprintf("POST logs to %s", h.path),
})
return
}
// 4. Method check (only accept POST)
if string(ctx.Method()) != "POST" {
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
ctx.SetContentType("application/json")
ctx.Response.Header.Set("Allow", "POST")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "Method not allowed",
"hint": "Use POST to submit logs",
})
return
}
// 5. Authentication check (if configured)
if h.authenticator != nil {
authHeader := string(ctx.Request.Header.Peek("Authorization"))
session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
if err != nil {
h.authFailures.Add(1)
h.logger.Warn("msg", "Authentication failed",
"component", "http_source",
"remote_addr", remoteAddr,
"error", err)
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
realm := h.authConfig.BasicAuth.Realm
if realm == "" {
realm = "Restricted"
}
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm))
} else if h.authConfig.Type == "bearer" {
ctx.Response.Header.Set("WWW-Authenticate", "Bearer")
}
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "Unauthorized",
})
return
}
h.authSuccesses.Add(1)
h.logger.Debug("msg", "Request authenticated",
"component", "http_source",
"remote_addr", remoteAddr,
"username", session.Username)
}
// 6. Process request body
body := ctx.PostBody()
if len(body) == 0 {
ctx.SetStatusCode(fasthttp.StatusBadRequest)
@ -308,7 +380,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
return
}
// Parse the log entries
// 7. Parse log entries
entries, err := h.parseEntries(body)
if err != nil {
h.invalidEntries.Add(1)
@ -320,7 +392,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
return
}
// Publish entries
// 8. Publish entries to subscribers
accepted := 0
for _, entry := range entries {
if h.publish(entry) {
@ -328,7 +400,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
}
}
// Return success response
// 9. Return success response
ctx.SetStatusCode(fasthttp.StatusAccepted)
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]any{
@ -460,4 +532,25 @@ func splitLines(data []byte) [][]byte {
}
return lines
}
// Configure HTTP source auth
func (h *HTTPSource) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
h.authConfig = authCfg
authenticator, err := auth.New(authCfg, h.logger)
if err != nil {
h.logger.Error("msg", "Failed to initialize authenticator for HTTP source",
"component", "http_source",
"error", err)
return
}
h.authenticator = authenticator
h.logger.Info("msg", "Authentication configured for HTTP source",
"component", "http_source",
"auth_type", authCfg.Type)
}

View File

@ -4,22 +4,26 @@ package source
import (
"time"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
)
// Represents an input data stream
type Source interface {
// Subscribe returns a channel that receives log entries
// Returns a channel that receives log entries
Subscribe() <-chan core.LogEntry
// Start begins reading from the source
// Begins reading from the source
Start() error
// Stop gracefully shuts down the source
// Gracefully shuts down the source
Stop()
// GetStats returns source statistics
// Returns source statistics
GetStats() SourceStats
// Configure authentication
SetAuth(auth *config.AuthConfig)
}
// Contains statistics about a source

View File

@ -7,6 +7,7 @@ import (
"sync/atomic"
"time"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"github.com/lixenwraith/log"
@ -118,4 +119,8 @@ func (s *StdinSource) publish(entry core.LogEntry) {
"component", "stdin_source")
}
}
}
func (s *StdinSource) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to stdin source
}

View File

@ -5,9 +5,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@ -16,7 +16,6 @@ import (
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"logwisp/src/internal/limit"
"logwisp/src/internal/tls"
"github.com/lixenwraith/log"
"github.com/lixenwraith/log/compat"
@ -24,28 +23,25 @@ import (
)
const (
maxClientBufferSize = 10 * 1024 * 1024 // 10MB max per client
maxLineLength = 1 * 1024 * 1024 // 1MB max per log line
maxEncryptedDataPerRead = 1 * 1024 * 1024 // 1MB max encrypted data per read
maxCumulativeEncrypted = 20 * 1024 * 1024 // 20MB total encrypted before processing
maxClientBufferSize = 10 * 1024 * 1024 // 10MB max per client
maxLineLength = 1 * 1024 * 1024 // 1MB max per log line
)
// Receives log entries via TCP connections
type TCPSource struct {
host string
port int64
bufferSize int64
server *tcpSourceServer
subscribers []chan core.LogEntry
mu sync.RWMutex
done chan struct{}
engine *gnet.Engine
engineMu sync.Mutex
wg sync.WaitGroup
netLimiter *limit.NetLimiter
tlsManager *tls.Manager
tlsConfig *config.TLSConfig
logger *log.Logger
host string
port int64
bufferSize int64
server *tcpSourceServer
subscribers []chan core.LogEntry
mu sync.RWMutex
done chan struct{}
engine *gnet.Engine
engineMu sync.Mutex
wg sync.WaitGroup
netLimiter *limit.NetLimiter
logger *log.Logger
authenticator *auth.Authenticator
// Statistics
totalEntries atomic.Uint64
@ -54,6 +50,8 @@ type TCPSource struct {
activeConns atomic.Int64
startTime time.Time
lastEntryTime atomic.Value // time.Time
authFailures atomic.Uint64
authSuccesses atomic.Uint64
}
// Creates a new TCP server source
@ -84,58 +82,35 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error
t.lastEntryTime.Store(time.Time{})
// Initialize net limiter if configured
if rl, ok := options["net_limit"].(map[string]any); ok {
if enabled, _ := rl["enabled"].(bool); enabled {
if nl, ok := options["net_limit"].(map[string]any); ok {
if enabled, _ := nl["enabled"].(bool); enabled {
cfg := config.NetLimitConfig{
Enabled: true,
}
if rps, ok := rl["requests_per_second"].(float64); ok {
if rps, ok := nl["requests_per_second"].(float64); ok {
cfg.RequestsPerSecond = rps
}
if burst, ok := rl["burst_size"].(int64); ok {
if burst, ok := nl["burst_size"].(int64); ok {
cfg.BurstSize = burst
}
if limitBy, ok := rl["limit_by"].(string); ok {
cfg.LimitBy = limitBy
}
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
cfg.MaxConnectionsPerIP = maxPerIP
}
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
cfg.MaxTotalConnections = maxTotal
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
cfg.MaxConnectionsPerUser = maxPerUser
}
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
cfg.MaxConnectionsPerToken = maxPerToken
}
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.MaxConnectionsTotal = maxTotal
}
t.netLimiter = limit.NewNetLimiter(cfg, logger)
}
}
// Extract TLS config and initialize TLS manager
if tc, ok := options["tls"].(map[string]any); ok {
t.tlsConfig = &config.TLSConfig{}
t.tlsConfig.Enabled, _ = tc["enabled"].(bool)
if certFile, ok := tc["cert_file"].(string); ok {
t.tlsConfig.CertFile = certFile
}
if keyFile, ok := tc["key_file"].(string); ok {
t.tlsConfig.KeyFile = keyFile
}
t.tlsConfig.ClientAuth, _ = tc["client_auth"].(bool)
if caFile, ok := tc["client_ca_file"].(string); ok {
t.tlsConfig.ClientCAFile = caFile
}
t.tlsConfig.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
// Create TLS manager if enabled
if t.tlsConfig.Enabled {
tlsManager, err := tls.NewManager(t.tlsConfig, logger)
if err != nil {
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
}
t.tlsManager = tlsManager
}
}
return t, nil
}
@ -167,8 +142,7 @@ func (t *TCPSource) Start() error {
defer t.wg.Done()
t.logger.Info("msg", "TCP source server starting",
"component", "tcp_source",
"port", t.port,
"tls_enabled", t.tlsManager != nil)
"port", t.port)
err := gnet.Run(t.server, addr,
gnet.WithLogger(gnetLogger),
@ -283,9 +257,8 @@ type tcpClient struct {
conn gnet.Conn
buffer bytes.Buffer
authenticated bool
session *auth.Session
authTimeout time.Time
tlsBridge *tls.GNetTLSConn
session *auth.Session
maxBufferSeen int
cumulativeEncrypted int64
}
@ -339,22 +312,17 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
}
// Create client state
client := &tcpClient{conn: c}
// Initialize TLS bridge if enabled
if s.source.tlsManager != nil {
tlsConfig := s.source.tlsManager.GetTCPConfig()
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
client.tlsBridge.Handshake() // Start async handshake
s.source.logger.Debug("msg", "TLS handshake initiated",
"component", "tcp_source",
"remote_addr", remoteAddr)
client := &tcpClient{
conn: c,
authenticated: s.source.authenticator == nil,
}
if s.source.authenticator != nil {
client.authTimeout = time.Now().Add(30 * time.Second)
}
// Create client state
s.mu.Lock()
s.clients[c] = &tcpClient{conn: c}
s.clients[c] = client
s.mu.Unlock()
newCount := s.source.activeConns.Add(1)
@ -362,7 +330,12 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
"component", "tcp_source",
"remote_addr", remoteAddr,
"active_connections", newCount,
"tls_enabled", s.source.tlsManager != nil)
"requires_auth", s.source.authenticator != nil)
// Send auth challenge if required
if s.source.authenticator != nil {
return []byte("AUTH_REQUIRED\n"), gnet.None
}
return nil, gnet.None
}
@ -372,18 +345,9 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
// Remove client state
s.mu.Lock()
client := s.clients[c]
delete(s.clients, c)
s.mu.Unlock()
// Clean up TLS bridge if present
if client != nil && client.tlsBridge != nil {
client.tlsBridge.Close()
s.source.logger.Debug("msg", "TLS connection closed",
"component", "tcp_source",
"remote_addr", remoteAddr)
}
// Remove connection tracking
if s.source.netLimiter != nil {
s.source.netLimiter.RemoveConnection(remoteAddr)
@ -416,79 +380,64 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.Close
}
// Check encrypted data size BEFORE processing through TLS
if len(data) > maxEncryptedDataPerRead {
s.source.logger.Warn("msg", "Encrypted data per read limit exceeded",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"data_size", len(data),
"limit", maxEncryptedDataPerRead)
s.source.invalidEntries.Add(1)
return gnet.Close
}
// Authentication phase
if !client.authenticated {
if time.Now().After(client.authTimeout) {
s.source.logger.Warn("msg", "Authentication timeout",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String())
return gnet.Close
}
// Track cumulative encrypted data to prevent slow accumulation
client.cumulativeEncrypted += int64(len(data))
if client.cumulativeEncrypted > maxCumulativeEncrypted {
s.source.logger.Warn("msg", "Cumulative encrypted data limit exceeded",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"total_encrypted", client.cumulativeEncrypted,
"limit", maxCumulativeEncrypted)
s.source.invalidEntries.Add(1)
return gnet.Close
}
client.buffer.Write(data)
// Process through TLS bridge if present
if client.tlsBridge != nil {
// Feed encrypted data into TLS engine
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
if errors.Is(err, tls.ErrTLSBackpressure) {
s.source.logger.Warn("msg", "TLS backpressure, closing slow client",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String())
} else {
s.source.logger.Error("msg", "TLS processing error",
// Look for auth line
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 {
line := client.buffer.Bytes()[:idx]
client.buffer.Next(idx + 1)
parts := strings.SplitN(string(line), " ", 3)
if len(parts) != 3 || parts[0] != "AUTH" {
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
return gnet.Close
}
session, err := s.source.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String())
if err != nil {
s.source.authFailures.Add(1)
s.source.logger.Warn("msg", "Authentication failed",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"error", err)
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
return gnet.Close
}
return gnet.Close
}
// Check if handshake is complete
if !client.tlsBridge.IsHandshakeDone() {
// Still handshaking, wait for more data
return gnet.None
}
s.source.authSuccesses.Add(1)
s.mu.Lock()
client.authenticated = true
client.session = session
s.mu.Unlock()
// Check handshake result
_, hsErr := client.tlsBridge.HandshakeComplete()
if hsErr != nil {
s.source.logger.Error("msg", "TLS handshake failed",
s.source.logger.Info("msg", "TCP client authenticated",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"error", hsErr)
return gnet.Close
}
"username", session.Username)
// Read decrypted plaintext
data = client.tlsBridge.Read()
if data == nil || len(data) == 0 {
// No plaintext available yet
return gnet.None
c.AsyncWrite([]byte("AUTH_OK\n"), nil)
client.buffer.Reset()
}
// Reset cumulative counter after successful decryption and processing
client.cumulativeEncrypted = 0
return gnet.None
}
// Check buffer size before appending
// Check if appending the new data would exceed the client buffer limit.
if client.buffer.Len()+len(data) > maxClientBufferSize {
s.source.logger.Warn("msg", "Client buffer limit exceeded",
s.source.logger.Warn("msg", "Client buffer limit exceeded, closing connection.",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"buffer_size", client.buffer.Len(),
"incoming_size", len(data))
"incoming_size", len(data),
"limit", maxClientBufferSize)
s.source.invalidEntries.Add(1)
return gnet.Close
}
@ -573,12 +522,22 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.None
}
// noopLogger implements gnet's Logger interface but discards everything
// type noopLogger struct{}
// func (n noopLogger) Debugf(format string, args ...any) {}
// func (n noopLogger) Infof(format string, args ...any) {}
// func (n noopLogger) Warnf(format string, args ...any) {}
// func (n noopLogger) Errorf(format string, args ...any) {}
// func (n noopLogger) Fatalf(format string, args ...any) {}
// Configure TCP source auth
func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
// Usage: gnet.Run(..., gnet.WithLogger(noopLogger{}), ...)
authenticator, err := auth.New(authCfg, t.logger)
if err != nil {
t.logger.Error("msg", "Failed to initialize authenticator for TCP source",
"component", "tcp_source",
"error", err)
return
}
t.authenticator = authenticator
t.logger.Info("msg", "Authentication configured for TCP source",
"component", "tcp_source",
"auth_type", authCfg.Type)
}

View File

@ -1,341 +0,0 @@
// FILE: src/internal/tls/gnet_bridge.go
package tls
import (
"crypto/tls"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/panjf2000/gnet/v2"
)
var (
ErrTLSBackpressure = errors.New("TLS processing backpressure")
ErrConnectionClosed = errors.New("connection closed")
ErrPlaintextBufferExceeded = errors.New("plaintext buffer size exceeded")
)
// Maximum plaintext buffer size to prevent memory exhaustion
const maxPlaintextBufferSize = 32 * 1024 * 1024 // 32MB
// Bridges gnet.Conn with crypto/tls via io.Pipe
type GNetTLSConn struct {
gnetConn gnet.Conn
tlsConn *tls.Conn
config *tls.Config
// Buffered channels for non-blocking operation
incomingCipher chan []byte // Network → TLS (encrypted)
outgoingCipher chan []byte // TLS → Network (encrypted)
// Handshake state
handshakeOnce sync.Once
handshakeDone chan struct{}
handshakeErr error
// Decrypted data buffer
plainBuf []byte
plainMu sync.Mutex
// Lifecycle
closed atomic.Bool
closeOnce sync.Once
wg sync.WaitGroup
// Error tracking
lastErr atomic.Value // error
logger interface{ Warn(args ...any) } // Minimal logger interface
}
// Creates a server-side TLS bridge
func NewServerConn(gnetConn gnet.Conn, config *tls.Config) *GNetTLSConn {
tc := &GNetTLSConn{
gnetConn: gnetConn,
config: config,
handshakeDone: make(chan struct{}),
// Buffered channels sized for throughput without blocking
incomingCipher: make(chan []byte, 128), // 128 packets buffer
outgoingCipher: make(chan []byte, 128),
plainBuf: make([]byte, 0, 65536), // 64KB initial capacity
}
// Create TLS conn with channel-based transport
rawConn := &channelConn{
incoming: tc.incomingCipher,
outgoing: tc.outgoingCipher,
localAddr: gnetConn.LocalAddr(),
remoteAddr: gnetConn.RemoteAddr(),
tc: tc,
}
tc.tlsConn = tls.Server(rawConn, config)
// Start pump goroutines
tc.wg.Add(2)
go tc.pumpCipherToNetwork()
go tc.pumpPlaintextFromTLS()
return tc
}
// Creates a client-side TLS bridge (similar changes)
func NewClientConn(gnetConn gnet.Conn, config *tls.Config, serverName string) *GNetTLSConn {
tc := &GNetTLSConn{
gnetConn: gnetConn,
config: config,
handshakeDone: make(chan struct{}),
incomingCipher: make(chan []byte, 128),
outgoingCipher: make(chan []byte, 128),
plainBuf: make([]byte, 0, 65536),
}
if config.ServerName == "" {
config = config.Clone()
config.ServerName = serverName
}
rawConn := &channelConn{
incoming: tc.incomingCipher,
outgoing: tc.outgoingCipher,
localAddr: gnetConn.LocalAddr(),
remoteAddr: gnetConn.RemoteAddr(),
tc: tc,
}
tc.tlsConn = tls.Client(rawConn, config)
tc.wg.Add(2)
go tc.pumpCipherToNetwork()
go tc.pumpPlaintextFromTLS()
return tc
}
// Feeds encrypted data from network into TLS engine (non-blocking)
func (tc *GNetTLSConn) ProcessIncoming(encryptedData []byte) error {
if tc.closed.Load() {
return ErrConnectionClosed
}
// Non-blocking send with backpressure detection
select {
case tc.incomingCipher <- encryptedData:
return nil
default:
// Channel full - TLS processing can't keep up
// Drop connection under backpressure vs blocking event loop
if tc.logger != nil {
tc.logger.Warn("msg", "TLS backpressure, dropping data",
"remote_addr", tc.gnetConn.RemoteAddr())
}
return ErrTLSBackpressure
}
}
// Sends TLS-encrypted data to network
func (tc *GNetTLSConn) pumpCipherToNetwork() {
defer tc.wg.Done()
for {
select {
case data, ok := <-tc.outgoingCipher:
if !ok {
return
}
// Send to network
if err := tc.gnetConn.AsyncWrite(data, nil); err != nil {
tc.lastErr.Store(err)
tc.Close()
return
}
case <-time.After(30 * time.Second):
// Keepalive/timeout check
if tc.closed.Load() {
return
}
}
}
}
// Reads decrypted data from TLS
func (tc *GNetTLSConn) pumpPlaintextFromTLS() {
defer tc.wg.Done()
buf := make([]byte, 32768) // 32KB read buffer
for {
n, err := tc.tlsConn.Read(buf)
if n > 0 {
tc.plainMu.Lock()
// Check buffer size limit before appending to prevent memory exhaustion
if len(tc.plainBuf)+n > maxPlaintextBufferSize {
tc.plainMu.Unlock()
// Log warning about buffer limit
if tc.logger != nil {
tc.logger.Warn("msg", "Plaintext buffer limit exceeded, closing connection",
"remote_addr", tc.gnetConn.RemoteAddr(),
"buffer_size", len(tc.plainBuf),
"incoming_size", n,
"limit", maxPlaintextBufferSize)
}
// Store error and close connection
tc.lastErr.Store(ErrPlaintextBufferExceeded)
tc.Close()
return
}
tc.plainBuf = append(tc.plainBuf, buf[:n]...)
tc.plainMu.Unlock()
}
if err != nil {
if err != io.EOF {
tc.lastErr.Store(err)
}
tc.Close()
return
}
}
}
// Returns available decrypted plaintext (non-blocking)
func (tc *GNetTLSConn) Read() []byte {
tc.plainMu.Lock()
defer tc.plainMu.Unlock()
if len(tc.plainBuf) == 0 {
return nil
}
// Atomic buffer swap under mutex protection to prevent race condition
data := tc.plainBuf
tc.plainBuf = make([]byte, 0, cap(tc.plainBuf))
return data
}
// Encrypts plaintext and queues for network transmission
func (tc *GNetTLSConn) Write(plaintext []byte) (int, error) {
if tc.closed.Load() {
return 0, ErrConnectionClosed
}
if !tc.IsHandshakeDone() {
return 0, errors.New("handshake not complete")
}
return tc.tlsConn.Write(plaintext)
}
// Initiates TLS handshake asynchronously
func (tc *GNetTLSConn) Handshake() {
tc.handshakeOnce.Do(func() {
go func() {
tc.handshakeErr = tc.tlsConn.Handshake()
close(tc.handshakeDone)
}()
})
}
// Checks if handshake is complete
func (tc *GNetTLSConn) IsHandshakeDone() bool {
select {
case <-tc.handshakeDone:
return true
default:
return false
}
}
// Waits for handshake completion
func (tc *GNetTLSConn) HandshakeComplete() (<-chan struct{}, error) {
<-tc.handshakeDone
return tc.handshakeDone, tc.handshakeErr
}
// Shuts down the bridge
func (tc *GNetTLSConn) Close() error {
tc.closeOnce.Do(func() {
tc.closed.Store(true)
// Close TLS connection
tc.tlsConn.Close()
// Close channels to stop pumps
close(tc.incomingCipher)
close(tc.outgoingCipher)
})
// Wait for pumps to finish
tc.wg.Wait()
return nil
}
// Returns TLS connection state
func (tc *GNetTLSConn) GetConnectionState() tls.ConnectionState {
return tc.tlsConn.ConnectionState()
}
// Returns last error
func (tc *GNetTLSConn) GetError() error {
if err, ok := tc.lastErr.Load().(error); ok {
return err
}
return nil
}
// Implements net.Conn over channels
type channelConn struct {
incoming <-chan []byte
outgoing chan<- []byte
localAddr net.Addr
remoteAddr net.Addr
tc *GNetTLSConn
readBuf []byte
}
func (c *channelConn) Read(b []byte) (int, error) {
// Use buffered read for efficiency
if len(c.readBuf) > 0 {
n := copy(b, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
// Wait for new data
select {
case data, ok := <-c.incoming:
if !ok {
return 0, io.EOF
}
n := copy(b, data)
if n < len(data) {
c.readBuf = data[n:] // Buffer remainder
}
return n, nil
case <-time.After(30 * time.Second):
return 0, errors.New("read timeout")
}
}
func (c *channelConn) Write(b []byte) (int, error) {
if c.tc.closed.Load() {
return 0, ErrConnectionClosed
}
// Make a copy since TLS may hold reference
data := make([]byte, len(b))
copy(data, b)
select {
case c.outgoing <- data:
return len(b), nil
case <-time.After(5 * time.Second):
return 0, errors.New("write timeout")
}
}
func (c *channelConn) Close() error { return nil }
func (c *channelConn) LocalAddr() net.Addr { return c.localAddr }
func (c *channelConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *channelConn) SetDeadline(t time.Time) error { return nil }
func (c *channelConn) SetReadDeadline(t time.Time) error { return nil }
func (c *channelConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@ -117,18 +117,6 @@ func (m *Manager) GetHTTPConfig() *tls.Config {
return cfg
}
// Returns TLS config for raw TCP connections
func (m *Manager) GetTCPConfig() *tls.Config {
if m == nil {
return nil
}
cfg := m.tlsConfig.Clone()
// No ALPN for raw TCP
cfg.NextProtos = nil
return cfg
}
// Validates a client certificate for mTLS
func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
if m == nil || !m.config.ClientAuth {
@ -174,6 +162,21 @@ func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
return nil
}
// Returns TLS statistics
func (m *Manager) GetStats() map[string]any {
if m == nil {
return map[string]any{"enabled": false}
}
return map[string]any{
"enabled": true,
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
"client_auth": m.config.ClientAuth,
"cipher_suites": len(m.tlsConfig.CipherSuites),
}
}
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
switch strings.ToUpper(version) {
case "TLS1.0", "TLS10":
@ -217,21 +220,6 @@ func parseCipherSuites(suites string) []uint16 {
return result
}
// Returns TLS statistics
func (m *Manager) GetStats() map[string]any {
if m == nil {
return map[string]any{"enabled": false}
}
return map[string]any{
"enabled": true,
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
"client_auth": m.config.ClientAuth,
"cipher_suites": len(m.tlsConfig.CipherSuites),
}
}
func tlsVersionString(version uint16) string {
switch version {
case tls.VersionTLS10: