v0.5.0 removed tcp tls, basic auth hash changed to argon2, refactor
This commit is contained in:
@ -37,7 +37,6 @@ type = "directory"
|
||||
path = "./" # Directory to monitor
|
||||
pattern = "*.log" # Glob pattern
|
||||
check_interval_ms = 100 # Scan interval
|
||||
read_from_beginning = false # Start position
|
||||
|
||||
### Console Sources
|
||||
# [[pipelines.sources]]
|
||||
@ -74,11 +73,10 @@ read_from_beginning = false # Start position
|
||||
# ip_blacklist = [] # Blocked IPs/CIDRs
|
||||
# requests_per_second = 100.0 # Rate limit per client
|
||||
# burst_size = 100 # Burst capacity
|
||||
# limit_by = "ip" # ip|user|token|global
|
||||
# response_code = 429 # HTTP status when limited
|
||||
# response_message = "Rate limit exceeded"
|
||||
# max_connections_per_ip = 10 # Max concurrent per IP
|
||||
# max_total_connections = 1000 # Max total connections
|
||||
# max_connections_total = 1000 # Max total connections
|
||||
|
||||
### TCP Sources
|
||||
# [[pipelines.sources]]
|
||||
@ -88,30 +86,18 @@ read_from_beginning = false # Start position
|
||||
# host = "0.0.0.0" # Listen address
|
||||
# port = 9091 # Listen port
|
||||
|
||||
# [pipelines.sources.options.tls]
|
||||
# enabled = false # Enable TLS
|
||||
# cert_file = "" # TLS certificate
|
||||
# key_file = "" # TLS key
|
||||
# client_auth = false # Require client certs
|
||||
# client_ca_file = "" # Client CA cert
|
||||
# verify_client_cert = false # Verify client certs
|
||||
# insecure_skip_verify = false # Skip verification
|
||||
# ca_file = "" # Custom CA file
|
||||
# min_version = "TLS1.2" # Min TLS version
|
||||
# max_version = "TLS1.3" # Max TLS version
|
||||
# cipher_suites = "" # Comma-separated list
|
||||
|
||||
# [pipelines.sources.options.net_limit]
|
||||
# enabled = false # Enable rate limiting
|
||||
# ip_whitelist = [] # Allowed IPs/CIDRs
|
||||
# ip_blacklist = [] # Blocked IPs/CIDRs
|
||||
# requests_per_second = 100.0 # Rate limit per client
|
||||
# burst_size = 100 # Burst capacity
|
||||
# limit_by = "ip" # ip|user|token|global
|
||||
# response_code = 429 # Response code when limited
|
||||
# response_message = "Rate limit exceeded"
|
||||
# max_connections_per_ip = 10 # Max concurrent per IP
|
||||
# max_total_connections = 1000 # Max total connections
|
||||
# max_connections_per_user = 10 # Max concurrent per user
|
||||
# max_connections_per_token = 10 # Max concurrent per token
|
||||
# max_connections_total = 1000 # Max total connections
|
||||
|
||||
### Rate limiting
|
||||
# [pipelines.rate_limit]
|
||||
@ -126,6 +112,21 @@ read_from_beginning = false # Start position
|
||||
# logic = "or" # or|and
|
||||
# patterns = [] # Regex patterns
|
||||
|
||||
## Examples of filter patterns:
|
||||
## Include only error or fatal logs containing "database":
|
||||
## type = "include"
|
||||
## logic = "and"
|
||||
## patterns = ["(?i)(error|fatal)", "database"]
|
||||
##
|
||||
## Exclude debug logs from test environment:
|
||||
## type = "exclude"
|
||||
## logic = "or"
|
||||
## patterns = ["(?i)debug", "test-env"]
|
||||
##
|
||||
## Include only JSON formatted logs:
|
||||
## type = "include"
|
||||
## patterns = ["^\\{.*\\}$"]
|
||||
|
||||
### Format
|
||||
|
||||
### Raw formatter (default)
|
||||
@ -174,7 +175,6 @@ format = "comment" # comment|message
|
||||
# verify_client_cert = false # Verify client certs
|
||||
# insecure_skip_verify = false # Skip verification
|
||||
# ca_file = "" # Custom CA file
|
||||
# server_name = "" # Expected server name
|
||||
# min_version = "TLS1.2" # Min TLS version
|
||||
# max_version = "TLS1.3" # Max TLS version
|
||||
# cipher_suites = "" # Comma-separated list
|
||||
@ -185,11 +185,10 @@ format = "comment" # comment|message
|
||||
# ip_blacklist = [] # Blocked IPs/CIDRs
|
||||
# requests_per_second = 100.0 # Rate limit per client
|
||||
# burst_size = 100 # Burst capacity
|
||||
# limit_by = "ip" # ip|user|token|global
|
||||
# response_code = 429 # HTTP status when limited
|
||||
# response_message = "Rate limit exceeded"
|
||||
# max_connections_per_ip = 10 # Max concurrent per IP
|
||||
# max_total_connections = 1000 # Max total connections
|
||||
# max_connections_total = 1000 # Max total connections
|
||||
|
||||
### TCP Sinks
|
||||
# [[pipelines.sinks]]
|
||||
@ -207,31 +206,18 @@ format = "comment" # comment|message
|
||||
# include_stats = false # Include statistics
|
||||
# format = "comment" # comment|message
|
||||
|
||||
# [pipelines.sinks.options.tls]
|
||||
# enabled = false # Enable TLS
|
||||
# cert_file = "" # TLS certificate
|
||||
# key_file = "" # TLS key
|
||||
# client_auth = false # Require client certs
|
||||
# client_ca_file = "" # Client CA cert
|
||||
# verify_client_cert = false # Verify client certs
|
||||
# insecure_skip_verify = false # Skip verification
|
||||
# ca_file = "" # Custom CA file
|
||||
# server_name = "" # Expected server name
|
||||
# min_version = "TLS1.2" # Min TLS version
|
||||
# max_version = "TLS1.3" # Max TLS version
|
||||
# cipher_suites = "" # Comma-separated list
|
||||
|
||||
# [pipelines.sinks.options.net_limit]
|
||||
# enabled = false # Enable rate limiting
|
||||
# ip_whitelist = [] # Allowed IPs/CIDRs
|
||||
# ip_blacklist = [] # Blocked IPs/CIDRs
|
||||
# requests_per_second = 100.0 # Rate limit per client
|
||||
# burst_size = 100 # Burst capacity
|
||||
# limit_by = "ip" # ip|user|token|global
|
||||
# response_code = 429 # HTTP status when limited
|
||||
# response_message = "Rate limit exceeded"
|
||||
# max_connections_per_ip = 10 # Max concurrent per IP
|
||||
# max_total_connections = 1000 # Max total connections
|
||||
# max_connections_per_user = 10 # Max concurrent per user
|
||||
# max_connections_per_token = 10 # Max concurrent per token
|
||||
# max_connections_total = 1000 # Max total connections
|
||||
|
||||
### HTTP Client Sinks
|
||||
# [[pipelines.sinks]]
|
||||
@ -283,6 +269,7 @@ format = "comment" # comment|message
|
||||
# [pipelines.sinks.options]
|
||||
# directory = "" # Output dir (required)
|
||||
# name = "" # Base name (required)
|
||||
# buffer_size = 1000 # Input channel buffer size
|
||||
# max_size_mb = 100 # Rotation size
|
||||
# max_total_size_mb = 0 # Total limit (0=unlimited)
|
||||
# retention_hours = 0.0 # Retention (0=disabled)
|
||||
|
||||
6
go.mod
6
go.mod
@ -5,9 +5,9 @@ go 1.25.1
|
||||
require (
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3
|
||||
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208
|
||||
github.com/panjf2000/gnet/v2 v2.9.3
|
||||
github.com/valyala/fasthttp v1.65.0
|
||||
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e
|
||||
github.com/panjf2000/gnet/v2 v2.9.4
|
||||
github.com/valyala/fasthttp v1.66.0
|
||||
golang.org/x/crypto v0.42.0
|
||||
golang.org/x/term v0.35.0
|
||||
golang.org/x/time v0.13.0
|
||||
|
||||
12
go.sum
12
go.sum
@ -12,20 +12,20 @@ github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zt
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0=
|
||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
|
||||
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208 h1:IB1O/HLv9VR/4mL1Tkjlr91lk+r8anP6bab7rYdS/oE=
|
||||
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
|
||||
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e h1:/XWCqFdSOiUf0/a5a63GHsvEdpglsYfn3qieNxTeyDc=
|
||||
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
|
||||
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
|
||||
github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
|
||||
github.com/panjf2000/gnet/v2 v2.9.3 h1:auV3/A9Na3jiBDmYAAU00rPhFKnsAI+TnI1F7YUJMHQ=
|
||||
github.com/panjf2000/gnet/v2 v2.9.3/go.mod h1:WQTxDWYuQ/hz3eccH0FN32IVuvZ19HewEWx0l62fx7E=
|
||||
github.com/panjf2000/gnet/v2 v2.9.4 h1:XvPCcaFwO4XWg4IgSfZnNV4dfDy5g++HIEx7sH0ldHc=
|
||||
github.com/panjf2000/gnet/v2 v2.9.4/go.mod h1:WQTxDWYuQ/hz3eccH0FN32IVuvZ19HewEWx0l62fx7E=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8=
|
||||
github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4=
|
||||
github.com/valyala/fasthttp v1.66.0 h1:M87A0Z7EayeyNaV6pfO3tUTUiYO0dZfEJnRGXTVNuyU=
|
||||
github.com/valyala/fasthttp v1.66.0/go.mod h1:Y4eC+zwoocmXSVCB1JmhNbYtS7tZPRI2ztPB72EVObs=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
|
||||
@ -60,7 +60,7 @@ func initializeLogger(cfg *config.Config) error {
|
||||
// In quiet mode, disable ALL logging output
|
||||
logCfg.Level = 255 // A level that disables all output
|
||||
logCfg.DisableFile = true
|
||||
logCfg.EnableStdout = false
|
||||
logCfg.EnableConsole = false
|
||||
return logger.ApplyConfig(logCfg)
|
||||
}
|
||||
|
||||
@ -74,29 +74,24 @@ func initializeLogger(cfg *config.Config) error {
|
||||
// Configure based on output mode
|
||||
switch cfg.Logging.Output {
|
||||
case "none":
|
||||
logCfg.EnableStdout = false
|
||||
logCfg.EnableConsole = false
|
||||
case "stdout":
|
||||
logCfg.EnableStdout = true
|
||||
logCfg.StdoutTarget = "stdout"
|
||||
logCfg.EnableConsole = true
|
||||
logCfg.ConsoleTarget = "stdout"
|
||||
case "stderr":
|
||||
logCfg.EnableStdout = true
|
||||
logCfg.StdoutTarget = "stderr"
|
||||
logCfg.EnableConsole = true
|
||||
logCfg.ConsoleTarget = "stderr"
|
||||
case "split":
|
||||
logCfg.EnableStdout = true
|
||||
logCfg.StdoutTarget = "split"
|
||||
logCfg.EnableConsole = true
|
||||
logCfg.ConsoleTarget = "split"
|
||||
case "file":
|
||||
logCfg.DisableFile = false
|
||||
logCfg.EnableStdout = false
|
||||
configureFileLogging(logCfg, cfg)
|
||||
case "both":
|
||||
logCfg.DisableFile = false
|
||||
logCfg.EnableStdout = true
|
||||
logCfg.StdoutTarget = "stdout"
|
||||
logCfg.EnableConsole = false
|
||||
configureFileLogging(logCfg, cfg)
|
||||
case "all":
|
||||
logCfg.DisableFile = false
|
||||
logCfg.EnableStdout = true
|
||||
logCfg.StdoutTarget = "split"
|
||||
logCfg.EnableConsole = true
|
||||
logCfg.ConsoleTarget = "split"
|
||||
configureFileLogging(logCfg, cfg)
|
||||
default:
|
||||
return fmt.Errorf("invalid log output mode: %s", cfg.Logging.Output)
|
||||
|
||||
@ -127,13 +127,13 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
||||
"listen", fmt.Sprintf("%s:%d", host, port))
|
||||
|
||||
// Display net limit info if configured
|
||||
if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
||||
if enabled, ok := rl["enabled"].(bool); ok && enabled {
|
||||
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
||||
if enabled, ok := nl["enabled"].(bool); ok && enabled {
|
||||
logger.Info("msg", "TCP net limiting enabled",
|
||||
"pipeline", cfg.Name,
|
||||
"sink_index", i,
|
||||
"requests_per_second", rl["requests_per_second"],
|
||||
"burst_size", rl["burst_size"])
|
||||
"requests_per_second", nl["requests_per_second"],
|
||||
"burst_size", nl["burst_size"])
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -162,14 +162,13 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
||||
"status_url", fmt.Sprintf("http://%s:%d%s", host, port, statusPath))
|
||||
|
||||
// Display net limit info if configured
|
||||
if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
||||
if enabled, ok := rl["enabled"].(bool); ok && enabled {
|
||||
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
||||
if enabled, ok := nl["enabled"].(bool); ok && enabled {
|
||||
logger.Info("msg", "HTTP net limiting enabled",
|
||||
"pipeline", cfg.Name,
|
||||
"sink_index", i,
|
||||
"requests_per_second", rl["requests_per_second"],
|
||||
"burst_size", rl["burst_size"],
|
||||
"limit_by", rl["limit_by"])
|
||||
"requests_per_second", nl["requests_per_second"],
|
||||
"burst_size", nl["burst_size"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ package auth
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -16,7 +17,7 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/lixenwraith/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@ -378,12 +379,14 @@ func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string)
|
||||
a.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Perform bcrypt anyway to prevent timing attacks
|
||||
bcrypt.CompareHashAndPassword([]byte("$2a$10$dummy.hash.to.prevent.timing.attacks"), []byte(password))
|
||||
// Perform argon2 anyway to prevent timing attacks
|
||||
dummySalt := make([]byte, 16)
|
||||
argon2.IDKey([]byte(password), dummySalt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(expectedHash), []byte(password)); err != nil {
|
||||
// Parse PHC format hash
|
||||
if !verifyArgon2idHash(password, expectedHash) {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
@ -400,6 +403,43 @@ func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Verify Argon2id hashes
|
||||
func verifyArgon2idHash(password, hash string) bool {
|
||||
// Parse PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse version
|
||||
var version int
|
||||
fmt.Sscanf(parts[2], "v=%d", &version)
|
||||
if version != argon2.Version {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse parameters
|
||||
var memory, time, threads uint32
|
||||
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||
|
||||
// Decode salt and hash
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Compute hash
|
||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, uint8(threads), uint32(len(expectedHash)))
|
||||
|
||||
// Constant time comparison
|
||||
return subtle.ConstantTimeCompare(computedHash, expectedHash) == 1
|
||||
}
|
||||
|
||||
func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) {
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return nil, fmt.Errorf("invalid bearer auth header")
|
||||
|
||||
@ -10,17 +10,24 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Handles auth credential generation
|
||||
// Argon2id parameters
|
||||
const (
|
||||
argon2Time = 3
|
||||
argon2Memory = 64 * 1024 // 64 MB
|
||||
argon2Threads = 4
|
||||
argon2SaltLen = 16
|
||||
argon2KeyLen = 32
|
||||
)
|
||||
|
||||
type GeneratorCommand struct {
|
||||
output io.Writer
|
||||
errOut io.Writer
|
||||
}
|
||||
|
||||
// Creates a new auth generator command handler
|
||||
func NewGeneratorCommand() *GeneratorCommand {
|
||||
return &GeneratorCommand{
|
||||
output: os.Stdout,
|
||||
@ -28,7 +35,6 @@ func NewGeneratorCommand() *GeneratorCommand {
|
||||
}
|
||||
}
|
||||
|
||||
// Runs the auth generation command with provided arguments
|
||||
func (g *GeneratorCommand) Execute(args []string) error {
|
||||
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
|
||||
cmd.SetOutput(g.errOut)
|
||||
@ -36,7 +42,6 @@ func (g *GeneratorCommand) Execute(args []string) error {
|
||||
var (
|
||||
username = cmd.String("u", "", "Username for basic auth")
|
||||
password = cmd.String("p", "", "Password to hash (will prompt if not provided)")
|
||||
cost = cmd.Int("c", 10, "Bcrypt cost (10-31)")
|
||||
genToken = cmd.Bool("t", false, "Generate random bearer token")
|
||||
tokenLen = cmd.Int("l", 32, "Token length in bytes")
|
||||
)
|
||||
@ -45,7 +50,7 @@ func (g *GeneratorCommand) Execute(args []string) error {
|
||||
fmt.Fprintln(g.errOut, "Generate authentication credentials for LogWisp")
|
||||
fmt.Fprintln(g.errOut, "\nUsage: logwisp auth [options]")
|
||||
fmt.Fprintln(g.errOut, "\nExamples:")
|
||||
fmt.Fprintln(g.errOut, " # Generate bcrypt hash for user")
|
||||
fmt.Fprintln(g.errOut, " # Generate Argon2id hash for user")
|
||||
fmt.Fprintln(g.errOut, " logwisp auth -u admin")
|
||||
fmt.Fprintln(g.errOut, " ")
|
||||
fmt.Fprintln(g.errOut, " # Generate 64-byte bearer token")
|
||||
@ -58,26 +63,19 @@ func (g *GeneratorCommand) Execute(args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Token generation mode
|
||||
if *genToken {
|
||||
return g.generateToken(*tokenLen)
|
||||
}
|
||||
|
||||
// Password hash generation mode
|
||||
if *username == "" {
|
||||
cmd.Usage()
|
||||
return fmt.Errorf("username required for password hash generation")
|
||||
}
|
||||
|
||||
return g.generatePasswordHash(*username, *password, *cost)
|
||||
return g.generatePasswordHash(*username, *password)
|
||||
}
|
||||
|
||||
func (g *GeneratorCommand) generatePasswordHash(username, password string, cost int) error {
|
||||
// Validate cost
|
||||
if cost < 10 || cost > 31 {
|
||||
return fmt.Errorf("bcrypt cost must be between 10 and 31")
|
||||
}
|
||||
|
||||
func (g *GeneratorCommand) generatePasswordHash(username, password string) error {
|
||||
// Get password if not provided
|
||||
if password == "" {
|
||||
pass1 := g.promptPassword("Enter password: ")
|
||||
@ -88,20 +86,29 @@ func (g *GeneratorCommand) generatePasswordHash(username, password string, cost
|
||||
password = pass1
|
||||
}
|
||||
|
||||
// Generate hash
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), cost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate hash: %w", err)
|
||||
// Generate salt
|
||||
salt := make([]byte, argon2SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Generate Argon2id hash
|
||||
hash := argon2.IDKey([]byte(password), salt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
|
||||
|
||||
// Encode in PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
|
||||
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
||||
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64)
|
||||
|
||||
// Output configuration snippets
|
||||
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):")
|
||||
fmt.Fprintln(g.output, "[[pipelines.auth.basic_auth.users]]")
|
||||
fmt.Fprintf(g.output, "username = %q\n", username)
|
||||
fmt.Fprintf(g.output, "password_hash = %q\n\n", string(hash))
|
||||
fmt.Fprintf(g.output, "password_hash = %q\n\n", phcHash)
|
||||
|
||||
fmt.Fprintln(g.output, "# Users File Format (for external auth file):")
|
||||
fmt.Fprintf(g.output, "%s:%s\n", username, hash)
|
||||
fmt.Fprintf(g.output, "%s:%s\n", username, phcHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -119,11 +126,9 @@ func (g *GeneratorCommand) generateToken(length int) error {
|
||||
return fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Generate both encodings
|
||||
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
|
||||
hex := fmt.Sprintf("%x", token)
|
||||
|
||||
// Output configuration
|
||||
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):")
|
||||
fmt.Fprintf(g.output, "tokens = [%q]\n\n", b64)
|
||||
|
||||
@ -139,7 +144,6 @@ func (g *GeneratorCommand) promptPassword(prompt string) string {
|
||||
password, err := term.ReadPassword(int(syscall.Stdin))
|
||||
fmt.Fprintln(g.errOut)
|
||||
if err != nil {
|
||||
// Fatal error - can't continue without password
|
||||
fmt.Fprintf(g.errOut, "Failed to read password: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ type BasicAuthConfig struct {
|
||||
|
||||
type BasicAuthUser struct {
|
||||
Username string `toml:"username"`
|
||||
// Password hash (bcrypt)
|
||||
// Password hash (Argon2id)
|
||||
PasswordHash string `toml:"password_hash"`
|
||||
}
|
||||
|
||||
|
||||
@ -5,13 +5,13 @@ import "fmt"
|
||||
|
||||
// Represents logging configuration for LogWisp
|
||||
type LogConfig struct {
|
||||
// Output mode: "file", "stdout", "stderr", "both", "none"
|
||||
// Output mode: "file", "stdout", "stderr", "split", "all", "none"
|
||||
Output string `toml:"output"`
|
||||
|
||||
// Log level: "debug", "info", "warn", "error"
|
||||
Level string `toml:"level"`
|
||||
|
||||
// File output settings (when Output includes "file" or "both")
|
||||
// File output settings (when Output includes "file" or "all")
|
||||
File *LogFileConfig `toml:"file"`
|
||||
|
||||
// Console output settings
|
||||
@ -66,7 +66,7 @@ func DefaultLogConfig() *LogConfig {
|
||||
func validateLogConfig(cfg *LogConfig) error {
|
||||
validOutputs := map[string]bool{
|
||||
"file": true, "stdout": true, "stderr": true,
|
||||
"both": true, "all": true, "none": true,
|
||||
"split": true, "all": true, "none": true,
|
||||
}
|
||||
if !validOutputs[cfg.Output] {
|
||||
return fmt.Errorf("invalid log output mode: %s", cfg.Output)
|
||||
|
||||
@ -131,8 +131,8 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
|
||||
}
|
||||
|
||||
// Validate net_limit
|
||||
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, rl); err != nil {
|
||||
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, nl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -161,15 +161,8 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
|
||||
}
|
||||
|
||||
// Validate net_limit
|
||||
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, rl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate TLS
|
||||
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
|
||||
if err := validateTLSOptions("TCP source", pipelineName, sourceIndex, tls); err != nil {
|
||||
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, nl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -196,7 +189,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate host if provided
|
||||
// Validate host
|
||||
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
||||
if net.ParseIP(host) == nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
|
||||
@ -219,7 +212,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
||||
}
|
||||
}
|
||||
|
||||
// Validate paths if provided
|
||||
// Validate paths
|
||||
if streamPath, ok := cfg.Options["stream_path"].(string); ok {
|
||||
if !strings.HasPrefix(streamPath, "/") {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: stream path must start with /: %s",
|
||||
@ -234,7 +227,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
||||
}
|
||||
}
|
||||
|
||||
// Validate heartbeat if present
|
||||
// Validate heartbeat
|
||||
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
|
||||
if err := validateHeartbeatOptions("HTTP", pipelineName, sinkIndex, hb); err != nil {
|
||||
return err
|
||||
@ -248,9 +241,9 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
||||
}
|
||||
}
|
||||
|
||||
// Validate net limit if present
|
||||
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, rl); err != nil {
|
||||
// Validate net limit
|
||||
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, nl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -263,7 +256,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate host if provided
|
||||
// Validate host
|
||||
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
||||
if net.ParseIP(host) == nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
|
||||
@ -286,23 +279,16 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
||||
}
|
||||
}
|
||||
|
||||
// Validate heartbeat if present
|
||||
// Validate heartbeat
|
||||
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
|
||||
if err := validateHeartbeatOptions("TCP", pipelineName, sinkIndex, hb); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate TLS if present
|
||||
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
|
||||
if err := validateTLSOptions("TCP", pipelineName, sinkIndex, tls); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate net limit if present
|
||||
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, rl); err != nil {
|
||||
// Validate net limit
|
||||
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, nl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,9 +12,6 @@ type TCPConfig struct {
|
||||
Port int64 `toml:"port"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
|
||||
// TLS Configuration
|
||||
TLS *TLSConfig `toml:"tls"`
|
||||
|
||||
// Net limiting
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
|
||||
@ -63,16 +60,15 @@ type NetLimitConfig struct {
|
||||
// Burst size (token bucket)
|
||||
BurstSize int64 `toml:"burst_size"`
|
||||
|
||||
// Net limit by: "ip", "user", "token", "global"
|
||||
LimitBy string `toml:"limit_by"`
|
||||
|
||||
// Response when net limited
|
||||
ResponseCode int64 `toml:"response_code"` // Default: 429
|
||||
ResponseMessage string `toml:"response_message"` // Default: "Net limit exceeded"
|
||||
|
||||
// Connection limits
|
||||
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
|
||||
MaxTotalConnections int64 `toml:"max_total_connections"`
|
||||
MaxConnectionsPerUser int64 `toml:"max_connections_per_user"`
|
||||
MaxConnectionsPerToken int64 `toml:"max_connections_per_token"`
|
||||
MaxConnectionsTotal int64 `toml:"max_connections_total"`
|
||||
}
|
||||
|
||||
func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb map[string]any) error {
|
||||
@ -93,13 +89,13 @@ func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl map[string]any) error {
|
||||
if enabled, ok := rl["enabled"].(bool); !ok || !enabled {
|
||||
func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, nl map[string]any) error {
|
||||
if enabled, ok := nl["enabled"].(bool); !ok || !enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate IP lists if present
|
||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||
for i, entry := range ipWhitelist {
|
||||
entryStr, ok := entry.(string)
|
||||
if !ok {
|
||||
@ -112,7 +108,7 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
|
||||
}
|
||||
}
|
||||
|
||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||
for i, entry := range ipBlacklist {
|
||||
entryStr, ok := entry.(string)
|
||||
if !ok {
|
||||
@ -126,30 +122,21 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
|
||||
}
|
||||
|
||||
// Validate requests per second
|
||||
rps, ok := rl["requests_per_second"].(float64)
|
||||
rps, ok := nl["requests_per_second"].(float64)
|
||||
if !ok || rps <= 0 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: requests_per_second must be positive",
|
||||
pipelineName, sinkIndex, serverType)
|
||||
}
|
||||
|
||||
// Validate burst size
|
||||
burst, ok := rl["burst_size"].(int64)
|
||||
burst, ok := nl["burst_size"].(int64)
|
||||
if !ok || burst < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: burst_size must be at least 1",
|
||||
pipelineName, sinkIndex, serverType)
|
||||
}
|
||||
|
||||
// Validate limit_by
|
||||
if limitBy, ok := rl["limit_by"].(string); ok && limitBy != "" {
|
||||
validLimitBy := map[string]bool{"ip": true, "global": true}
|
||||
if !validLimitBy[limitBy] {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid limit_by value: %s (must be 'ip' or 'global')",
|
||||
pipelineName, sinkIndex, serverType, limitBy)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate response code
|
||||
if respCode, ok := rl["response_code"].(int64); ok {
|
||||
if respCode, ok := nl["response_code"].(int64); ok {
|
||||
if respCode > 0 && (respCode < 400 || respCode >= 600) {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: response_code must be 4xx or 5xx: %d",
|
||||
pipelineName, sinkIndex, serverType, respCode)
|
||||
@ -157,14 +144,25 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
|
||||
}
|
||||
|
||||
// Validate connection limits
|
||||
maxPerIP, perIPOk := rl["max_connections_per_ip"].(int64)
|
||||
maxTotal, totalOk := rl["max_total_connections"].(int64)
|
||||
maxPerIP, perIPOk := nl["max_connections_per_ip"].(int64)
|
||||
maxPerUser, perUserOk := nl["max_connections_per_user"].(int64)
|
||||
maxPerToken, perTokenOk := nl["max_connections_per_token"].(int64)
|
||||
maxTotal, totalOk := nl["max_connections_total"].(int64)
|
||||
|
||||
if perIPOk && totalOk && maxPerIP > 0 && maxTotal > 0 {
|
||||
if perIPOk && perUserOk && perTokenOk && totalOk &&
|
||||
maxPerIP > 0 && maxPerUser > 0 && maxPerToken > 0 && maxTotal > 0 {
|
||||
if maxPerIP > maxTotal {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_total_connections (%d)",
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_connections_total (%d)",
|
||||
pipelineName, sinkIndex, serverType, maxPerIP, maxTotal)
|
||||
}
|
||||
if maxPerUser > maxTotal {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_user (%d) cannot exceed max_connections_total (%d)",
|
||||
pipelineName, sinkIndex, serverType, maxPerUser, maxTotal)
|
||||
}
|
||||
if maxPerToken > maxTotal {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_token (%d) cannot exceed max_connections_total (%d)",
|
||||
pipelineName, sinkIndex, serverType, maxPerToken, maxTotal)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@ -50,6 +50,9 @@ type NetLimiter struct {
|
||||
|
||||
// Connection tracking
|
||||
ipConnections map[string]*connTracker
|
||||
userConnections map[string]*connTracker
|
||||
tokenConnections map[string]*connTracker
|
||||
totalConnections atomic.Int64
|
||||
connMu sync.RWMutex
|
||||
|
||||
// Statistics
|
||||
@ -108,6 +111,8 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
ipBlacklist: make([]*net.IPNet, 0),
|
||||
ipLimiters: make(map[string]*ipLimiter),
|
||||
ipConnections: make(map[string]*connTracker),
|
||||
userConnections: make(map[string]*connTracker),
|
||||
tokenConnections: make(map[string]*connTracker),
|
||||
lastCleanup: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@ -117,14 +122,6 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
// Parse IP lists
|
||||
l.parseIPLists(cfg)
|
||||
|
||||
// Create global limiter if configured
|
||||
if cfg.Enabled && cfg.LimitBy == "global" {
|
||||
l.globalLimiter = NewTokenBucket(
|
||||
float64(cfg.BurstSize),
|
||||
cfg.RequestsPerSecond,
|
||||
)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine only if rate limiting is enabled
|
||||
if cfg.Enabled {
|
||||
go l.cleanupLoop()
|
||||
@ -138,7 +135,10 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
"blacklist_rules", len(l.ipBlacklist),
|
||||
"requests_per_second", cfg.RequestsPerSecond,
|
||||
"burst_size", cfg.BurstSize,
|
||||
"limit_by", cfg.LimitBy)
|
||||
"max_connections_per_ip", cfg.MaxConnectionsPerIP,
|
||||
"max_connections_per_user", cfg.MaxConnectionsPerUser,
|
||||
"max_connections_per_token", cfg.MaxConnectionsPerToken,
|
||||
"max_connections_total", cfg.MaxConnectionsTotal)
|
||||
|
||||
return l
|
||||
}
|
||||
@ -276,7 +276,7 @@ func (l *NetLimiter) Shutdown() {
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if an HTTP request should be allowed
|
||||
// Checks if an HTTP request should be allowed: IP access control + connection limits (IP only) + calls
|
||||
func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int64, message string) {
|
||||
if l == nil {
|
||||
return true, 0, ""
|
||||
@ -343,7 +343,7 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
if !l.checkLimit(ipStr) {
|
||||
if !l.checkIPLimit(ipStr) {
|
||||
l.blockedByRateLimit.Add(1)
|
||||
statusCode = l.config.ResponseCode
|
||||
if statusCode == 0 {
|
||||
@ -372,7 +372,7 @@ func (l *NetLimiter) updateConnectionActivity(ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if a TCP connection should be allowed
|
||||
// Checks if a TCP connection should be allowed: IP access control + calls checkIPLimit()
|
||||
func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
|
||||
if l == nil {
|
||||
return true
|
||||
@ -412,7 +412,7 @@ func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
|
||||
|
||||
// Check rate limit
|
||||
ipStr := tcpAddr.IP.String()
|
||||
if !l.checkLimit(ipStr) {
|
||||
if !l.checkIPLimit(ipStr) {
|
||||
l.blockedByRateLimit.Add(1)
|
||||
return false
|
||||
}
|
||||
@ -531,17 +531,40 @@ func (l *NetLimiter) GetStats() map[string]any {
|
||||
return map[string]any{"enabled": false}
|
||||
}
|
||||
|
||||
// Get active rate limiters count
|
||||
l.ipMu.RLock()
|
||||
activeIPs := len(l.ipLimiters)
|
||||
l.ipMu.RUnlock()
|
||||
|
||||
// Get connection tracker counts and calculate total active connections
|
||||
l.connMu.RLock()
|
||||
totalConnections := 0
|
||||
ipConnTrackers := len(l.ipConnections)
|
||||
userConnTrackers := len(l.userConnections)
|
||||
tokenConnTrackers := len(l.tokenConnections)
|
||||
|
||||
// Calculate actual connection count by summing all IP connections
|
||||
// Potentially more accurate than totalConnections counter which might drift
|
||||
// TODO: test and refactor if they match
|
||||
actualIPConnections := 0
|
||||
for _, tracker := range l.ipConnections {
|
||||
totalConnections += int(tracker.connections.Load())
|
||||
actualIPConnections += int(tracker.connections.Load())
|
||||
}
|
||||
|
||||
actualUserConnections := 0
|
||||
for _, tracker := range l.userConnections {
|
||||
actualUserConnections += int(tracker.connections.Load())
|
||||
}
|
||||
|
||||
actualTokenConnections := 0
|
||||
for _, tracker := range l.tokenConnections {
|
||||
actualTokenConnections += int(tracker.connections.Load())
|
||||
}
|
||||
|
||||
// Use the counter for total (should match actualIPConnections in most cases)
|
||||
totalConns := l.totalConnections.Load()
|
||||
l.connMu.RUnlock()
|
||||
|
||||
// Calculate total blocked
|
||||
totalBlocked := l.blockedByBlacklist.Load() +
|
||||
l.blockedByWhitelist.Load() +
|
||||
l.blockedByRateLimit.Load() +
|
||||
@ -559,23 +582,39 @@ func (l *NetLimiter) GetStats() map[string]any {
|
||||
"conn_limit": l.blockedByConnLimit.Load(),
|
||||
"invalid_ip": l.blockedByInvalidIP.Load(),
|
||||
},
|
||||
"active_ips": activeIPs,
|
||||
"total_connections": totalConnections,
|
||||
"acl": map[string]int{
|
||||
"whitelist_rules": len(l.ipWhitelist),
|
||||
"blacklist_rules": len(l.ipBlacklist),
|
||||
},
|
||||
"rate_limit": map[string]any{
|
||||
"rate_limiting": map[string]any{
|
||||
"enabled": l.config.Enabled,
|
||||
"requests_per_second": l.config.RequestsPerSecond,
|
||||
"burst_size": l.config.BurstSize,
|
||||
"limit_by": l.config.LimitBy,
|
||||
"active_ip_limiters": activeIPs, // IPs being rate-limited
|
||||
},
|
||||
"access_control": map[string]any{
|
||||
"whitelist_rules": len(l.ipWhitelist),
|
||||
"blacklist_rules": len(l.ipBlacklist),
|
||||
},
|
||||
"connections": map[string]any{
|
||||
// Actual counts
|
||||
"total_active": totalConns, // Counter-based total
|
||||
"active_ip_connections": actualIPConnections, // Sum of all IP connections
|
||||
"active_user_connections": actualUserConnections, // Sum of all user connections
|
||||
"active_token_connections": actualTokenConnections, // Sum of all token connections
|
||||
|
||||
// Tracker counts (number of unique IPs/users/tokens being tracked)
|
||||
"tracked_ips": ipConnTrackers,
|
||||
"tracked_users": userConnTrackers,
|
||||
"tracked_tokens": tokenConnTrackers,
|
||||
|
||||
// Configuration limits (0 = disabled)
|
||||
"limit_per_ip": l.config.MaxConnectionsPerIP,
|
||||
"limit_per_user": l.config.MaxConnectionsPerUser,
|
||||
"limit_per_token": l.config.MaxConnectionsPerToken,
|
||||
"limit_total": l.config.MaxConnectionsTotal,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Performs the actual net limit check
|
||||
func (l *NetLimiter) checkLimit(ip string) bool {
|
||||
// Performs IP net limit check (req/sec)
|
||||
func (l *NetLimiter) checkIPLimit(ip string) bool {
|
||||
// Validate IP format
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil || !isIPv4(parsedIP) {
|
||||
@ -588,12 +627,7 @@ func (l *NetLimiter) checkLimit(ip string) bool {
|
||||
// Maybe run cleanup
|
||||
l.maybeCleanup()
|
||||
|
||||
switch l.config.LimitBy {
|
||||
case "global":
|
||||
return l.globalLimiter.Allow()
|
||||
|
||||
case "ip", "":
|
||||
// Default to per-IP limiting
|
||||
// IP limit
|
||||
l.ipMu.Lock()
|
||||
lim, exists := l.ipLimiters[ip]
|
||||
if !exists {
|
||||
@ -616,25 +650,13 @@ func (l *NetLimiter) checkLimit(ip string) bool {
|
||||
}
|
||||
l.ipMu.Unlock()
|
||||
|
||||
// Check connection limit if configured
|
||||
if l.config.MaxConnectionsPerIP > 0 {
|
||||
l.connMu.RLock()
|
||||
tracker, exists := l.ipConnections[ip]
|
||||
l.connMu.RUnlock()
|
||||
|
||||
if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
|
||||
return false
|
||||
}
|
||||
// Rate limit check
|
||||
allowed := lim.bucket.Allow()
|
||||
if !allowed {
|
||||
l.blockedByRateLimit.Add(1)
|
||||
}
|
||||
|
||||
return lim.bucket.Allow()
|
||||
|
||||
default:
|
||||
// Unknown limit_by value, allow by default
|
||||
l.logger.Warn("msg", "Unknown limit_by value",
|
||||
"limit_by", l.config.LimitBy)
|
||||
return true
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
|
||||
// Runs cleanup if enough time has passed
|
||||
@ -691,25 +713,57 @@ func (l *NetLimiter) cleanup() {
|
||||
|
||||
// Clean up stale connection trackers
|
||||
l.connMu.Lock()
|
||||
connCleaned := 0
|
||||
|
||||
// Clean IP connections
|
||||
ipCleaned := 0
|
||||
for ip, tracker := range l.ipConnections {
|
||||
tracker.mu.Lock()
|
||||
lastSeen := tracker.lastSeen
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Remove if no activity for 5 minutes AND no active connections
|
||||
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
|
||||
delete(l.ipConnections, ip)
|
||||
connCleaned++
|
||||
ipCleaned++
|
||||
}
|
||||
}
|
||||
|
||||
// Clean user connections
|
||||
userCleaned := 0
|
||||
for user, tracker := range l.userConnections {
|
||||
tracker.mu.Lock()
|
||||
lastSeen := tracker.lastSeen
|
||||
tracker.mu.Unlock()
|
||||
|
||||
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
|
||||
delete(l.userConnections, user)
|
||||
userCleaned++
|
||||
}
|
||||
}
|
||||
|
||||
// Clean token connections
|
||||
tokenCleaned := 0
|
||||
for token, tracker := range l.tokenConnections {
|
||||
tracker.mu.Lock()
|
||||
lastSeen := tracker.lastSeen
|
||||
tracker.mu.Unlock()
|
||||
|
||||
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
|
||||
delete(l.tokenConnections, token)
|
||||
tokenCleaned++
|
||||
}
|
||||
}
|
||||
|
||||
l.connMu.Unlock()
|
||||
|
||||
if connCleaned > 0 {
|
||||
if ipCleaned > 0 || userCleaned > 0 || tokenCleaned > 0 {
|
||||
l.logger.Debug("msg", "Cleaned up stale connection trackers",
|
||||
"component", "netlimit",
|
||||
"cleaned", connCleaned,
|
||||
"remaining", len(l.ipConnections))
|
||||
"ip_cleaned", ipCleaned,
|
||||
"user_cleaned", userCleaned,
|
||||
"token_cleaned", tokenCleaned,
|
||||
"ip_remaining", len(l.ipConnections),
|
||||
"user_remaining", len(l.userConnections),
|
||||
"token_remaining", len(l.tokenConnections))
|
||||
}
|
||||
}
|
||||
|
||||
@ -731,3 +785,163 @@ func (l *NetLimiter) cleanupLoop() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tracks a new connection with optional user/token info: Connection limits (IP/user/token/total) for TCP only
|
||||
func (l *NetLimiter) TrackConnection(ip string, user string, token string) bool {
|
||||
if l == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
l.connMu.Lock()
|
||||
defer l.connMu.Unlock()
|
||||
|
||||
// Check total connections limit (0 = disabled)
|
||||
if l.config.MaxConnectionsTotal > 0 {
|
||||
currentTotal := l.totalConnections.Load()
|
||||
if currentTotal >= l.config.MaxConnectionsTotal {
|
||||
l.blockedByConnLimit.Add(1)
|
||||
l.logger.Debug("msg", "TCP connection blocked by total limit",
|
||||
"component", "netlimit",
|
||||
"current_total", currentTotal,
|
||||
"max_total", l.config.MaxConnectionsTotal)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check per-IP connection limit (0 = disabled)
|
||||
if l.config.MaxConnectionsPerIP > 0 && ip != "" {
|
||||
tracker, exists := l.ipConnections[ip]
|
||||
if !exists {
|
||||
tracker = &connTracker{lastSeen: time.Now()}
|
||||
l.ipConnections[ip] = tracker
|
||||
}
|
||||
if tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
|
||||
l.blockedByConnLimit.Add(1)
|
||||
l.logger.Debug("msg", "TCP connection blocked by IP limit",
|
||||
"component", "netlimit",
|
||||
"ip", ip,
|
||||
"current", tracker.connections.Load(),
|
||||
"max", l.config.MaxConnectionsPerIP)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check per-user connection limit (0 = disabled)
|
||||
if l.config.MaxConnectionsPerUser > 0 && user != "" {
|
||||
tracker, exists := l.userConnections[user]
|
||||
if !exists {
|
||||
tracker = &connTracker{lastSeen: time.Now()}
|
||||
l.userConnections[user] = tracker
|
||||
}
|
||||
if tracker.connections.Load() >= l.config.MaxConnectionsPerUser {
|
||||
l.blockedByConnLimit.Add(1)
|
||||
l.logger.Debug("msg", "TCP connection blocked by user limit",
|
||||
"component", "netlimit",
|
||||
"user", user,
|
||||
"current", tracker.connections.Load(),
|
||||
"max", l.config.MaxConnectionsPerUser)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check per-token connection limit (0 = disabled)
|
||||
if l.config.MaxConnectionsPerToken > 0 && token != "" {
|
||||
tracker, exists := l.tokenConnections[token]
|
||||
if !exists {
|
||||
tracker = &connTracker{lastSeen: time.Now()}
|
||||
l.tokenConnections[token] = tracker
|
||||
}
|
||||
if tracker.connections.Load() >= l.config.MaxConnectionsPerToken {
|
||||
l.blockedByConnLimit.Add(1)
|
||||
l.logger.Debug("msg", "TCP connection blocked by token limit",
|
||||
"component", "netlimit",
|
||||
"token", token,
|
||||
"current", tracker.connections.Load(),
|
||||
"max", l.config.MaxConnectionsPerToken)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// All checks passed, increment counters
|
||||
l.totalConnections.Add(1)
|
||||
|
||||
if ip != "" && l.config.MaxConnectionsPerIP > 0 {
|
||||
if tracker, exists := l.ipConnections[ip]; exists {
|
||||
tracker.connections.Add(1)
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if user != "" && l.config.MaxConnectionsPerUser > 0 {
|
||||
if tracker, exists := l.userConnections[user]; exists {
|
||||
tracker.connections.Add(1)
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if token != "" && l.config.MaxConnectionsPerToken > 0 {
|
||||
if tracker, exists := l.tokenConnections[token]; exists {
|
||||
tracker.connections.Add(1)
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Releases a tracked connection
|
||||
func (l *NetLimiter) ReleaseConnection(ip string, user string, token string) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.connMu.Lock()
|
||||
defer l.connMu.Unlock()
|
||||
|
||||
// Decrement total
|
||||
if l.totalConnections.Load() > 0 {
|
||||
l.totalConnections.Add(-1)
|
||||
}
|
||||
|
||||
// Decrement IP counter
|
||||
if ip != "" {
|
||||
if tracker, exists := l.ipConnections[ip]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement user counter
|
||||
if user != "" {
|
||||
if tracker, exists := l.userConnections[user]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement token counter
|
||||
if token != "" {
|
||||
if tracker, exists := l.tokenConnections[token]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -129,6 +129,11 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Configure authentication for sources that support it
|
||||
for _, sourceInst := range pipeline.Sources {
|
||||
sourceInst.SetAuth(cfg.Auth)
|
||||
}
|
||||
|
||||
// Start all sinks
|
||||
for i, sinkInst := range pipeline.Sinks {
|
||||
if err := sinkInst.Start(pipelineCtx); err != nil {
|
||||
@ -139,9 +144,7 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
|
||||
|
||||
// Configure authentication for sinks that support it
|
||||
for _, sinkInst := range pipeline.Sinks {
|
||||
if setter, ok := sinkInst.(sink.AuthSetter); ok {
|
||||
setter.SetAuthConfig(cfg.Auth)
|
||||
}
|
||||
sinkInst.SetAuth(cfg.Auth)
|
||||
}
|
||||
|
||||
// Wire sources to sinks through filters
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
|
||||
@ -121,7 +122,9 @@ func (s *StdoutSink) processLoop(ctx context.Context) {
|
||||
// Format and write
|
||||
formatted, err := s.formatter.Format(entry)
|
||||
if err != nil {
|
||||
s.logger.Error("msg", "Failed to format log entry for stdout", "error", err)
|
||||
s.logger.Error("msg", "Failed to format log entry for stdout",
|
||||
"component", "stdout_sink",
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
s.output.Write(formatted)
|
||||
@ -234,7 +237,9 @@ func (s *StderrSink) processLoop(ctx context.Context) {
|
||||
// Format and write
|
||||
formatted, err := s.formatter.Format(entry)
|
||||
if err != nil {
|
||||
s.logger.Error("msg", "Failed to format log entry for stderr", "error", err)
|
||||
s.logger.Error("msg", "Failed to format log entry for stderr",
|
||||
"component", "stderr_sink",
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
s.output.Write(formatted)
|
||||
@ -246,3 +251,11 @@ func (s *StderrSink) processLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdoutSink) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to stdout sink
|
||||
}
|
||||
|
||||
func (s *StderrSink) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to stderr sink
|
||||
}
|
||||
@ -4,6 +4,7 @@ package sink
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"logwisp/src/internal/config"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@ -43,7 +44,7 @@ func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
||||
writerConfig := log.DefaultConfig()
|
||||
writerConfig.Directory = directory
|
||||
writerConfig.Name = name
|
||||
writerConfig.EnableStdout = false // File only
|
||||
writerConfig.EnableConsole = false // File only
|
||||
writerConfig.ShowTimestamp = false // We already have timestamps in entries
|
||||
writerConfig.ShowLevel = false // We already have levels in entries
|
||||
|
||||
@ -165,3 +166,7 @@ func (fs *FileSink) processLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fs *FileSink) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to file sink
|
||||
}
|
||||
@ -142,31 +142,28 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
||||
}
|
||||
|
||||
// Extract net limit config
|
||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
||||
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||
cfg.NetLimit = &config.NetLimitConfig{}
|
||||
cfg.NetLimit.Enabled, _ = rl["enabled"].(bool)
|
||||
if rps, ok := rl["requests_per_second"].(float64); ok {
|
||||
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
|
||||
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||
cfg.NetLimit.RequestsPerSecond = rps
|
||||
}
|
||||
if burst, ok := rl["burst_size"].(int64); ok {
|
||||
if burst, ok := nl["burst_size"].(int64); ok {
|
||||
cfg.NetLimit.BurstSize = burst
|
||||
}
|
||||
if limitBy, ok := rl["limit_by"].(string); ok {
|
||||
cfg.NetLimit.LimitBy = limitBy
|
||||
}
|
||||
if respCode, ok := rl["response_code"].(int64); ok {
|
||||
if respCode, ok := nl["response_code"].(int64); ok {
|
||||
cfg.NetLimit.ResponseCode = respCode
|
||||
}
|
||||
if msg, ok := rl["response_message"].(string); ok {
|
||||
if msg, ok := nl["response_message"].(string); ok {
|
||||
cfg.NetLimit.ResponseMessage = msg
|
||||
}
|
||||
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
|
||||
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
|
||||
}
|
||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
||||
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsTotal = maxTotal
|
||||
}
|
||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||
for _, entry := range ipWhitelist {
|
||||
if str, ok := entry.(string); ok {
|
||||
@ -174,7 +171,7 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||
for _, entry := range ipBlacklist {
|
||||
if str, ok := entry.(string); ok {
|
||||
@ -806,8 +803,8 @@ func (h *HTTPSink) GetHost() string {
|
||||
return h.config.Host
|
||||
}
|
||||
|
||||
// Configures http sink authentication
|
||||
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
// Configures http sink auth
|
||||
func (h *HTTPSink) SetAuth(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -6,7 +6,10 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
@ -33,6 +36,7 @@ type HTTPClientSink struct {
|
||||
startTime time.Time
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
authenticator *auth.Authenticator
|
||||
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
@ -44,7 +48,9 @@ type HTTPClientSink struct {
|
||||
}
|
||||
|
||||
// Holds HTTP client sink configuration
|
||||
// TODO: missing toml tags
|
||||
type HTTPClientConfig struct {
|
||||
// Config
|
||||
URL string
|
||||
BufferSize int64
|
||||
BatchSize int64
|
||||
@ -57,6 +63,10 @@ type HTTPClientConfig struct {
|
||||
RetryDelay time.Duration
|
||||
RetryBackoff float64 // Multiplier for exponential backoff
|
||||
|
||||
// Security
|
||||
Username string
|
||||
Password string
|
||||
|
||||
// TLS configuration
|
||||
InsecureSkipVerify bool
|
||||
CAFile string
|
||||
@ -118,6 +128,12 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
|
||||
if insecure, ok := options["insecure_skip_verify"].(bool); ok {
|
||||
cfg.InsecureSkipVerify = insecure
|
||||
}
|
||||
if username, ok := options["username"].(string); ok {
|
||||
cfg.Username = username
|
||||
}
|
||||
if password, ok := options["password"].(string); ok {
|
||||
cfg.Password = password // TODO: change to Argon2 hashed password
|
||||
}
|
||||
|
||||
// Extract headers
|
||||
if headers, ok := options["headers"].(map[string]any); ok {
|
||||
@ -422,8 +438,16 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
|
||||
req.SetRequestURI(h.config.URL)
|
||||
req.Header.SetMethod("POST")
|
||||
req.Header.SetContentType("application/json")
|
||||
req.SetBody(body)
|
||||
|
||||
// Add Basic Auth header if credentials configured
|
||||
if h.config.Username != "" && h.config.Password != "" {
|
||||
creds := h.config.Username + ":" + h.config.Password
|
||||
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
|
||||
req.Header.Set("Authorization", "Basic "+encodedCreds)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
for k, v := range h.config.Headers {
|
||||
req.Header.Set(k, v)
|
||||
@ -495,3 +519,9 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
"last_error", lastErr)
|
||||
h.failedBatches.Add(1)
|
||||
}
|
||||
|
||||
// Not applicable, Clients authenticate to remote servers using Username/Password in config
|
||||
func (h *HTTPClientSink) SetAuth(authCfg *config.AuthConfig) {
|
||||
// No-op: client sinks don't validate incoming connections
|
||||
// They authenticate to remote servers using Username/Password fields
|
||||
}
|
||||
@ -9,19 +9,22 @@ import (
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
// Represents an output destination for log entries
|
||||
// Represents an output data stream
|
||||
type Sink interface {
|
||||
// Input returns the channel for sending log entries to this sink
|
||||
// Returns the channel for sending log entries to this sink
|
||||
Input() chan<- core.LogEntry
|
||||
|
||||
// Start begins processing log entries
|
||||
// Begins processing log entries
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Stop gracefully shuts down the sink
|
||||
// Gracefully shuts down the sink
|
||||
Stop()
|
||||
|
||||
// GetStats returns sink statistics
|
||||
// Returns sink statistics
|
||||
GetStats() SinkStats
|
||||
|
||||
// Configure authentication
|
||||
SetAuth(auth *config.AuthConfig)
|
||||
}
|
||||
|
||||
// Contains statistics about a sink
|
||||
@ -33,8 +36,3 @@ type SinkStats struct {
|
||||
LastProcessed time.Time
|
||||
Details map[string]any
|
||||
}
|
||||
|
||||
// Interface for sinks that can accept an AuthConfig
|
||||
type AuthSetter interface {
|
||||
SetAuthConfig(auth *config.AuthConfig)
|
||||
}
|
||||
@ -17,7 +17,6 @@ import (
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/tls"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/lixenwraith/log/compat"
|
||||
@ -26,6 +25,7 @@ import (
|
||||
|
||||
// Streams log entries via TCP
|
||||
type TCPSink struct {
|
||||
// C
|
||||
input chan core.LogEntry
|
||||
config TCPConfig
|
||||
server *tcpServer
|
||||
@ -38,11 +38,7 @@ type TCPSink struct {
|
||||
netLimiter *limit.NetLimiter
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
// Security components
|
||||
authenticator *auth.Authenticator
|
||||
tlsManager *tls.Manager
|
||||
authConfig *config.AuthConfig
|
||||
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
@ -62,7 +58,6 @@ type TCPConfig struct {
|
||||
Port int64
|
||||
BufferSize int64
|
||||
Heartbeat *config.HeartbeatConfig
|
||||
TLS *config.TLSConfig
|
||||
NetLimit *config.NetLimitConfig
|
||||
}
|
||||
|
||||
@ -99,58 +94,35 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
||||
}
|
||||
}
|
||||
|
||||
// Extract TLS config
|
||||
if tc, ok := options["tls"].(map[string]any); ok {
|
||||
cfg.TLS = &config.TLSConfig{}
|
||||
cfg.TLS.Enabled, _ = tc["enabled"].(bool)
|
||||
if certFile, ok := tc["cert_file"].(string); ok {
|
||||
cfg.TLS.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := tc["key_file"].(string); ok {
|
||||
cfg.TLS.KeyFile = keyFile
|
||||
}
|
||||
cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool)
|
||||
if caFile, ok := tc["client_ca_file"].(string); ok {
|
||||
cfg.TLS.ClientCAFile = caFile
|
||||
}
|
||||
cfg.TLS.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
|
||||
if minVer, ok := tc["min_version"].(string); ok {
|
||||
cfg.TLS.MinVersion = minVer
|
||||
}
|
||||
if maxVer, ok := tc["max_version"].(string); ok {
|
||||
cfg.TLS.MaxVersion = maxVer
|
||||
}
|
||||
if ciphers, ok := tc["cipher_suites"].(string); ok {
|
||||
cfg.TLS.CipherSuites = ciphers
|
||||
}
|
||||
}
|
||||
|
||||
// Extract net limit config
|
||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
||||
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||
cfg.NetLimit = &config.NetLimitConfig{}
|
||||
cfg.NetLimit.Enabled, _ = rl["enabled"].(bool)
|
||||
if rps, ok := rl["requests_per_second"].(float64); ok {
|
||||
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
|
||||
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||
cfg.NetLimit.RequestsPerSecond = rps
|
||||
}
|
||||
if burst, ok := rl["burst_size"].(int64); ok {
|
||||
if burst, ok := nl["burst_size"].(int64); ok {
|
||||
cfg.NetLimit.BurstSize = burst
|
||||
}
|
||||
if limitBy, ok := rl["limit_by"].(string); ok {
|
||||
cfg.NetLimit.LimitBy = limitBy
|
||||
}
|
||||
if respCode, ok := rl["response_code"].(int64); ok {
|
||||
if respCode, ok := nl["response_code"].(int64); ok {
|
||||
cfg.NetLimit.ResponseCode = respCode
|
||||
}
|
||||
if msg, ok := rl["response_message"].(string); ok {
|
||||
if msg, ok := nl["response_message"].(string); ok {
|
||||
cfg.NetLimit.ResponseMessage = msg
|
||||
}
|
||||
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
|
||||
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
|
||||
}
|
||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
||||
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsPerUser = maxPerUser
|
||||
}
|
||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
||||
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsPerToken = maxPerToken
|
||||
}
|
||||
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsTotal = maxTotal
|
||||
}
|
||||
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||
for _, entry := range ipWhitelist {
|
||||
if str, ok := entry.(string); ok {
|
||||
@ -158,7 +130,7 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
||||
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||
for _, entry := range ipBlacklist {
|
||||
if str, ok := entry.(string); ok {
|
||||
@ -290,18 +262,6 @@ func (t *TCPSink) GetStats() SinkStats {
|
||||
netLimitStats = t.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var authStats map[string]any
|
||||
if t.authenticator != nil {
|
||||
authStats = t.authenticator.GetStats()
|
||||
authStats["failures"] = t.authFailures.Load()
|
||||
authStats["successes"] = t.authSuccesses.Load()
|
||||
}
|
||||
|
||||
var tlsStats map[string]any
|
||||
if t.tlsManager != nil {
|
||||
tlsStats = t.tlsManager.GetStats()
|
||||
}
|
||||
|
||||
return SinkStats{
|
||||
Type: "tcp",
|
||||
TotalProcessed: t.totalProcessed.Load(),
|
||||
@ -312,8 +272,7 @@ func (t *TCPSink) GetStats() SinkStats {
|
||||
"port": t.config.Port,
|
||||
"buffer_size": t.config.BufferSize,
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
"tls": tlsStats,
|
||||
"auth": map[string]any{"enabled": false},
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -347,21 +306,31 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
"entry_source", entry.Source)
|
||||
continue
|
||||
}
|
||||
t.broadcastData(data)
|
||||
|
||||
// Broadcast only to authenticated clients
|
||||
t.server.mu.RLock()
|
||||
for conn, client := range t.server.clients {
|
||||
if client.authenticated {
|
||||
// Send through TLS bridge if present
|
||||
if client.tlsBridge != nil {
|
||||
if _, err := client.tlsBridge.Write(data); err != nil {
|
||||
// TLS write failed, connection likely dead
|
||||
t.logger.Debug("msg", "TLS write failed",
|
||||
case <-tickerChan:
|
||||
heartbeatEntry := t.createHeartbeatEntry()
|
||||
data, err := t.formatter.Format(heartbeatEntry)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to format heartbeat",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
t.broadcastData(data)
|
||||
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TCPSink) broadcastData(data []byte) {
|
||||
t.server.mu.RLock()
|
||||
defer t.server.mu.RUnlock()
|
||||
|
||||
for conn, client := range t.server.clients {
|
||||
if client.authenticated {
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
@ -376,54 +345,6 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
t.server.mu.RUnlock()
|
||||
|
||||
case <-tickerChan:
|
||||
heartbeatEntry := t.createHeartbeatEntry()
|
||||
data, err := t.formatter.Format(heartbeatEntry)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to format heartbeat",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
t.server.mu.RLock()
|
||||
for conn, client := range t.server.clients {
|
||||
if client.authenticated {
|
||||
// Validate session is still active
|
||||
if t.authenticator != nil && client.session != nil {
|
||||
if !t.authenticator.ValidateSession(client.session.ID) {
|
||||
// Session expired, close connection
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
}
|
||||
if client.tlsBridge != nil {
|
||||
if _, err := client.tlsBridge.Write(data); err != nil {
|
||||
t.logger.Debug("msg", "TLS heartbeat write failed",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
conn.Close()
|
||||
}
|
||||
} else {
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
t.handleWriteError(c, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
t.server.mu.RUnlock()
|
||||
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle write errors with threshold-based connection termination
|
||||
@ -490,10 +411,8 @@ type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authenticated bool
|
||||
session *auth.Session
|
||||
authTimeout time.Time
|
||||
tlsBridge *tls.GNetTLSConn
|
||||
authTimeoutSet bool
|
||||
session *auth.Session
|
||||
}
|
||||
|
||||
// Handles gnet events with authentication
|
||||
@ -551,23 +470,11 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
// Create client state without auth timeout initially
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
||||
authTimeoutSet: false, // Auth timeout not started yet
|
||||
authenticated: s.sink.authenticator == nil,
|
||||
}
|
||||
|
||||
// Initialize TLS bridge if enabled
|
||||
if s.sink.tlsManager != nil {
|
||||
tlsConfig := s.sink.tlsManager.GetTCPConfig()
|
||||
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
|
||||
client.tlsBridge.Handshake() // Start async handshake
|
||||
|
||||
s.sink.logger.Debug("msg", "TLS handshake initiated",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", remoteAddr)
|
||||
} else if s.sink.authenticator != nil {
|
||||
// Only set auth timeout if no TLS (plain connection)
|
||||
client.authTimeout = time.Now().Add(30 * time.Second) // TODO: configurable or non-hardcoded timer
|
||||
client.authTimeoutSet = true
|
||||
if s.sink.authenticator != nil {
|
||||
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
@ -578,12 +485,11 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
s.sink.logger.Debug("msg", "TCP connection opened",
|
||||
"remote_addr", remoteAddr,
|
||||
"active_connections", newCount,
|
||||
"requires_auth", s.sink.authenticator != nil)
|
||||
"auth_enabled", s.sink.authenticator != nil)
|
||||
|
||||
// Send auth prompt if authentication is required
|
||||
if s.sink.authenticator != nil && s.sink.tlsManager == nil {
|
||||
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
|
||||
return authPrompt, gnet.None
|
||||
if s.sink.authenticator != nil {
|
||||
return []byte("AUTH_REQUIRED\n"), gnet.None
|
||||
}
|
||||
|
||||
return nil, gnet.None
|
||||
@ -594,17 +500,9 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
|
||||
// Remove client state
|
||||
s.mu.Lock()
|
||||
client := s.clients[c]
|
||||
delete(s.clients, c)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Clean up TLS bridge if present
|
||||
if client != nil && client.tlsBridge != nil {
|
||||
client.tlsBridge.Close()
|
||||
s.sink.logger.Debug("msg", "TLS connection closed",
|
||||
"remote_addr", remoteAddr)
|
||||
}
|
||||
|
||||
// Clean up write error tracking
|
||||
s.sink.errorMu.Lock()
|
||||
delete(s.sink.consecutiveWriteErrors, c)
|
||||
@ -632,98 +530,34 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Read all available data
|
||||
data, err := c.Next(-1)
|
||||
if err != nil {
|
||||
s.sink.logger.Error("msg", "Error reading from connection",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Process through TLS bridge if present
|
||||
if client.tlsBridge != nil {
|
||||
// Feed encrypted data into TLS engine
|
||||
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
|
||||
s.sink.logger.Error("msg", "TLS processing error",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", err)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Check if handshake is complete
|
||||
if !client.tlsBridge.IsHandshakeDone() {
|
||||
// Still handshaking, wait for more data
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Check handshake result
|
||||
_, hsErr := client.tlsBridge.HandshakeComplete()
|
||||
if hsErr != nil {
|
||||
s.sink.logger.Error("msg", "TLS handshake failed",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", hsErr)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Set auth timeout only after TLS handshake completes
|
||||
if !client.authTimeoutSet && s.sink.authenticator != nil && !client.authenticated {
|
||||
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||
client.authTimeoutSet = true
|
||||
s.sink.logger.Debug("msg", "Auth timeout started after TLS handshake",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
}
|
||||
|
||||
// Read decrypted plaintext
|
||||
data = client.tlsBridge.Read()
|
||||
if data == nil || len(data) == 0 {
|
||||
// No plaintext available yet
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// First data after TLS handshake - send auth prompt if needed
|
||||
if s.sink.authenticator != nil && !client.authenticated &&
|
||||
len(client.buffer.Bytes()) == 0 {
|
||||
authPrompt := []byte("AUTH REQUIRED\n")
|
||||
client.tlsBridge.Write(authPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Only check auth timeout if it has been set
|
||||
if !client.authenticated && client.authTimeoutSet && time.Now().After(client.authTimeout) {
|
||||
// Authentication phase
|
||||
if !client.authenticated {
|
||||
// Check auth timeout
|
||||
if time.Now().After(client.authTimeout) {
|
||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
|
||||
client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
|
||||
} else if client.tlsBridge == nil {
|
||||
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
||||
}
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// If not authenticated, expect auth command
|
||||
if !client.authenticated {
|
||||
// Read auth data
|
||||
data, _ := c.Next(-1)
|
||||
if len(data) == 0 {
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
client.buffer.Write(data)
|
||||
|
||||
// Look for complete auth line
|
||||
if line, err := client.buffer.ReadBytes('\n'); err == nil {
|
||||
line = bytes.TrimSpace(line)
|
||||
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 {
|
||||
line := client.buffer.Bytes()[:idx]
|
||||
client.buffer.Next(idx + 1)
|
||||
|
||||
// Parse AUTH command: AUTH <method> <credentials>
|
||||
parts := strings.SplitN(string(line), " ", 3)
|
||||
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||
// Send error through TLS if enabled
|
||||
errMsg := []byte("AUTH FAILED\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(errMsg)
|
||||
} else {
|
||||
c.AsyncWrite(errMsg, nil)
|
||||
}
|
||||
return gnet.None
|
||||
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Authenticate
|
||||
@ -734,13 +568,7 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"method", parts[1],
|
||||
"error", err)
|
||||
// Send error through TLS if enabled
|
||||
errMsg := []byte("AUTH FAILED\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(errMsg)
|
||||
} else {
|
||||
c.AsyncWrite(errMsg, nil)
|
||||
}
|
||||
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
@ -755,35 +583,25 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"username", session.Username,
|
||||
"method", session.Method,
|
||||
"tls", client.tlsBridge != nil)
|
||||
"method", session.Method)
|
||||
|
||||
// Send success through TLS if enabled
|
||||
successMsg := []byte("AUTH OK\n")
|
||||
if client.tlsBridge != nil {
|
||||
client.tlsBridge.Write(successMsg)
|
||||
} else {
|
||||
c.AsyncWrite(successMsg, nil)
|
||||
}
|
||||
|
||||
// Clear buffer after auth
|
||||
c.AsyncWrite([]byte("AUTH_OK\n"), nil)
|
||||
client.buffer.Reset()
|
||||
}
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Authenticated clients shouldn't send data, just discard
|
||||
// Clients shouldn't send data, just discard
|
||||
c.Discard(-1)
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Configures tcp sink authentication
|
||||
func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
// Configures tcp sink auth
|
||||
func (t *TCPSink) SetAuth(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
t.authConfig = authCfg
|
||||
authenticator, err := auth.New(authCfg, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
|
||||
@ -793,22 +611,7 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
||||
}
|
||||
t.authenticator = authenticator
|
||||
|
||||
// Initialize TLS manager if TLS is configured
|
||||
if t.config.TLS != nil && t.config.TLS.Enabled {
|
||||
tlsManager, err := tls.NewManager(t.config.TLS, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to create TLS manager",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
// Continue without TLS
|
||||
return
|
||||
}
|
||||
t.tlsManager = tlsManager
|
||||
}
|
||||
|
||||
t.logger.Info("msg", "Authentication configured for TCP sink",
|
||||
"component", "tcp_sink",
|
||||
"auth_type", authCfg.Type,
|
||||
"tls_enabled", t.tlsManager != nil,
|
||||
"tls_bridge", t.tlsManager != nil)
|
||||
"auth_type", authCfg.Type)
|
||||
}
|
||||
@ -2,22 +2,21 @@
|
||||
package sink
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
tlspkg "logwisp/src/internal/tls"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
@ -33,10 +32,7 @@ type TCPClientSink struct {
|
||||
startTime time.Time
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
// TLS support
|
||||
tlsManager *tlspkg.Manager
|
||||
tlsConfig *tls.Config
|
||||
authenticator *auth.Authenticator
|
||||
|
||||
// Reconnection state
|
||||
reconnecting atomic.Bool
|
||||
@ -60,6 +56,10 @@ type TCPClientConfig struct {
|
||||
ReadTimeout time.Duration
|
||||
KeepAlive time.Duration
|
||||
|
||||
// Security
|
||||
Username string
|
||||
Password string
|
||||
|
||||
// Reconnection settings
|
||||
ReconnectDelay time.Duration
|
||||
MaxReconnectDelay time.Duration
|
||||
@ -120,27 +120,11 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
||||
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
|
||||
cfg.ReconnectBackoff = backoff
|
||||
}
|
||||
|
||||
// Extract TLS config
|
||||
if tc, ok := options["tls"].(map[string]any); ok {
|
||||
cfg.TLS = &config.TLSConfig{}
|
||||
cfg.TLS.Enabled, _ = tc["enabled"].(bool)
|
||||
if certFile, ok := tc["cert_file"].(string); ok {
|
||||
cfg.TLS.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := tc["key_file"].(string); ok {
|
||||
cfg.TLS.KeyFile = keyFile
|
||||
}
|
||||
cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool)
|
||||
if caFile, ok := tc["client_ca_file"].(string); ok {
|
||||
cfg.TLS.ClientCAFile = caFile
|
||||
}
|
||||
if insecure, ok := tc["insecure_skip_verify"].(bool); ok {
|
||||
cfg.TLS.InsecureSkipVerify = insecure
|
||||
}
|
||||
if caFile, ok := tc["ca_file"].(string); ok {
|
||||
cfg.TLS.CAFile = caFile
|
||||
if username, ok := options["username"].(string); ok {
|
||||
cfg.Username = username
|
||||
}
|
||||
if password, ok := options["password"].(string); ok {
|
||||
cfg.Password = password
|
||||
}
|
||||
|
||||
t := &TCPClientSink{
|
||||
@ -154,62 +138,6 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
||||
t.lastProcessed.Store(time.Time{})
|
||||
t.connectionUptime.Store(time.Duration(0))
|
||||
|
||||
// Initialize TLS manager if TLS is configured
|
||||
if cfg.TLS != nil && cfg.TLS.Enabled {
|
||||
// Build custom TLS config for client
|
||||
t.tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: cfg.TLS.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
// Extract server name from address for SNI
|
||||
host, _, err := net.SplitHostPort(cfg.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse address for SNI: %w", err)
|
||||
}
|
||||
t.tlsConfig.ServerName = host
|
||||
|
||||
// Load custom CA for server verification
|
||||
if cfg.TLS.CAFile != "" {
|
||||
caCert, err := os.ReadFile(cfg.TLS.CAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.TLS.CAFile, err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.TLS.CAFile)
|
||||
}
|
||||
t.tlsConfig.RootCAs = caCertPool
|
||||
logger.Debug("msg", "Custom CA loaded for server verification",
|
||||
"component", "tcp_client_sink",
|
||||
"ca_file", cfg.TLS.CAFile)
|
||||
}
|
||||
|
||||
// Load client certificate for mTLS
|
||||
if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertFile, cfg.TLS.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client certificate: %w", err)
|
||||
}
|
||||
t.tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
logger.Info("msg", "Client certificate loaded for mTLS",
|
||||
"component", "tcp_client_sink",
|
||||
"cert_file", cfg.TLS.CertFile)
|
||||
}
|
||||
|
||||
// Set minimum TLS version if configured
|
||||
if cfg.TLS.MinVersion != "" {
|
||||
t.tlsConfig.MinVersion = parseTLSVersion(cfg.TLS.MinVersion, tls.VersionTLS12)
|
||||
} else {
|
||||
t.tlsConfig.MinVersion = tls.VersionTLS12 // Default minimum
|
||||
}
|
||||
|
||||
logger.Info("msg", "TLS enabled for TCP client",
|
||||
"component", "tcp_client_sink",
|
||||
"address", cfg.Address,
|
||||
"server_name", host,
|
||||
"insecure", cfg.TLS.InsecureSkipVerify,
|
||||
"mtls", cfg.TLS.CertFile != "")
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@ -376,33 +304,44 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
|
||||
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
||||
}
|
||||
|
||||
// Wrap with TLS if configured
|
||||
if t.tlsConfig != nil {
|
||||
t.logger.Debug("msg", "Initiating TLS handshake",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address)
|
||||
|
||||
tlsConn := tls.Client(conn, t.tlsConfig)
|
||||
|
||||
// Perform handshake with timeout
|
||||
handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := tlsConn.HandshakeContext(handshakeCtx); err != nil {
|
||||
// Handle authentication if credentials configured
|
||||
if t.config.Username != "" && t.config.Password != "" {
|
||||
// Read auth challenge
|
||||
reader := bufio.NewReader(conn)
|
||||
challenge, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
return nil, fmt.Errorf("failed to read auth challenge: %w", err)
|
||||
}
|
||||
|
||||
// Log connection details
|
||||
state := tlsConn.ConnectionState()
|
||||
t.logger.Info("msg", "TLS connection established",
|
||||
if strings.TrimSpace(challenge) == "AUTH_REQUIRED" {
|
||||
// Send credentials
|
||||
creds := t.config.Username + ":" + t.config.Password
|
||||
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
|
||||
authCmd := fmt.Sprintf("AUTH basic %s\n", encodedCreds)
|
||||
|
||||
if _, err := conn.Write([]byte(authCmd)); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to send auth: %w", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to read auth response: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(response) != "AUTH_OK" {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %s", response)
|
||||
}
|
||||
|
||||
t.logger.Debug("msg", "TCP authentication successful",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"tls_version", tlsVersionString(state.Version),
|
||||
"cipher_suite", tls.CipherSuiteName(state.CipherSuite),
|
||||
"server_name", state.ServerName)
|
||||
|
||||
return tlsConn, nil
|
||||
"username", t.config.Username)
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@ -504,34 +443,8 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns human-readable TLS version
|
||||
func tlsVersionString(version uint16) string {
|
||||
switch version {
|
||||
case tls.VersionTLS10:
|
||||
return "TLS1.0"
|
||||
case tls.VersionTLS11:
|
||||
return "TLS1.1"
|
||||
case tls.VersionTLS12:
|
||||
return "TLS1.2"
|
||||
case tls.VersionTLS13:
|
||||
return "TLS1.3"
|
||||
default:
|
||||
return fmt.Sprintf("0x%04x", version)
|
||||
}
|
||||
}
|
||||
|
||||
// Converts string to TLS version constant
|
||||
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
|
||||
switch strings.ToUpper(version) {
|
||||
case "TLS1.0", "TLS10":
|
||||
return tls.VersionTLS10
|
||||
case "TLS1.1", "TLS11":
|
||||
return tls.VersionTLS11
|
||||
case "TLS1.2", "TLS12":
|
||||
return tls.VersionTLS12
|
||||
case "TLS1.3", "TLS13":
|
||||
return tls.VersionTLS13
|
||||
default:
|
||||
return defaultVersion
|
||||
}
|
||||
// Not applicable, Clients authenticate to remote servers using Username/Password in config
|
||||
func (h *TCPClientSink) SetAuth(authCfg *config.AuthConfig) {
|
||||
// No-op: client sinks don't validate incoming connections
|
||||
// They authenticate to remote servers using Username/Password fields
|
||||
}
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
@ -287,3 +288,7 @@ func globToRegex(glob string) string {
|
||||
regex = strings.ReplaceAll(regex, `\?`, `.`)
|
||||
return "^" + regex + "$"
|
||||
}
|
||||
|
||||
func (ds *DirectorySource) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to directory source
|
||||
}
|
||||
@ -4,6 +4,7 @@ package source
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"logwisp/src/internal/auth"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -20,19 +21,29 @@ import (
|
||||
|
||||
// Receives log entries via HTTP POST requests
|
||||
type HTTPSource struct {
|
||||
// Config
|
||||
host string
|
||||
port int64
|
||||
path string
|
||||
bufferSize int64
|
||||
maxRequestBodySize int64
|
||||
|
||||
// Application
|
||||
server *fasthttp.Server
|
||||
subscribers []chan core.LogEntry
|
||||
mu sync.RWMutex
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
netLimiter *limit.NetLimiter
|
||||
logger *log.Logger
|
||||
|
||||
// Add TLS support
|
||||
// Runtime
|
||||
mu sync.RWMutex
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Security
|
||||
authenticator *auth.Authenticator
|
||||
authConfig *config.AuthConfig
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
tlsManager *tls.Manager
|
||||
tlsConfig *config.TLSConfig
|
||||
|
||||
@ -66,11 +77,17 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err
|
||||
bufferSize = bufSize
|
||||
}
|
||||
|
||||
maxRequestBodySize := int64(10 * 1024 * 1024) // fasthttp default 10MB
|
||||
if maxBodySize, ok := options["max_body_size"].(int64); ok && maxBodySize > 0 && maxBodySize < maxRequestBodySize {
|
||||
maxRequestBodySize = maxBodySize
|
||||
}
|
||||
|
||||
h := &HTTPSource{
|
||||
host: host,
|
||||
port: port,
|
||||
path: ingestPath,
|
||||
bufferSize: bufferSize,
|
||||
maxRequestBodySize: maxRequestBodySize,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
@ -78,30 +95,36 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err
|
||||
h.lastEntryTime.Store(time.Time{})
|
||||
|
||||
// Initialize net limiter if configured
|
||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
||||
if enabled, _ := rl["enabled"].(bool); enabled {
|
||||
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||
if enabled, _ := nl["enabled"].(bool); enabled {
|
||||
cfg := config.NetLimitConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
if rps, ok := rl["requests_per_second"].(float64); ok {
|
||||
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||
cfg.RequestsPerSecond = rps
|
||||
}
|
||||
if burst, ok := rl["burst_size"].(int64); ok {
|
||||
if burst, ok := nl["burst_size"].(int64); ok {
|
||||
cfg.BurstSize = burst
|
||||
}
|
||||
if limitBy, ok := rl["limit_by"].(string); ok {
|
||||
cfg.LimitBy = limitBy
|
||||
}
|
||||
if respCode, ok := rl["response_code"].(int64); ok {
|
||||
if respCode, ok := nl["response_code"].(int64); ok {
|
||||
cfg.ResponseCode = respCode
|
||||
}
|
||||
if msg, ok := rl["response_message"].(string); ok {
|
||||
if msg, ok := nl["response_message"].(string); ok {
|
||||
cfg.ResponseMessage = msg
|
||||
}
|
||||
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
|
||||
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||
cfg.MaxConnectionsPerIP = maxPerIP
|
||||
}
|
||||
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
|
||||
cfg.MaxConnectionsPerUser = maxPerUser
|
||||
}
|
||||
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
|
||||
cfg.MaxConnectionsPerToken = maxPerToken
|
||||
}
|
||||
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||
cfg.MaxConnectionsTotal = maxTotal
|
||||
}
|
||||
|
||||
h.netLimiter = limit.NewNetLimiter(cfg, logger)
|
||||
}
|
||||
@ -161,6 +184,7 @@ func (h *HTTPSource) Start() error {
|
||||
DisableKeepalive: false,
|
||||
StreamRequestBody: true,
|
||||
CloseOnShutdown: true,
|
||||
MaxRequestBodySize: int(h.maxRequestBodySize),
|
||||
}
|
||||
|
||||
// Use configured host and port
|
||||
@ -259,19 +283,9 @@ func (h *HTTPSource) GetStats() SourceStats {
|
||||
}
|
||||
|
||||
func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
// Only handle POST to the configured ingest path
|
||||
if string(ctx.Method()) != "POST" || string(ctx.Path()) != h.path {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Not Found",
|
||||
"hint": fmt.Sprintf("POST logs to %s", h.path),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract and validate IP
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
|
||||
// 1. IPv6 check (early reject)
|
||||
ipStr, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err == nil {
|
||||
if ip := net.ParseIP(ipStr); ip != nil && ip.To4() == nil {
|
||||
@ -284,7 +298,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check net limit
|
||||
// 2. Net limit check (early reject)
|
||||
if h.netLimiter != nil {
|
||||
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
||||
ctx.SetStatusCode(int(statusCode))
|
||||
@ -297,7 +311,65 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
// Process the request body
|
||||
// 3. Path check (only process ingest path)
|
||||
path := string(ctx.Path())
|
||||
if path != h.path {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Not Found",
|
||||
"hint": fmt.Sprintf("POST logs to %s", h.path),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Method check (only accept POST)
|
||||
if string(ctx.Method()) != "POST" {
|
||||
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.Response.Header.Set("Allow", "POST")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Method not allowed",
|
||||
"hint": "Use POST to submit logs",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Authentication check (if configured)
|
||||
if h.authenticator != nil {
|
||||
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||
session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
|
||||
if err != nil {
|
||||
h.authFailures.Add(1)
|
||||
h.logger.Warn("msg", "Authentication failed",
|
||||
"component", "http_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"error", err)
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
||||
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
|
||||
realm := h.authConfig.BasicAuth.Realm
|
||||
if realm == "" {
|
||||
realm = "Restricted"
|
||||
}
|
||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm))
|
||||
} else if h.authConfig.Type == "bearer" {
|
||||
ctx.Response.Header.Set("WWW-Authenticate", "Bearer")
|
||||
}
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
h.authSuccesses.Add(1)
|
||||
h.logger.Debug("msg", "Request authenticated",
|
||||
"component", "http_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"username", session.Username)
|
||||
}
|
||||
|
||||
// 6. Process request body
|
||||
body := ctx.PostBody()
|
||||
if len(body) == 0 {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
@ -308,7 +380,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the log entries
|
||||
// 7. Parse log entries
|
||||
entries, err := h.parseEntries(body)
|
||||
if err != nil {
|
||||
h.invalidEntries.Add(1)
|
||||
@ -320,7 +392,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
return
|
||||
}
|
||||
|
||||
// Publish entries
|
||||
// 8. Publish entries to subscribers
|
||||
accepted := 0
|
||||
for _, entry := range entries {
|
||||
if h.publish(entry) {
|
||||
@ -328,7 +400,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
// Return success response
|
||||
// 9. Return success response
|
||||
ctx.SetStatusCode(fasthttp.StatusAccepted)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]any{
|
||||
@ -461,3 +533,24 @@ func splitLines(data []byte) [][]byte {
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
// Configure HTTP source auth
|
||||
func (h *HTTPSource) SetAuth(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
h.authConfig = authCfg
|
||||
authenticator, err := auth.New(authCfg, h.logger)
|
||||
if err != nil {
|
||||
h.logger.Error("msg", "Failed to initialize authenticator for HTTP source",
|
||||
"component", "http_source",
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
h.authenticator = authenticator
|
||||
|
||||
h.logger.Info("msg", "Authentication configured for HTTP source",
|
||||
"component", "http_source",
|
||||
"auth_type", authCfg.Type)
|
||||
}
|
||||
@ -4,22 +4,26 @@ package source
|
||||
import (
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
// Represents an input data stream
|
||||
type Source interface {
|
||||
// Subscribe returns a channel that receives log entries
|
||||
// Returns a channel that receives log entries
|
||||
Subscribe() <-chan core.LogEntry
|
||||
|
||||
// Start begins reading from the source
|
||||
// Begins reading from the source
|
||||
Start() error
|
||||
|
||||
// Stop gracefully shuts down the source
|
||||
// Gracefully shuts down the source
|
||||
Stop()
|
||||
|
||||
// GetStats returns source statistics
|
||||
// Returns source statistics
|
||||
GetStats() SourceStats
|
||||
|
||||
// Configure authentication
|
||||
SetAuth(auth *config.AuthConfig)
|
||||
}
|
||||
|
||||
// Contains statistics about a source
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
@ -119,3 +120,7 @@ func (s *StdinSource) publish(entry core.LogEntry) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdinSource) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to stdin source
|
||||
}
|
||||
@ -5,9 +5,9 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -16,7 +16,6 @@ import (
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/tls"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/lixenwraith/log/compat"
|
||||
@ -26,8 +25,6 @@ import (
|
||||
const (
|
||||
maxClientBufferSize = 10 * 1024 * 1024 // 10MB max per client
|
||||
maxLineLength = 1 * 1024 * 1024 // 1MB max per log line
|
||||
maxEncryptedDataPerRead = 1 * 1024 * 1024 // 1MB max encrypted data per read
|
||||
maxCumulativeEncrypted = 20 * 1024 * 1024 // 20MB total encrypted before processing
|
||||
)
|
||||
|
||||
// Receives log entries via TCP connections
|
||||
@ -43,9 +40,8 @@ type TCPSource struct {
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
netLimiter *limit.NetLimiter
|
||||
tlsManager *tls.Manager
|
||||
tlsConfig *config.TLSConfig
|
||||
logger *log.Logger
|
||||
authenticator *auth.Authenticator
|
||||
|
||||
// Statistics
|
||||
totalEntries atomic.Uint64
|
||||
@ -54,6 +50,8 @@ type TCPSource struct {
|
||||
activeConns atomic.Int64
|
||||
startTime time.Time
|
||||
lastEntryTime atomic.Value // time.Time
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
}
|
||||
|
||||
// Creates a new TCP server source
|
||||
@ -84,58 +82,35 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error
|
||||
t.lastEntryTime.Store(time.Time{})
|
||||
|
||||
// Initialize net limiter if configured
|
||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
||||
if enabled, _ := rl["enabled"].(bool); enabled {
|
||||
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||
if enabled, _ := nl["enabled"].(bool); enabled {
|
||||
cfg := config.NetLimitConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
if rps, ok := rl["requests_per_second"].(float64); ok {
|
||||
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||
cfg.RequestsPerSecond = rps
|
||||
}
|
||||
if burst, ok := rl["burst_size"].(int64); ok {
|
||||
if burst, ok := nl["burst_size"].(int64); ok {
|
||||
cfg.BurstSize = burst
|
||||
}
|
||||
if limitBy, ok := rl["limit_by"].(string); ok {
|
||||
cfg.LimitBy = limitBy
|
||||
}
|
||||
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
|
||||
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||
cfg.MaxConnectionsPerIP = maxPerIP
|
||||
}
|
||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
||||
cfg.MaxTotalConnections = maxTotal
|
||||
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
|
||||
cfg.MaxConnectionsPerUser = maxPerUser
|
||||
}
|
||||
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
|
||||
cfg.MaxConnectionsPerToken = maxPerToken
|
||||
}
|
||||
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||
cfg.MaxConnectionsTotal = maxTotal
|
||||
}
|
||||
|
||||
t.netLimiter = limit.NewNetLimiter(cfg, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract TLS config and initialize TLS manager
|
||||
if tc, ok := options["tls"].(map[string]any); ok {
|
||||
t.tlsConfig = &config.TLSConfig{}
|
||||
t.tlsConfig.Enabled, _ = tc["enabled"].(bool)
|
||||
if certFile, ok := tc["cert_file"].(string); ok {
|
||||
t.tlsConfig.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := tc["key_file"].(string); ok {
|
||||
t.tlsConfig.KeyFile = keyFile
|
||||
}
|
||||
t.tlsConfig.ClientAuth, _ = tc["client_auth"].(bool)
|
||||
if caFile, ok := tc["client_ca_file"].(string); ok {
|
||||
t.tlsConfig.ClientCAFile = caFile
|
||||
}
|
||||
t.tlsConfig.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
|
||||
|
||||
// Create TLS manager if enabled
|
||||
if t.tlsConfig.Enabled {
|
||||
tlsManager, err := tls.NewManager(t.tlsConfig, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
t.tlsManager = tlsManager
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@ -167,8 +142,7 @@ func (t *TCPSource) Start() error {
|
||||
defer t.wg.Done()
|
||||
t.logger.Info("msg", "TCP source server starting",
|
||||
"component", "tcp_source",
|
||||
"port", t.port,
|
||||
"tls_enabled", t.tlsManager != nil)
|
||||
"port", t.port)
|
||||
|
||||
err := gnet.Run(t.server, addr,
|
||||
gnet.WithLogger(gnetLogger),
|
||||
@ -283,9 +257,8 @@ type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authenticated bool
|
||||
session *auth.Session
|
||||
authTimeout time.Time
|
||||
tlsBridge *tls.GNetTLSConn
|
||||
session *auth.Session
|
||||
maxBufferSeen int
|
||||
cumulativeEncrypted int64
|
||||
}
|
||||
@ -339,22 +312,17 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
}
|
||||
|
||||
// Create client state
|
||||
client := &tcpClient{conn: c}
|
||||
|
||||
// Initialize TLS bridge if enabled
|
||||
if s.source.tlsManager != nil {
|
||||
tlsConfig := s.source.tlsManager.GetTCPConfig()
|
||||
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
|
||||
client.tlsBridge.Handshake() // Start async handshake
|
||||
|
||||
s.source.logger.Debug("msg", "TLS handshake initiated",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", remoteAddr)
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
authenticated: s.source.authenticator == nil,
|
||||
}
|
||||
|
||||
if s.source.authenticator != nil {
|
||||
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||
}
|
||||
|
||||
// Create client state
|
||||
s.mu.Lock()
|
||||
s.clients[c] = &tcpClient{conn: c}
|
||||
s.clients[c] = client
|
||||
s.mu.Unlock()
|
||||
|
||||
newCount := s.source.activeConns.Add(1)
|
||||
@ -362,7 +330,12 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
"component", "tcp_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"active_connections", newCount,
|
||||
"tls_enabled", s.source.tlsManager != nil)
|
||||
"requires_auth", s.source.authenticator != nil)
|
||||
|
||||
// Send auth challenge if required
|
||||
if s.source.authenticator != nil {
|
||||
return []byte("AUTH_REQUIRED\n"), gnet.None
|
||||
}
|
||||
|
||||
return nil, gnet.None
|
||||
}
|
||||
@ -372,18 +345,9 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
|
||||
// Remove client state
|
||||
s.mu.Lock()
|
||||
client := s.clients[c]
|
||||
delete(s.clients, c)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Clean up TLS bridge if present
|
||||
if client != nil && client.tlsBridge != nil {
|
||||
client.tlsBridge.Close()
|
||||
s.source.logger.Debug("msg", "TLS connection closed",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", remoteAddr)
|
||||
}
|
||||
|
||||
// Remove connection tracking
|
||||
if s.source.netLimiter != nil {
|
||||
s.source.netLimiter.RemoveConnection(remoteAddr)
|
||||
@ -416,79 +380,64 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Check encrypted data size BEFORE processing through TLS
|
||||
if len(data) > maxEncryptedDataPerRead {
|
||||
s.source.logger.Warn("msg", "Encrypted data per read limit exceeded",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"data_size", len(data),
|
||||
"limit", maxEncryptedDataPerRead)
|
||||
s.source.invalidEntries.Add(1)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Track cumulative encrypted data to prevent slow accumulation
|
||||
client.cumulativeEncrypted += int64(len(data))
|
||||
if client.cumulativeEncrypted > maxCumulativeEncrypted {
|
||||
s.source.logger.Warn("msg", "Cumulative encrypted data limit exceeded",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"total_encrypted", client.cumulativeEncrypted,
|
||||
"limit", maxCumulativeEncrypted)
|
||||
s.source.invalidEntries.Add(1)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Process through TLS bridge if present
|
||||
if client.tlsBridge != nil {
|
||||
// Feed encrypted data into TLS engine
|
||||
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
|
||||
if errors.Is(err, tls.ErrTLSBackpressure) {
|
||||
s.source.logger.Warn("msg", "TLS backpressure, closing slow client",
|
||||
// Authentication phase
|
||||
if !client.authenticated {
|
||||
if time.Now().After(client.authTimeout) {
|
||||
s.source.logger.Warn("msg", "Authentication timeout",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
} else {
|
||||
s.source.logger.Error("msg", "TLS processing error",
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
client.buffer.Write(data)
|
||||
|
||||
// Look for auth line
|
||||
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 {
|
||||
line := client.buffer.Bytes()[:idx]
|
||||
client.buffer.Next(idx + 1)
|
||||
|
||||
parts := strings.SplitN(string(line), " ", 3)
|
||||
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
session, err := s.source.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String())
|
||||
if err != nil {
|
||||
s.source.authFailures.Add(1)
|
||||
s.source.logger.Warn("msg", "Authentication failed",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", err)
|
||||
}
|
||||
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Check if handshake is complete
|
||||
if !client.tlsBridge.IsHandshakeDone() {
|
||||
// Still handshaking, wait for more data
|
||||
return gnet.None
|
||||
}
|
||||
s.source.authSuccesses.Add(1)
|
||||
s.mu.Lock()
|
||||
client.authenticated = true
|
||||
client.session = session
|
||||
s.mu.Unlock()
|
||||
|
||||
// Check handshake result
|
||||
_, hsErr := client.tlsBridge.HandshakeComplete()
|
||||
if hsErr != nil {
|
||||
s.source.logger.Error("msg", "TLS handshake failed",
|
||||
s.source.logger.Info("msg", "TCP client authenticated",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", hsErr)
|
||||
return gnet.Close
|
||||
}
|
||||
"username", session.Username)
|
||||
|
||||
// Read decrypted plaintext
|
||||
data = client.tlsBridge.Read()
|
||||
if data == nil || len(data) == 0 {
|
||||
// No plaintext available yet
|
||||
c.AsyncWrite([]byte("AUTH_OK\n"), nil)
|
||||
client.buffer.Reset()
|
||||
}
|
||||
return gnet.None
|
||||
}
|
||||
// Reset cumulative counter after successful decryption and processing
|
||||
client.cumulativeEncrypted = 0
|
||||
}
|
||||
|
||||
// Check buffer size before appending
|
||||
// Check if appending the new data would exceed the client buffer limit.
|
||||
if client.buffer.Len()+len(data) > maxClientBufferSize {
|
||||
s.source.logger.Warn("msg", "Client buffer limit exceeded",
|
||||
s.source.logger.Warn("msg", "Client buffer limit exceeded, closing connection.",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"buffer_size", client.buffer.Len(),
|
||||
"incoming_size", len(data))
|
||||
"incoming_size", len(data),
|
||||
"limit", maxClientBufferSize)
|
||||
s.source.invalidEntries.Add(1)
|
||||
return gnet.Close
|
||||
}
|
||||
@ -573,12 +522,22 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// noopLogger implements gnet's Logger interface but discards everything
|
||||
// type noopLogger struct{}
|
||||
// func (n noopLogger) Debugf(format string, args ...any) {}
|
||||
// func (n noopLogger) Infof(format string, args ...any) {}
|
||||
// func (n noopLogger) Warnf(format string, args ...any) {}
|
||||
// func (n noopLogger) Errorf(format string, args ...any) {}
|
||||
// func (n noopLogger) Fatalf(format string, args ...any) {}
|
||||
// Configure TCP source auth
|
||||
func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
// Usage: gnet.Run(..., gnet.WithLogger(noopLogger{}), ...)
|
||||
authenticator, err := auth.New(authCfg, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to initialize authenticator for TCP source",
|
||||
"component", "tcp_source",
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
t.authenticator = authenticator
|
||||
|
||||
t.logger.Info("msg", "Authentication configured for TCP source",
|
||||
"component", "tcp_source",
|
||||
"auth_type", authCfg.Type)
|
||||
}
|
||||
@ -1,341 +0,0 @@
|
||||
// FILE: src/internal/tls/gnet_bridge.go
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTLSBackpressure = errors.New("TLS processing backpressure")
|
||||
ErrConnectionClosed = errors.New("connection closed")
|
||||
ErrPlaintextBufferExceeded = errors.New("plaintext buffer size exceeded")
|
||||
)
|
||||
|
||||
// Maximum plaintext buffer size to prevent memory exhaustion
|
||||
const maxPlaintextBufferSize = 32 * 1024 * 1024 // 32MB
|
||||
|
||||
// Bridges gnet.Conn with crypto/tls via io.Pipe
|
||||
type GNetTLSConn struct {
|
||||
gnetConn gnet.Conn
|
||||
tlsConn *tls.Conn
|
||||
config *tls.Config
|
||||
|
||||
// Buffered channels for non-blocking operation
|
||||
incomingCipher chan []byte // Network → TLS (encrypted)
|
||||
outgoingCipher chan []byte // TLS → Network (encrypted)
|
||||
|
||||
// Handshake state
|
||||
handshakeOnce sync.Once
|
||||
handshakeDone chan struct{}
|
||||
handshakeErr error
|
||||
|
||||
// Decrypted data buffer
|
||||
plainBuf []byte
|
||||
plainMu sync.Mutex
|
||||
|
||||
// Lifecycle
|
||||
closed atomic.Bool
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Error tracking
|
||||
lastErr atomic.Value // error
|
||||
logger interface{ Warn(args ...any) } // Minimal logger interface
|
||||
}
|
||||
|
||||
// Creates a server-side TLS bridge
|
||||
func NewServerConn(gnetConn gnet.Conn, config *tls.Config) *GNetTLSConn {
|
||||
tc := &GNetTLSConn{
|
||||
gnetConn: gnetConn,
|
||||
config: config,
|
||||
handshakeDone: make(chan struct{}),
|
||||
// Buffered channels sized for throughput without blocking
|
||||
incomingCipher: make(chan []byte, 128), // 128 packets buffer
|
||||
outgoingCipher: make(chan []byte, 128),
|
||||
plainBuf: make([]byte, 0, 65536), // 64KB initial capacity
|
||||
}
|
||||
|
||||
// Create TLS conn with channel-based transport
|
||||
rawConn := &channelConn{
|
||||
incoming: tc.incomingCipher,
|
||||
outgoing: tc.outgoingCipher,
|
||||
localAddr: gnetConn.LocalAddr(),
|
||||
remoteAddr: gnetConn.RemoteAddr(),
|
||||
tc: tc,
|
||||
}
|
||||
tc.tlsConn = tls.Server(rawConn, config)
|
||||
|
||||
// Start pump goroutines
|
||||
tc.wg.Add(2)
|
||||
go tc.pumpCipherToNetwork()
|
||||
go tc.pumpPlaintextFromTLS()
|
||||
|
||||
return tc
|
||||
}
|
||||
|
||||
// Creates a client-side TLS bridge (similar changes)
|
||||
func NewClientConn(gnetConn gnet.Conn, config *tls.Config, serverName string) *GNetTLSConn {
|
||||
tc := &GNetTLSConn{
|
||||
gnetConn: gnetConn,
|
||||
config: config,
|
||||
handshakeDone: make(chan struct{}),
|
||||
incomingCipher: make(chan []byte, 128),
|
||||
outgoingCipher: make(chan []byte, 128),
|
||||
plainBuf: make([]byte, 0, 65536),
|
||||
}
|
||||
|
||||
if config.ServerName == "" {
|
||||
config = config.Clone()
|
||||
config.ServerName = serverName
|
||||
}
|
||||
|
||||
rawConn := &channelConn{
|
||||
incoming: tc.incomingCipher,
|
||||
outgoing: tc.outgoingCipher,
|
||||
localAddr: gnetConn.LocalAddr(),
|
||||
remoteAddr: gnetConn.RemoteAddr(),
|
||||
tc: tc,
|
||||
}
|
||||
tc.tlsConn = tls.Client(rawConn, config)
|
||||
|
||||
tc.wg.Add(2)
|
||||
go tc.pumpCipherToNetwork()
|
||||
go tc.pumpPlaintextFromTLS()
|
||||
|
||||
return tc
|
||||
}
|
||||
|
||||
// Feeds encrypted data from network into TLS engine (non-blocking)
|
||||
func (tc *GNetTLSConn) ProcessIncoming(encryptedData []byte) error {
|
||||
if tc.closed.Load() {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
// Non-blocking send with backpressure detection
|
||||
select {
|
||||
case tc.incomingCipher <- encryptedData:
|
||||
return nil
|
||||
default:
|
||||
// Channel full - TLS processing can't keep up
|
||||
// Drop connection under backpressure vs blocking event loop
|
||||
if tc.logger != nil {
|
||||
tc.logger.Warn("msg", "TLS backpressure, dropping data",
|
||||
"remote_addr", tc.gnetConn.RemoteAddr())
|
||||
}
|
||||
return ErrTLSBackpressure
|
||||
}
|
||||
}
|
||||
|
||||
// Sends TLS-encrypted data to network
|
||||
func (tc *GNetTLSConn) pumpCipherToNetwork() {
|
||||
defer tc.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case data, ok := <-tc.outgoingCipher:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Send to network
|
||||
if err := tc.gnetConn.AsyncWrite(data, nil); err != nil {
|
||||
tc.lastErr.Store(err)
|
||||
tc.Close()
|
||||
return
|
||||
}
|
||||
case <-time.After(30 * time.Second):
|
||||
// Keepalive/timeout check
|
||||
if tc.closed.Load() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reads decrypted data from TLS
|
||||
func (tc *GNetTLSConn) pumpPlaintextFromTLS() {
|
||||
defer tc.wg.Done()
|
||||
buf := make([]byte, 32768) // 32KB read buffer
|
||||
|
||||
for {
|
||||
n, err := tc.tlsConn.Read(buf)
|
||||
if n > 0 {
|
||||
tc.plainMu.Lock()
|
||||
// Check buffer size limit before appending to prevent memory exhaustion
|
||||
if len(tc.plainBuf)+n > maxPlaintextBufferSize {
|
||||
tc.plainMu.Unlock()
|
||||
// Log warning about buffer limit
|
||||
if tc.logger != nil {
|
||||
tc.logger.Warn("msg", "Plaintext buffer limit exceeded, closing connection",
|
||||
"remote_addr", tc.gnetConn.RemoteAddr(),
|
||||
"buffer_size", len(tc.plainBuf),
|
||||
"incoming_size", n,
|
||||
"limit", maxPlaintextBufferSize)
|
||||
}
|
||||
// Store error and close connection
|
||||
tc.lastErr.Store(ErrPlaintextBufferExceeded)
|
||||
tc.Close()
|
||||
return
|
||||
}
|
||||
tc.plainBuf = append(tc.plainBuf, buf[:n]...)
|
||||
tc.plainMu.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
tc.lastErr.Store(err)
|
||||
}
|
||||
tc.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns available decrypted plaintext (non-blocking)
|
||||
func (tc *GNetTLSConn) Read() []byte {
|
||||
tc.plainMu.Lock()
|
||||
defer tc.plainMu.Unlock()
|
||||
|
||||
if len(tc.plainBuf) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Atomic buffer swap under mutex protection to prevent race condition
|
||||
data := tc.plainBuf
|
||||
tc.plainBuf = make([]byte, 0, cap(tc.plainBuf))
|
||||
return data
|
||||
}
|
||||
|
||||
// Encrypts plaintext and queues for network transmission
|
||||
func (tc *GNetTLSConn) Write(plaintext []byte) (int, error) {
|
||||
if tc.closed.Load() {
|
||||
return 0, ErrConnectionClosed
|
||||
}
|
||||
|
||||
if !tc.IsHandshakeDone() {
|
||||
return 0, errors.New("handshake not complete")
|
||||
}
|
||||
|
||||
return tc.tlsConn.Write(plaintext)
|
||||
}
|
||||
|
||||
// Initiates TLS handshake asynchronously
|
||||
func (tc *GNetTLSConn) Handshake() {
|
||||
tc.handshakeOnce.Do(func() {
|
||||
go func() {
|
||||
tc.handshakeErr = tc.tlsConn.Handshake()
|
||||
close(tc.handshakeDone)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// Checks if handshake is complete
|
||||
func (tc *GNetTLSConn) IsHandshakeDone() bool {
|
||||
select {
|
||||
case <-tc.handshakeDone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Waits for handshake completion
|
||||
func (tc *GNetTLSConn) HandshakeComplete() (<-chan struct{}, error) {
|
||||
<-tc.handshakeDone
|
||||
return tc.handshakeDone, tc.handshakeErr
|
||||
}
|
||||
|
||||
// Shuts down the bridge
|
||||
func (tc *GNetTLSConn) Close() error {
|
||||
tc.closeOnce.Do(func() {
|
||||
tc.closed.Store(true)
|
||||
|
||||
// Close TLS connection
|
||||
tc.tlsConn.Close()
|
||||
|
||||
// Close channels to stop pumps
|
||||
close(tc.incomingCipher)
|
||||
close(tc.outgoingCipher)
|
||||
})
|
||||
|
||||
// Wait for pumps to finish
|
||||
tc.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns TLS connection state
|
||||
func (tc *GNetTLSConn) GetConnectionState() tls.ConnectionState {
|
||||
return tc.tlsConn.ConnectionState()
|
||||
}
|
||||
|
||||
// Returns last error
|
||||
func (tc *GNetTLSConn) GetError() error {
|
||||
if err, ok := tc.lastErr.Load().(error); ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Implements net.Conn over channels
|
||||
type channelConn struct {
|
||||
incoming <-chan []byte
|
||||
outgoing chan<- []byte
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
tc *GNetTLSConn
|
||||
readBuf []byte
|
||||
}
|
||||
|
||||
func (c *channelConn) Read(b []byte) (int, error) {
|
||||
// Use buffered read for efficiency
|
||||
if len(c.readBuf) > 0 {
|
||||
n := copy(b, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Wait for new data
|
||||
select {
|
||||
case data, ok := <-c.incoming:
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(b, data)
|
||||
if n < len(data) {
|
||||
c.readBuf = data[n:] // Buffer remainder
|
||||
}
|
||||
return n, nil
|
||||
case <-time.After(30 * time.Second):
|
||||
return 0, errors.New("read timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channelConn) Write(b []byte) (int, error) {
|
||||
if c.tc.closed.Load() {
|
||||
return 0, ErrConnectionClosed
|
||||
}
|
||||
|
||||
// Make a copy since TLS may hold reference
|
||||
data := make([]byte, len(b))
|
||||
copy(data, b)
|
||||
|
||||
select {
|
||||
case c.outgoing <- data:
|
||||
return len(b), nil
|
||||
case <-time.After(5 * time.Second):
|
||||
return 0, errors.New("write timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channelConn) Close() error { return nil }
|
||||
func (c *channelConn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *channelConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *channelConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *channelConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *channelConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
@ -117,18 +117,6 @@ func (m *Manager) GetHTTPConfig() *tls.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Returns TLS config for raw TCP connections
|
||||
func (m *Manager) GetTCPConfig() *tls.Config {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg := m.tlsConfig.Clone()
|
||||
// No ALPN for raw TCP
|
||||
cfg.NextProtos = nil
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Validates a client certificate for mTLS
|
||||
func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
|
||||
if m == nil || !m.config.ClientAuth {
|
||||
@ -174,6 +162,21 @@ func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns TLS statistics
|
||||
func (m *Manager) GetStats() map[string]any {
|
||||
if m == nil {
|
||||
return map[string]any{"enabled": false}
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"enabled": true,
|
||||
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
|
||||
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
|
||||
"client_auth": m.config.ClientAuth,
|
||||
"cipher_suites": len(m.tlsConfig.CipherSuites),
|
||||
}
|
||||
}
|
||||
|
||||
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
|
||||
switch strings.ToUpper(version) {
|
||||
case "TLS1.0", "TLS10":
|
||||
@ -217,21 +220,6 @@ func parseCipherSuites(suites string) []uint16 {
|
||||
return result
|
||||
}
|
||||
|
||||
// Returns TLS statistics
|
||||
func (m *Manager) GetStats() map[string]any {
|
||||
if m == nil {
|
||||
return map[string]any{"enabled": false}
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"enabled": true,
|
||||
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
|
||||
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
|
||||
"client_auth": m.config.ClientAuth,
|
||||
"cipher_suites": len(m.tlsConfig.CipherSuites),
|
||||
}
|
||||
}
|
||||
|
||||
func tlsVersionString(version uint16) string {
|
||||
switch version {
|
||||
case tls.VersionTLS10:
|
||||
|
||||
Reference in New Issue
Block a user