v0.5.0 removed tcp tls, basic auth hash changed to argon2, refactor
This commit is contained in:
@ -37,7 +37,6 @@ type = "directory"
|
|||||||
path = "./" # Directory to monitor
|
path = "./" # Directory to monitor
|
||||||
pattern = "*.log" # Glob pattern
|
pattern = "*.log" # Glob pattern
|
||||||
check_interval_ms = 100 # Scan interval
|
check_interval_ms = 100 # Scan interval
|
||||||
read_from_beginning = false # Start position
|
|
||||||
|
|
||||||
### Console Sources
|
### Console Sources
|
||||||
# [[pipelines.sources]]
|
# [[pipelines.sources]]
|
||||||
@ -74,11 +73,10 @@ read_from_beginning = false # Start position
|
|||||||
# ip_blacklist = [] # Blocked IPs/CIDRs
|
# ip_blacklist = [] # Blocked IPs/CIDRs
|
||||||
# requests_per_second = 100.0 # Rate limit per client
|
# requests_per_second = 100.0 # Rate limit per client
|
||||||
# burst_size = 100 # Burst capacity
|
# burst_size = 100 # Burst capacity
|
||||||
# limit_by = "ip" # ip|user|token|global
|
|
||||||
# response_code = 429 # HTTP status when limited
|
# response_code = 429 # HTTP status when limited
|
||||||
# response_message = "Rate limit exceeded"
|
# response_message = "Rate limit exceeded"
|
||||||
# max_connections_per_ip = 10 # Max concurrent per IP
|
# max_connections_per_ip = 10 # Max concurrent per IP
|
||||||
# max_total_connections = 1000 # Max total connections
|
# max_connections_total = 1000 # Max total connections
|
||||||
|
|
||||||
### TCP Sources
|
### TCP Sources
|
||||||
# [[pipelines.sources]]
|
# [[pipelines.sources]]
|
||||||
@ -88,30 +86,18 @@ read_from_beginning = false # Start position
|
|||||||
# host = "0.0.0.0" # Listen address
|
# host = "0.0.0.0" # Listen address
|
||||||
# port = 9091 # Listen port
|
# port = 9091 # Listen port
|
||||||
|
|
||||||
# [pipelines.sources.options.tls]
|
|
||||||
# enabled = false # Enable TLS
|
|
||||||
# cert_file = "" # TLS certificate
|
|
||||||
# key_file = "" # TLS key
|
|
||||||
# client_auth = false # Require client certs
|
|
||||||
# client_ca_file = "" # Client CA cert
|
|
||||||
# verify_client_cert = false # Verify client certs
|
|
||||||
# insecure_skip_verify = false # Skip verification
|
|
||||||
# ca_file = "" # Custom CA file
|
|
||||||
# min_version = "TLS1.2" # Min TLS version
|
|
||||||
# max_version = "TLS1.3" # Max TLS version
|
|
||||||
# cipher_suites = "" # Comma-separated list
|
|
||||||
|
|
||||||
# [pipelines.sources.options.net_limit]
|
# [pipelines.sources.options.net_limit]
|
||||||
# enabled = false # Enable rate limiting
|
# enabled = false # Enable rate limiting
|
||||||
# ip_whitelist = [] # Allowed IPs/CIDRs
|
# ip_whitelist = [] # Allowed IPs/CIDRs
|
||||||
# ip_blacklist = [] # Blocked IPs/CIDRs
|
# ip_blacklist = [] # Blocked IPs/CIDRs
|
||||||
# requests_per_second = 100.0 # Rate limit per client
|
# requests_per_second = 100.0 # Rate limit per client
|
||||||
# burst_size = 100 # Burst capacity
|
# burst_size = 100 # Burst capacity
|
||||||
# limit_by = "ip" # ip|user|token|global
|
|
||||||
# response_code = 429 # Response code when limited
|
# response_code = 429 # Response code when limited
|
||||||
# response_message = "Rate limit exceeded"
|
# response_message = "Rate limit exceeded"
|
||||||
# max_connections_per_ip = 10 # Max concurrent per IP
|
# max_connections_per_ip = 10 # Max concurrent per IP
|
||||||
# max_total_connections = 1000 # Max total connections
|
# max_connections_per_user = 10 # Max concurrent per user
|
||||||
|
# max_connections_per_token = 10 # Max concurrent per token
|
||||||
|
# max_connections_total = 1000 # Max total connections
|
||||||
|
|
||||||
### Rate limiting
|
### Rate limiting
|
||||||
# [pipelines.rate_limit]
|
# [pipelines.rate_limit]
|
||||||
@ -126,6 +112,21 @@ read_from_beginning = false # Start position
|
|||||||
# logic = "or" # or|and
|
# logic = "or" # or|and
|
||||||
# patterns = [] # Regex patterns
|
# patterns = [] # Regex patterns
|
||||||
|
|
||||||
|
## Examples of filter patterns:
|
||||||
|
## Include only error or fatal logs containing "database":
|
||||||
|
## type = "include"
|
||||||
|
## logic = "and"
|
||||||
|
## patterns = ["(?i)(error|fatal)", "database"]
|
||||||
|
##
|
||||||
|
## Exclude debug logs from test environment:
|
||||||
|
## type = "exclude"
|
||||||
|
## logic = "or"
|
||||||
|
## patterns = ["(?i)debug", "test-env"]
|
||||||
|
##
|
||||||
|
## Include only JSON formatted logs:
|
||||||
|
## type = "include"
|
||||||
|
## patterns = ["^\\{.*\\}$"]
|
||||||
|
|
||||||
### Format
|
### Format
|
||||||
|
|
||||||
### Raw formatter (default)
|
### Raw formatter (default)
|
||||||
@ -174,7 +175,6 @@ format = "comment" # comment|message
|
|||||||
# verify_client_cert = false # Verify client certs
|
# verify_client_cert = false # Verify client certs
|
||||||
# insecure_skip_verify = false # Skip verification
|
# insecure_skip_verify = false # Skip verification
|
||||||
# ca_file = "" # Custom CA file
|
# ca_file = "" # Custom CA file
|
||||||
# server_name = "" # Expected server name
|
|
||||||
# min_version = "TLS1.2" # Min TLS version
|
# min_version = "TLS1.2" # Min TLS version
|
||||||
# max_version = "TLS1.3" # Max TLS version
|
# max_version = "TLS1.3" # Max TLS version
|
||||||
# cipher_suites = "" # Comma-separated list
|
# cipher_suites = "" # Comma-separated list
|
||||||
@ -185,11 +185,10 @@ format = "comment" # comment|message
|
|||||||
# ip_blacklist = [] # Blocked IPs/CIDRs
|
# ip_blacklist = [] # Blocked IPs/CIDRs
|
||||||
# requests_per_second = 100.0 # Rate limit per client
|
# requests_per_second = 100.0 # Rate limit per client
|
||||||
# burst_size = 100 # Burst capacity
|
# burst_size = 100 # Burst capacity
|
||||||
# limit_by = "ip" # ip|user|token|global
|
|
||||||
# response_code = 429 # HTTP status when limited
|
# response_code = 429 # HTTP status when limited
|
||||||
# response_message = "Rate limit exceeded"
|
# response_message = "Rate limit exceeded"
|
||||||
# max_connections_per_ip = 10 # Max concurrent per IP
|
# max_connections_per_ip = 10 # Max concurrent per IP
|
||||||
# max_total_connections = 1000 # Max total connections
|
# max_connections_total = 1000 # Max total connections
|
||||||
|
|
||||||
### TCP Sinks
|
### TCP Sinks
|
||||||
# [[pipelines.sinks]]
|
# [[pipelines.sinks]]
|
||||||
@ -207,31 +206,18 @@ format = "comment" # comment|message
|
|||||||
# include_stats = false # Include statistics
|
# include_stats = false # Include statistics
|
||||||
# format = "comment" # comment|message
|
# format = "comment" # comment|message
|
||||||
|
|
||||||
# [pipelines.sinks.options.tls]
|
|
||||||
# enabled = false # Enable TLS
|
|
||||||
# cert_file = "" # TLS certificate
|
|
||||||
# key_file = "" # TLS key
|
|
||||||
# client_auth = false # Require client certs
|
|
||||||
# client_ca_file = "" # Client CA cert
|
|
||||||
# verify_client_cert = false # Verify client certs
|
|
||||||
# insecure_skip_verify = false # Skip verification
|
|
||||||
# ca_file = "" # Custom CA file
|
|
||||||
# server_name = "" # Expected server name
|
|
||||||
# min_version = "TLS1.2" # Min TLS version
|
|
||||||
# max_version = "TLS1.3" # Max TLS version
|
|
||||||
# cipher_suites = "" # Comma-separated list
|
|
||||||
|
|
||||||
# [pipelines.sinks.options.net_limit]
|
# [pipelines.sinks.options.net_limit]
|
||||||
# enabled = false # Enable rate limiting
|
# enabled = false # Enable rate limiting
|
||||||
# ip_whitelist = [] # Allowed IPs/CIDRs
|
# ip_whitelist = [] # Allowed IPs/CIDRs
|
||||||
# ip_blacklist = [] # Blocked IPs/CIDRs
|
# ip_blacklist = [] # Blocked IPs/CIDRs
|
||||||
# requests_per_second = 100.0 # Rate limit per client
|
# requests_per_second = 100.0 # Rate limit per client
|
||||||
# burst_size = 100 # Burst capacity
|
# burst_size = 100 # Burst capacity
|
||||||
# limit_by = "ip" # ip|user|token|global
|
|
||||||
# response_code = 429 # HTTP status when limited
|
# response_code = 429 # HTTP status when limited
|
||||||
# response_message = "Rate limit exceeded"
|
# response_message = "Rate limit exceeded"
|
||||||
# max_connections_per_ip = 10 # Max concurrent per IP
|
# max_connections_per_ip = 10 # Max concurrent per IP
|
||||||
# max_total_connections = 1000 # Max total connections
|
# max_connections_per_user = 10 # Max concurrent per user
|
||||||
|
# max_connections_per_token = 10 # Max concurrent per token
|
||||||
|
# max_connections_total = 1000 # Max total connections
|
||||||
|
|
||||||
### HTTP Client Sinks
|
### HTTP Client Sinks
|
||||||
# [[pipelines.sinks]]
|
# [[pipelines.sinks]]
|
||||||
@ -283,6 +269,7 @@ format = "comment" # comment|message
|
|||||||
# [pipelines.sinks.options]
|
# [pipelines.sinks.options]
|
||||||
# directory = "" # Output dir (required)
|
# directory = "" # Output dir (required)
|
||||||
# name = "" # Base name (required)
|
# name = "" # Base name (required)
|
||||||
|
# buffer_size = 1000 # Input channel buffer size
|
||||||
# max_size_mb = 100 # Rotation size
|
# max_size_mb = 100 # Rotation size
|
||||||
# max_total_size_mb = 0 # Total limit (0=unlimited)
|
# max_total_size_mb = 0 # Total limit (0=unlimited)
|
||||||
# retention_hours = 0.0 # Retention (0=disabled)
|
# retention_hours = 0.0 # Retention (0=disabled)
|
||||||
|
|||||||
6
go.mod
6
go.mod
@ -5,9 +5,9 @@ go 1.25.1
|
|||||||
require (
|
require (
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3
|
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3
|
||||||
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208
|
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e
|
||||||
github.com/panjf2000/gnet/v2 v2.9.3
|
github.com/panjf2000/gnet/v2 v2.9.4
|
||||||
github.com/valyala/fasthttp v1.65.0
|
github.com/valyala/fasthttp v1.66.0
|
||||||
golang.org/x/crypto v0.42.0
|
golang.org/x/crypto v0.42.0
|
||||||
golang.org/x/term v0.35.0
|
golang.org/x/term v0.35.0
|
||||||
golang.org/x/time v0.13.0
|
golang.org/x/time v0.13.0
|
||||||
|
|||||||
12
go.sum
12
go.sum
@ -12,20 +12,20 @@ github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zt
|
|||||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0=
|
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0=
|
||||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
|
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
|
||||||
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208 h1:IB1O/HLv9VR/4mL1Tkjlr91lk+r8anP6bab7rYdS/oE=
|
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e h1:/XWCqFdSOiUf0/a5a63GHsvEdpglsYfn3qieNxTeyDc=
|
||||||
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
|
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
|
||||||
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
|
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
|
||||||
github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
|
github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
|
||||||
github.com/panjf2000/gnet/v2 v2.9.3 h1:auV3/A9Na3jiBDmYAAU00rPhFKnsAI+TnI1F7YUJMHQ=
|
github.com/panjf2000/gnet/v2 v2.9.4 h1:XvPCcaFwO4XWg4IgSfZnNV4dfDy5g++HIEx7sH0ldHc=
|
||||||
github.com/panjf2000/gnet/v2 v2.9.3/go.mod h1:WQTxDWYuQ/hz3eccH0FN32IVuvZ19HewEWx0l62fx7E=
|
github.com/panjf2000/gnet/v2 v2.9.4/go.mod h1:WQTxDWYuQ/hz3eccH0FN32IVuvZ19HewEWx0l62fx7E=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8=
|
github.com/valyala/fasthttp v1.66.0 h1:M87A0Z7EayeyNaV6pfO3tUTUiYO0dZfEJnRGXTVNuyU=
|
||||||
github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4=
|
github.com/valyala/fasthttp v1.66.0/go.mod h1:Y4eC+zwoocmXSVCB1JmhNbYtS7tZPRI2ztPB72EVObs=
|
||||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
|
|||||||
@ -60,7 +60,7 @@ func initializeLogger(cfg *config.Config) error {
|
|||||||
// In quiet mode, disable ALL logging output
|
// In quiet mode, disable ALL logging output
|
||||||
logCfg.Level = 255 // A level that disables all output
|
logCfg.Level = 255 // A level that disables all output
|
||||||
logCfg.DisableFile = true
|
logCfg.DisableFile = true
|
||||||
logCfg.EnableStdout = false
|
logCfg.EnableConsole = false
|
||||||
return logger.ApplyConfig(logCfg)
|
return logger.ApplyConfig(logCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,29 +74,24 @@ func initializeLogger(cfg *config.Config) error {
|
|||||||
// Configure based on output mode
|
// Configure based on output mode
|
||||||
switch cfg.Logging.Output {
|
switch cfg.Logging.Output {
|
||||||
case "none":
|
case "none":
|
||||||
logCfg.EnableStdout = false
|
logCfg.EnableConsole = false
|
||||||
case "stdout":
|
case "stdout":
|
||||||
logCfg.EnableStdout = true
|
logCfg.EnableConsole = true
|
||||||
logCfg.StdoutTarget = "stdout"
|
logCfg.ConsoleTarget = "stdout"
|
||||||
case "stderr":
|
case "stderr":
|
||||||
logCfg.EnableStdout = true
|
logCfg.EnableConsole = true
|
||||||
logCfg.StdoutTarget = "stderr"
|
logCfg.ConsoleTarget = "stderr"
|
||||||
case "split":
|
case "split":
|
||||||
logCfg.EnableStdout = true
|
logCfg.EnableConsole = true
|
||||||
logCfg.StdoutTarget = "split"
|
logCfg.ConsoleTarget = "split"
|
||||||
case "file":
|
case "file":
|
||||||
logCfg.DisableFile = false
|
logCfg.DisableFile = false
|
||||||
logCfg.EnableStdout = false
|
logCfg.EnableConsole = false
|
||||||
configureFileLogging(logCfg, cfg)
|
|
||||||
case "both":
|
|
||||||
logCfg.DisableFile = false
|
|
||||||
logCfg.EnableStdout = true
|
|
||||||
logCfg.StdoutTarget = "stdout"
|
|
||||||
configureFileLogging(logCfg, cfg)
|
configureFileLogging(logCfg, cfg)
|
||||||
case "all":
|
case "all":
|
||||||
logCfg.DisableFile = false
|
logCfg.DisableFile = false
|
||||||
logCfg.EnableStdout = true
|
logCfg.EnableConsole = true
|
||||||
logCfg.StdoutTarget = "split"
|
logCfg.ConsoleTarget = "split"
|
||||||
configureFileLogging(logCfg, cfg)
|
configureFileLogging(logCfg, cfg)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid log output mode: %s", cfg.Logging.Output)
|
return fmt.Errorf("invalid log output mode: %s", cfg.Logging.Output)
|
||||||
|
|||||||
@ -127,13 +127,13 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
|||||||
"listen", fmt.Sprintf("%s:%d", host, port))
|
"listen", fmt.Sprintf("%s:%d", host, port))
|
||||||
|
|
||||||
// Display net limit info if configured
|
// Display net limit info if configured
|
||||||
if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
||||||
if enabled, ok := rl["enabled"].(bool); ok && enabled {
|
if enabled, ok := nl["enabled"].(bool); ok && enabled {
|
||||||
logger.Info("msg", "TCP net limiting enabled",
|
logger.Info("msg", "TCP net limiting enabled",
|
||||||
"pipeline", cfg.Name,
|
"pipeline", cfg.Name,
|
||||||
"sink_index", i,
|
"sink_index", i,
|
||||||
"requests_per_second", rl["requests_per_second"],
|
"requests_per_second", nl["requests_per_second"],
|
||||||
"burst_size", rl["burst_size"])
|
"burst_size", nl["burst_size"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -162,14 +162,13 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
|||||||
"status_url", fmt.Sprintf("http://%s:%d%s", host, port, statusPath))
|
"status_url", fmt.Sprintf("http://%s:%d%s", host, port, statusPath))
|
||||||
|
|
||||||
// Display net limit info if configured
|
// Display net limit info if configured
|
||||||
if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
||||||
if enabled, ok := rl["enabled"].(bool); ok && enabled {
|
if enabled, ok := nl["enabled"].(bool); ok && enabled {
|
||||||
logger.Info("msg", "HTTP net limiting enabled",
|
logger.Info("msg", "HTTP net limiting enabled",
|
||||||
"pipeline", cfg.Name,
|
"pipeline", cfg.Name,
|
||||||
"sink_index", i,
|
"sink_index", i,
|
||||||
"requests_per_second", rl["requests_per_second"],
|
"requests_per_second", nl["requests_per_second"],
|
||||||
"burst_size", rl["burst_size"],
|
"burst_size", nl["burst_size"])
|
||||||
"limit_by", rl["limit_by"])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,7 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -16,7 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/argon2"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -378,12 +379,14 @@ func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string)
|
|||||||
a.mu.RUnlock()
|
a.mu.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
// Perform bcrypt anyway to prevent timing attacks
|
// Perform argon2 anyway to prevent timing attacks
|
||||||
bcrypt.CompareHashAndPassword([]byte("$2a$10$dummy.hash.to.prevent.timing.attacks"), []byte(password))
|
dummySalt := make([]byte, 16)
|
||||||
|
argon2.IDKey([]byte(password), dummySalt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
|
||||||
return nil, fmt.Errorf("invalid credentials")
|
return nil, fmt.Errorf("invalid credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(expectedHash), []byte(password)); err != nil {
|
// Parse PHC format hash
|
||||||
|
if !verifyArgon2idHash(password, expectedHash) {
|
||||||
return nil, fmt.Errorf("invalid credentials")
|
return nil, fmt.Errorf("invalid credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -400,6 +403,43 @@ func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string)
|
|||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify Argon2id hashes
|
||||||
|
func verifyArgon2idHash(password, hash string) bool {
|
||||||
|
// Parse PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
|
||||||
|
parts := strings.Split(hash, "$")
|
||||||
|
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse version
|
||||||
|
var version int
|
||||||
|
fmt.Sscanf(parts[2], "v=%d", &version)
|
||||||
|
if version != argon2.Version {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parameters
|
||||||
|
var memory, time, threads uint32
|
||||||
|
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||||
|
|
||||||
|
// Decode salt and hash
|
||||||
|
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute hash
|
||||||
|
computedHash := argon2.IDKey([]byte(password), salt, time, memory, uint8(threads), uint32(len(expectedHash)))
|
||||||
|
|
||||||
|
// Constant time comparison
|
||||||
|
return subtle.ConstantTimeCompare(computedHash, expectedHash) == 1
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) {
|
func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) {
|
||||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
return nil, fmt.Errorf("invalid bearer auth header")
|
return nil, fmt.Errorf("invalid bearer auth header")
|
||||||
|
|||||||
@ -10,17 +10,24 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/argon2"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handles auth credential generation
|
// Argon2id parameters
|
||||||
|
const (
|
||||||
|
argon2Time = 3
|
||||||
|
argon2Memory = 64 * 1024 // 64 MB
|
||||||
|
argon2Threads = 4
|
||||||
|
argon2SaltLen = 16
|
||||||
|
argon2KeyLen = 32
|
||||||
|
)
|
||||||
|
|
||||||
type GeneratorCommand struct {
|
type GeneratorCommand struct {
|
||||||
output io.Writer
|
output io.Writer
|
||||||
errOut io.Writer
|
errOut io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new auth generator command handler
|
|
||||||
func NewGeneratorCommand() *GeneratorCommand {
|
func NewGeneratorCommand() *GeneratorCommand {
|
||||||
return &GeneratorCommand{
|
return &GeneratorCommand{
|
||||||
output: os.Stdout,
|
output: os.Stdout,
|
||||||
@ -28,7 +35,6 @@ func NewGeneratorCommand() *GeneratorCommand {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runs the auth generation command with provided arguments
|
|
||||||
func (g *GeneratorCommand) Execute(args []string) error {
|
func (g *GeneratorCommand) Execute(args []string) error {
|
||||||
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
|
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
|
||||||
cmd.SetOutput(g.errOut)
|
cmd.SetOutput(g.errOut)
|
||||||
@ -36,7 +42,6 @@ func (g *GeneratorCommand) Execute(args []string) error {
|
|||||||
var (
|
var (
|
||||||
username = cmd.String("u", "", "Username for basic auth")
|
username = cmd.String("u", "", "Username for basic auth")
|
||||||
password = cmd.String("p", "", "Password to hash (will prompt if not provided)")
|
password = cmd.String("p", "", "Password to hash (will prompt if not provided)")
|
||||||
cost = cmd.Int("c", 10, "Bcrypt cost (10-31)")
|
|
||||||
genToken = cmd.Bool("t", false, "Generate random bearer token")
|
genToken = cmd.Bool("t", false, "Generate random bearer token")
|
||||||
tokenLen = cmd.Int("l", 32, "Token length in bytes")
|
tokenLen = cmd.Int("l", 32, "Token length in bytes")
|
||||||
)
|
)
|
||||||
@ -45,7 +50,7 @@ func (g *GeneratorCommand) Execute(args []string) error {
|
|||||||
fmt.Fprintln(g.errOut, "Generate authentication credentials for LogWisp")
|
fmt.Fprintln(g.errOut, "Generate authentication credentials for LogWisp")
|
||||||
fmt.Fprintln(g.errOut, "\nUsage: logwisp auth [options]")
|
fmt.Fprintln(g.errOut, "\nUsage: logwisp auth [options]")
|
||||||
fmt.Fprintln(g.errOut, "\nExamples:")
|
fmt.Fprintln(g.errOut, "\nExamples:")
|
||||||
fmt.Fprintln(g.errOut, " # Generate bcrypt hash for user")
|
fmt.Fprintln(g.errOut, " # Generate Argon2id hash for user")
|
||||||
fmt.Fprintln(g.errOut, " logwisp auth -u admin")
|
fmt.Fprintln(g.errOut, " logwisp auth -u admin")
|
||||||
fmt.Fprintln(g.errOut, " ")
|
fmt.Fprintln(g.errOut, " ")
|
||||||
fmt.Fprintln(g.errOut, " # Generate 64-byte bearer token")
|
fmt.Fprintln(g.errOut, " # Generate 64-byte bearer token")
|
||||||
@ -58,26 +63,19 @@ func (g *GeneratorCommand) Execute(args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token generation mode
|
|
||||||
if *genToken {
|
if *genToken {
|
||||||
return g.generateToken(*tokenLen)
|
return g.generateToken(*tokenLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Password hash generation mode
|
|
||||||
if *username == "" {
|
if *username == "" {
|
||||||
cmd.Usage()
|
cmd.Usage()
|
||||||
return fmt.Errorf("username required for password hash generation")
|
return fmt.Errorf("username required for password hash generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
return g.generatePasswordHash(*username, *password, *cost)
|
return g.generatePasswordHash(*username, *password)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GeneratorCommand) generatePasswordHash(username, password string, cost int) error {
|
func (g *GeneratorCommand) generatePasswordHash(username, password string) error {
|
||||||
// Validate cost
|
|
||||||
if cost < 10 || cost > 31 {
|
|
||||||
return fmt.Errorf("bcrypt cost must be between 10 and 31")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get password if not provided
|
// Get password if not provided
|
||||||
if password == "" {
|
if password == "" {
|
||||||
pass1 := g.promptPassword("Enter password: ")
|
pass1 := g.promptPassword("Enter password: ")
|
||||||
@ -88,20 +86,29 @@ func (g *GeneratorCommand) generatePasswordHash(username, password string, cost
|
|||||||
password = pass1
|
password = pass1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate hash
|
// Generate salt
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), cost)
|
salt := make([]byte, argon2SaltLen)
|
||||||
if err != nil {
|
if _, err := rand.Read(salt); err != nil {
|
||||||
return fmt.Errorf("failed to generate hash: %w", err)
|
return fmt.Errorf("failed to generate salt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate Argon2id hash
|
||||||
|
hash := argon2.IDKey([]byte(password), salt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
|
||||||
|
|
||||||
|
// Encode in PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
|
||||||
|
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||||
|
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
||||||
|
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||||
|
argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64)
|
||||||
|
|
||||||
// Output configuration snippets
|
// Output configuration snippets
|
||||||
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):")
|
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):")
|
||||||
fmt.Fprintln(g.output, "[[pipelines.auth.basic_auth.users]]")
|
fmt.Fprintln(g.output, "[[pipelines.auth.basic_auth.users]]")
|
||||||
fmt.Fprintf(g.output, "username = %q\n", username)
|
fmt.Fprintf(g.output, "username = %q\n", username)
|
||||||
fmt.Fprintf(g.output, "password_hash = %q\n\n", string(hash))
|
fmt.Fprintf(g.output, "password_hash = %q\n\n", phcHash)
|
||||||
|
|
||||||
fmt.Fprintln(g.output, "# Users File Format (for external auth file):")
|
fmt.Fprintln(g.output, "# Users File Format (for external auth file):")
|
||||||
fmt.Fprintf(g.output, "%s:%s\n", username, hash)
|
fmt.Fprintf(g.output, "%s:%s\n", username, phcHash)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -119,11 +126,9 @@ func (g *GeneratorCommand) generateToken(length int) error {
|
|||||||
return fmt.Errorf("failed to generate random bytes: %w", err)
|
return fmt.Errorf("failed to generate random bytes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate both encodings
|
|
||||||
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
|
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
|
||||||
hex := fmt.Sprintf("%x", token)
|
hex := fmt.Sprintf("%x", token)
|
||||||
|
|
||||||
// Output configuration
|
|
||||||
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):")
|
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):")
|
||||||
fmt.Fprintf(g.output, "tokens = [%q]\n\n", b64)
|
fmt.Fprintf(g.output, "tokens = [%q]\n\n", b64)
|
||||||
|
|
||||||
@ -139,7 +144,6 @@ func (g *GeneratorCommand) promptPassword(prompt string) string {
|
|||||||
password, err := term.ReadPassword(int(syscall.Stdin))
|
password, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
fmt.Fprintln(g.errOut)
|
fmt.Fprintln(g.errOut)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Fatal error - can't continue without password
|
|
||||||
fmt.Fprintf(g.errOut, "Failed to read password: %v\n", err)
|
fmt.Fprintf(g.errOut, "Failed to read password: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,7 +29,7 @@ type BasicAuthConfig struct {
|
|||||||
|
|
||||||
type BasicAuthUser struct {
|
type BasicAuthUser struct {
|
||||||
Username string `toml:"username"`
|
Username string `toml:"username"`
|
||||||
// Password hash (bcrypt)
|
// Password hash (Argon2id)
|
||||||
PasswordHash string `toml:"password_hash"`
|
PasswordHash string `toml:"password_hash"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5,13 +5,13 @@ import "fmt"
|
|||||||
|
|
||||||
// Represents logging configuration for LogWisp
|
// Represents logging configuration for LogWisp
|
||||||
type LogConfig struct {
|
type LogConfig struct {
|
||||||
// Output mode: "file", "stdout", "stderr", "both", "none"
|
// Output mode: "file", "stdout", "stderr", "split", "all", "none"
|
||||||
Output string `toml:"output"`
|
Output string `toml:"output"`
|
||||||
|
|
||||||
// Log level: "debug", "info", "warn", "error"
|
// Log level: "debug", "info", "warn", "error"
|
||||||
Level string `toml:"level"`
|
Level string `toml:"level"`
|
||||||
|
|
||||||
// File output settings (when Output includes "file" or "both")
|
// File output settings (when Output includes "file" or "all")
|
||||||
File *LogFileConfig `toml:"file"`
|
File *LogFileConfig `toml:"file"`
|
||||||
|
|
||||||
// Console output settings
|
// Console output settings
|
||||||
@ -66,7 +66,7 @@ func DefaultLogConfig() *LogConfig {
|
|||||||
func validateLogConfig(cfg *LogConfig) error {
|
func validateLogConfig(cfg *LogConfig) error {
|
||||||
validOutputs := map[string]bool{
|
validOutputs := map[string]bool{
|
||||||
"file": true, "stdout": true, "stderr": true,
|
"file": true, "stdout": true, "stderr": true,
|
||||||
"both": true, "all": true, "none": true,
|
"split": true, "all": true, "none": true,
|
||||||
}
|
}
|
||||||
if !validOutputs[cfg.Output] {
|
if !validOutputs[cfg.Output] {
|
||||||
return fmt.Errorf("invalid log output mode: %s", cfg.Output)
|
return fmt.Errorf("invalid log output mode: %s", cfg.Output)
|
||||||
|
|||||||
@ -131,8 +131,8 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate net_limit
|
// Validate net_limit
|
||||||
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||||
if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, rl); err != nil {
|
if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, nl); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -161,15 +161,8 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate net_limit
|
// Validate net_limit
|
||||||
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||||
if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, rl); err != nil {
|
if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, nl); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate TLS
|
|
||||||
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
|
|
||||||
if err := validateTLSOptions("TCP source", pipelineName, sourceIndex, tls); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -196,7 +189,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
|||||||
pipelineName, sinkIndex)
|
pipelineName, sinkIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate host if provided
|
// Validate host
|
||||||
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
||||||
if net.ParseIP(host) == nil {
|
if net.ParseIP(host) == nil {
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
|
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
|
||||||
@ -219,7 +212,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate paths if provided
|
// Validate paths
|
||||||
if streamPath, ok := cfg.Options["stream_path"].(string); ok {
|
if streamPath, ok := cfg.Options["stream_path"].(string); ok {
|
||||||
if !strings.HasPrefix(streamPath, "/") {
|
if !strings.HasPrefix(streamPath, "/") {
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d]: stream path must start with /: %s",
|
return fmt.Errorf("pipeline '%s' sink[%d]: stream path must start with /: %s",
|
||||||
@ -234,7 +227,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate heartbeat if present
|
// Validate heartbeat
|
||||||
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
|
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
|
||||||
if err := validateHeartbeatOptions("HTTP", pipelineName, sinkIndex, hb); err != nil {
|
if err := validateHeartbeatOptions("HTTP", pipelineName, sinkIndex, hb); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -248,9 +241,9 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate net limit if present
|
// Validate net limit
|
||||||
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||||
if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, rl); err != nil {
|
if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, nl); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -263,7 +256,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
|||||||
pipelineName, sinkIndex)
|
pipelineName, sinkIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate host if provided
|
// Validate host
|
||||||
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
||||||
if net.ParseIP(host) == nil {
|
if net.ParseIP(host) == nil {
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
|
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
|
||||||
@ -286,23 +279,16 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate heartbeat if present
|
// Validate heartbeat
|
||||||
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
|
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
|
||||||
if err := validateHeartbeatOptions("TCP", pipelineName, sinkIndex, hb); err != nil {
|
if err := validateHeartbeatOptions("TCP", pipelineName, sinkIndex, hb); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate TLS if present
|
// Validate net limit
|
||||||
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
|
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||||
if err := validateTLSOptions("TCP", pipelineName, sinkIndex, tls); err != nil {
|
if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, nl); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate net limit if present
|
|
||||||
if rl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
|
||||||
if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, rl); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,9 +12,6 @@ type TCPConfig struct {
|
|||||||
Port int64 `toml:"port"`
|
Port int64 `toml:"port"`
|
||||||
BufferSize int64 `toml:"buffer_size"`
|
BufferSize int64 `toml:"buffer_size"`
|
||||||
|
|
||||||
// TLS Configuration
|
|
||||||
TLS *TLSConfig `toml:"tls"`
|
|
||||||
|
|
||||||
// Net limiting
|
// Net limiting
|
||||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||||
|
|
||||||
@ -63,16 +60,15 @@ type NetLimitConfig struct {
|
|||||||
// Burst size (token bucket)
|
// Burst size (token bucket)
|
||||||
BurstSize int64 `toml:"burst_size"`
|
BurstSize int64 `toml:"burst_size"`
|
||||||
|
|
||||||
// Net limit by: "ip", "user", "token", "global"
|
|
||||||
LimitBy string `toml:"limit_by"`
|
|
||||||
|
|
||||||
// Response when net limited
|
// Response when net limited
|
||||||
ResponseCode int64 `toml:"response_code"` // Default: 429
|
ResponseCode int64 `toml:"response_code"` // Default: 429
|
||||||
ResponseMessage string `toml:"response_message"` // Default: "Net limit exceeded"
|
ResponseMessage string `toml:"response_message"` // Default: "Net limit exceeded"
|
||||||
|
|
||||||
// Connection limits
|
// Connection limits
|
||||||
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
|
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
|
||||||
MaxTotalConnections int64 `toml:"max_total_connections"`
|
MaxConnectionsPerUser int64 `toml:"max_connections_per_user"`
|
||||||
|
MaxConnectionsPerToken int64 `toml:"max_connections_per_token"`
|
||||||
|
MaxConnectionsTotal int64 `toml:"max_connections_total"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb map[string]any) error {
|
func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb map[string]any) error {
|
||||||
@ -93,13 +89,13 @@ func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl map[string]any) error {
|
func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, nl map[string]any) error {
|
||||||
if enabled, ok := rl["enabled"].(bool); !ok || !enabled {
|
if enabled, ok := nl["enabled"].(bool); !ok || !enabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate IP lists if present
|
// Validate IP lists if present
|
||||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||||
for i, entry := range ipWhitelist {
|
for i, entry := range ipWhitelist {
|
||||||
entryStr, ok := entry.(string)
|
entryStr, ok := entry.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -112,7 +108,7 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||||
for i, entry := range ipBlacklist {
|
for i, entry := range ipBlacklist {
|
||||||
entryStr, ok := entry.(string)
|
entryStr, ok := entry.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -126,30 +122,21 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate requests per second
|
// Validate requests per second
|
||||||
rps, ok := rl["requests_per_second"].(float64)
|
rps, ok := nl["requests_per_second"].(float64)
|
||||||
if !ok || rps <= 0 {
|
if !ok || rps <= 0 {
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: requests_per_second must be positive",
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: requests_per_second must be positive",
|
||||||
pipelineName, sinkIndex, serverType)
|
pipelineName, sinkIndex, serverType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate burst size
|
// Validate burst size
|
||||||
burst, ok := rl["burst_size"].(int64)
|
burst, ok := nl["burst_size"].(int64)
|
||||||
if !ok || burst < 1 {
|
if !ok || burst < 1 {
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: burst_size must be at least 1",
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: burst_size must be at least 1",
|
||||||
pipelineName, sinkIndex, serverType)
|
pipelineName, sinkIndex, serverType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate limit_by
|
|
||||||
if limitBy, ok := rl["limit_by"].(string); ok && limitBy != "" {
|
|
||||||
validLimitBy := map[string]bool{"ip": true, "global": true}
|
|
||||||
if !validLimitBy[limitBy] {
|
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid limit_by value: %s (must be 'ip' or 'global')",
|
|
||||||
pipelineName, sinkIndex, serverType, limitBy)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate response code
|
// Validate response code
|
||||||
if respCode, ok := rl["response_code"].(int64); ok {
|
if respCode, ok := nl["response_code"].(int64); ok {
|
||||||
if respCode > 0 && (respCode < 400 || respCode >= 600) {
|
if respCode > 0 && (respCode < 400 || respCode >= 600) {
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: response_code must be 4xx or 5xx: %d",
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: response_code must be 4xx or 5xx: %d",
|
||||||
pipelineName, sinkIndex, serverType, respCode)
|
pipelineName, sinkIndex, serverType, respCode)
|
||||||
@ -157,14 +144,25 @@ func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, rl
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate connection limits
|
// Validate connection limits
|
||||||
maxPerIP, perIPOk := rl["max_connections_per_ip"].(int64)
|
maxPerIP, perIPOk := nl["max_connections_per_ip"].(int64)
|
||||||
maxTotal, totalOk := rl["max_total_connections"].(int64)
|
maxPerUser, perUserOk := nl["max_connections_per_user"].(int64)
|
||||||
|
maxPerToken, perTokenOk := nl["max_connections_per_token"].(int64)
|
||||||
|
maxTotal, totalOk := nl["max_connections_total"].(int64)
|
||||||
|
|
||||||
if perIPOk && totalOk && maxPerIP > 0 && maxTotal > 0 {
|
if perIPOk && perUserOk && perTokenOk && totalOk &&
|
||||||
|
maxPerIP > 0 && maxPerUser > 0 && maxPerToken > 0 && maxTotal > 0 {
|
||||||
if maxPerIP > maxTotal {
|
if maxPerIP > maxTotal {
|
||||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_total_connections (%d)",
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_connections_total (%d)",
|
||||||
pipelineName, sinkIndex, serverType, maxPerIP, maxTotal)
|
pipelineName, sinkIndex, serverType, maxPerIP, maxTotal)
|
||||||
}
|
}
|
||||||
|
if maxPerUser > maxTotal {
|
||||||
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_user (%d) cannot exceed max_connections_total (%d)",
|
||||||
|
pipelineName, sinkIndex, serverType, maxPerUser, maxTotal)
|
||||||
|
}
|
||||||
|
if maxPerToken > maxTotal {
|
||||||
|
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_token (%d) cannot exceed max_connections_total (%d)",
|
||||||
|
pipelineName, sinkIndex, serverType, maxPerToken, maxTotal)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -49,8 +49,11 @@ type NetLimiter struct {
|
|||||||
globalLimiter *TokenBucket
|
globalLimiter *TokenBucket
|
||||||
|
|
||||||
// Connection tracking
|
// Connection tracking
|
||||||
ipConnections map[string]*connTracker
|
ipConnections map[string]*connTracker
|
||||||
connMu sync.RWMutex
|
userConnections map[string]*connTracker
|
||||||
|
tokenConnections map[string]*connTracker
|
||||||
|
totalConnections atomic.Int64
|
||||||
|
connMu sync.RWMutex
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
totalRequests atomic.Uint64
|
totalRequests atomic.Uint64
|
||||||
@ -102,29 +105,23 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
l := &NetLimiter{
|
l := &NetLimiter{
|
||||||
config: cfg,
|
config: cfg,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
ipWhitelist: make([]*net.IPNet, 0),
|
ipWhitelist: make([]*net.IPNet, 0),
|
||||||
ipBlacklist: make([]*net.IPNet, 0),
|
ipBlacklist: make([]*net.IPNet, 0),
|
||||||
ipLimiters: make(map[string]*ipLimiter),
|
ipLimiters: make(map[string]*ipLimiter),
|
||||||
ipConnections: make(map[string]*connTracker),
|
ipConnections: make(map[string]*connTracker),
|
||||||
lastCleanup: time.Now(),
|
userConnections: make(map[string]*connTracker),
|
||||||
ctx: ctx,
|
tokenConnections: make(map[string]*connTracker),
|
||||||
cancel: cancel,
|
lastCleanup: time.Now(),
|
||||||
cleanupDone: make(chan struct{}),
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
cleanupDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse IP lists
|
// Parse IP lists
|
||||||
l.parseIPLists(cfg)
|
l.parseIPLists(cfg)
|
||||||
|
|
||||||
// Create global limiter if configured
|
|
||||||
if cfg.Enabled && cfg.LimitBy == "global" {
|
|
||||||
l.globalLimiter = NewTokenBucket(
|
|
||||||
float64(cfg.BurstSize),
|
|
||||||
cfg.RequestsPerSecond,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start cleanup goroutine only if rate limiting is enabled
|
// Start cleanup goroutine only if rate limiting is enabled
|
||||||
if cfg.Enabled {
|
if cfg.Enabled {
|
||||||
go l.cleanupLoop()
|
go l.cleanupLoop()
|
||||||
@ -138,7 +135,10 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
|||||||
"blacklist_rules", len(l.ipBlacklist),
|
"blacklist_rules", len(l.ipBlacklist),
|
||||||
"requests_per_second", cfg.RequestsPerSecond,
|
"requests_per_second", cfg.RequestsPerSecond,
|
||||||
"burst_size", cfg.BurstSize,
|
"burst_size", cfg.BurstSize,
|
||||||
"limit_by", cfg.LimitBy)
|
"max_connections_per_ip", cfg.MaxConnectionsPerIP,
|
||||||
|
"max_connections_per_user", cfg.MaxConnectionsPerUser,
|
||||||
|
"max_connections_per_token", cfg.MaxConnectionsPerToken,
|
||||||
|
"max_connections_total", cfg.MaxConnectionsTotal)
|
||||||
|
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
@ -276,7 +276,7 @@ func (l *NetLimiter) Shutdown() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks if an HTTP request should be allowed
|
// Checks if an HTTP request should be allowed: IP access control + connection limits (IP only) + calls
|
||||||
func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int64, message string) {
|
func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int64, message string) {
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return true, 0, ""
|
return true, 0, ""
|
||||||
@ -343,7 +343,7 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check rate limit
|
// Check rate limit
|
||||||
if !l.checkLimit(ipStr) {
|
if !l.checkIPLimit(ipStr) {
|
||||||
l.blockedByRateLimit.Add(1)
|
l.blockedByRateLimit.Add(1)
|
||||||
statusCode = l.config.ResponseCode
|
statusCode = l.config.ResponseCode
|
||||||
if statusCode == 0 {
|
if statusCode == 0 {
|
||||||
@ -372,7 +372,7 @@ func (l *NetLimiter) updateConnectionActivity(ip string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks if a TCP connection should be allowed
|
// Checks if a TCP connection should be allowed: IP access control + calls checkIPLimit()
|
||||||
func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
|
func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return true
|
return true
|
||||||
@ -412,7 +412,7 @@ func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
|
|||||||
|
|
||||||
// Check rate limit
|
// Check rate limit
|
||||||
ipStr := tcpAddr.IP.String()
|
ipStr := tcpAddr.IP.String()
|
||||||
if !l.checkLimit(ipStr) {
|
if !l.checkIPLimit(ipStr) {
|
||||||
l.blockedByRateLimit.Add(1)
|
l.blockedByRateLimit.Add(1)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -531,17 +531,40 @@ func (l *NetLimiter) GetStats() map[string]any {
|
|||||||
return map[string]any{"enabled": false}
|
return map[string]any{"enabled": false}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get active rate limiters count
|
||||||
l.ipMu.RLock()
|
l.ipMu.RLock()
|
||||||
activeIPs := len(l.ipLimiters)
|
activeIPs := len(l.ipLimiters)
|
||||||
l.ipMu.RUnlock()
|
l.ipMu.RUnlock()
|
||||||
|
|
||||||
|
// Get connection tracker counts and calculate total active connections
|
||||||
l.connMu.RLock()
|
l.connMu.RLock()
|
||||||
totalConnections := 0
|
ipConnTrackers := len(l.ipConnections)
|
||||||
|
userConnTrackers := len(l.userConnections)
|
||||||
|
tokenConnTrackers := len(l.tokenConnections)
|
||||||
|
|
||||||
|
// Calculate actual connection count by summing all IP connections
|
||||||
|
// Potentially more accurate than totalConnections counter which might drift
|
||||||
|
// TODO: test and refactor if they match
|
||||||
|
actualIPConnections := 0
|
||||||
for _, tracker := range l.ipConnections {
|
for _, tracker := range l.ipConnections {
|
||||||
totalConnections += int(tracker.connections.Load())
|
actualIPConnections += int(tracker.connections.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
actualUserConnections := 0
|
||||||
|
for _, tracker := range l.userConnections {
|
||||||
|
actualUserConnections += int(tracker.connections.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
actualTokenConnections := 0
|
||||||
|
for _, tracker := range l.tokenConnections {
|
||||||
|
actualTokenConnections += int(tracker.connections.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the counter for total (should match actualIPConnections in most cases)
|
||||||
|
totalConns := l.totalConnections.Load()
|
||||||
l.connMu.RUnlock()
|
l.connMu.RUnlock()
|
||||||
|
|
||||||
|
// Calculate total blocked
|
||||||
totalBlocked := l.blockedByBlacklist.Load() +
|
totalBlocked := l.blockedByBlacklist.Load() +
|
||||||
l.blockedByWhitelist.Load() +
|
l.blockedByWhitelist.Load() +
|
||||||
l.blockedByRateLimit.Load() +
|
l.blockedByRateLimit.Load() +
|
||||||
@ -559,23 +582,39 @@ func (l *NetLimiter) GetStats() map[string]any {
|
|||||||
"conn_limit": l.blockedByConnLimit.Load(),
|
"conn_limit": l.blockedByConnLimit.Load(),
|
||||||
"invalid_ip": l.blockedByInvalidIP.Load(),
|
"invalid_ip": l.blockedByInvalidIP.Load(),
|
||||||
},
|
},
|
||||||
"active_ips": activeIPs,
|
"rate_limiting": map[string]any{
|
||||||
"total_connections": totalConnections,
|
|
||||||
"acl": map[string]int{
|
|
||||||
"whitelist_rules": len(l.ipWhitelist),
|
|
||||||
"blacklist_rules": len(l.ipBlacklist),
|
|
||||||
},
|
|
||||||
"rate_limit": map[string]any{
|
|
||||||
"enabled": l.config.Enabled,
|
"enabled": l.config.Enabled,
|
||||||
"requests_per_second": l.config.RequestsPerSecond,
|
"requests_per_second": l.config.RequestsPerSecond,
|
||||||
"burst_size": l.config.BurstSize,
|
"burst_size": l.config.BurstSize,
|
||||||
"limit_by": l.config.LimitBy,
|
"active_ip_limiters": activeIPs, // IPs being rate-limited
|
||||||
|
},
|
||||||
|
"access_control": map[string]any{
|
||||||
|
"whitelist_rules": len(l.ipWhitelist),
|
||||||
|
"blacklist_rules": len(l.ipBlacklist),
|
||||||
|
},
|
||||||
|
"connections": map[string]any{
|
||||||
|
// Actual counts
|
||||||
|
"total_active": totalConns, // Counter-based total
|
||||||
|
"active_ip_connections": actualIPConnections, // Sum of all IP connections
|
||||||
|
"active_user_connections": actualUserConnections, // Sum of all user connections
|
||||||
|
"active_token_connections": actualTokenConnections, // Sum of all token connections
|
||||||
|
|
||||||
|
// Tracker counts (number of unique IPs/users/tokens being tracked)
|
||||||
|
"tracked_ips": ipConnTrackers,
|
||||||
|
"tracked_users": userConnTrackers,
|
||||||
|
"tracked_tokens": tokenConnTrackers,
|
||||||
|
|
||||||
|
// Configuration limits (0 = disabled)
|
||||||
|
"limit_per_ip": l.config.MaxConnectionsPerIP,
|
||||||
|
"limit_per_user": l.config.MaxConnectionsPerUser,
|
||||||
|
"limit_per_token": l.config.MaxConnectionsPerToken,
|
||||||
|
"limit_total": l.config.MaxConnectionsTotal,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Performs the actual net limit check
|
// Performs IP net limit check (req/sec)
|
||||||
func (l *NetLimiter) checkLimit(ip string) bool {
|
func (l *NetLimiter) checkIPLimit(ip string) bool {
|
||||||
// Validate IP format
|
// Validate IP format
|
||||||
parsedIP := net.ParseIP(ip)
|
parsedIP := net.ParseIP(ip)
|
||||||
if parsedIP == nil || !isIPv4(parsedIP) {
|
if parsedIP == nil || !isIPv4(parsedIP) {
|
||||||
@ -588,53 +627,36 @@ func (l *NetLimiter) checkLimit(ip string) bool {
|
|||||||
// Maybe run cleanup
|
// Maybe run cleanup
|
||||||
l.maybeCleanup()
|
l.maybeCleanup()
|
||||||
|
|
||||||
switch l.config.LimitBy {
|
// IP limit
|
||||||
case "global":
|
l.ipMu.Lock()
|
||||||
return l.globalLimiter.Allow()
|
lim, exists := l.ipLimiters[ip]
|
||||||
|
if !exists {
|
||||||
case "ip", "":
|
// Create new limiter for this IP
|
||||||
// Default to per-IP limiting
|
lim = &ipLimiter{
|
||||||
l.ipMu.Lock()
|
bucket: NewTokenBucket(
|
||||||
lim, exists := l.ipLimiters[ip]
|
float64(l.config.BurstSize),
|
||||||
if !exists {
|
l.config.RequestsPerSecond,
|
||||||
// Create new limiter for this IP
|
),
|
||||||
lim = &ipLimiter{
|
lastSeen: time.Now(),
|
||||||
bucket: NewTokenBucket(
|
|
||||||
float64(l.config.BurstSize),
|
|
||||||
l.config.RequestsPerSecond,
|
|
||||||
),
|
|
||||||
lastSeen: time.Now(),
|
|
||||||
}
|
|
||||||
l.ipLimiters[ip] = lim
|
|
||||||
l.uniqueIPs.Add(1)
|
|
||||||
|
|
||||||
l.logger.Debug("msg", "Created new IP limiter",
|
|
||||||
"ip", ip,
|
|
||||||
"total_ips", l.uniqueIPs.Load())
|
|
||||||
} else {
|
|
||||||
lim.lastSeen = time.Now()
|
|
||||||
}
|
}
|
||||||
l.ipMu.Unlock()
|
l.ipLimiters[ip] = lim
|
||||||
|
l.uniqueIPs.Add(1)
|
||||||
|
|
||||||
// Check connection limit if configured
|
l.logger.Debug("msg", "Created new IP limiter",
|
||||||
if l.config.MaxConnectionsPerIP > 0 {
|
"ip", ip,
|
||||||
l.connMu.RLock()
|
"total_ips", l.uniqueIPs.Load())
|
||||||
tracker, exists := l.ipConnections[ip]
|
} else {
|
||||||
l.connMu.RUnlock()
|
lim.lastSeen = time.Now()
|
||||||
|
|
||||||
if exists && tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return lim.bucket.Allow()
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Unknown limit_by value, allow by default
|
|
||||||
l.logger.Warn("msg", "Unknown limit_by value",
|
|
||||||
"limit_by", l.config.LimitBy)
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
l.ipMu.Unlock()
|
||||||
|
|
||||||
|
// Rate limit check
|
||||||
|
allowed := lim.bucket.Allow()
|
||||||
|
if !allowed {
|
||||||
|
l.blockedByRateLimit.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return allowed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runs cleanup if enough time has passed
|
// Runs cleanup if enough time has passed
|
||||||
@ -691,25 +713,57 @@ func (l *NetLimiter) cleanup() {
|
|||||||
|
|
||||||
// Clean up stale connection trackers
|
// Clean up stale connection trackers
|
||||||
l.connMu.Lock()
|
l.connMu.Lock()
|
||||||
connCleaned := 0
|
|
||||||
|
// Clean IP connections
|
||||||
|
ipCleaned := 0
|
||||||
for ip, tracker := range l.ipConnections {
|
for ip, tracker := range l.ipConnections {
|
||||||
tracker.mu.Lock()
|
tracker.mu.Lock()
|
||||||
lastSeen := tracker.lastSeen
|
lastSeen := tracker.lastSeen
|
||||||
tracker.mu.Unlock()
|
tracker.mu.Unlock()
|
||||||
|
|
||||||
// Remove if no activity for 5 minutes AND no active connections
|
|
||||||
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
|
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
|
||||||
delete(l.ipConnections, ip)
|
delete(l.ipConnections, ip)
|
||||||
connCleaned++
|
ipCleaned++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean user connections
|
||||||
|
userCleaned := 0
|
||||||
|
for user, tracker := range l.userConnections {
|
||||||
|
tracker.mu.Lock()
|
||||||
|
lastSeen := tracker.lastSeen
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
|
||||||
|
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
|
||||||
|
delete(l.userConnections, user)
|
||||||
|
userCleaned++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean token connections
|
||||||
|
tokenCleaned := 0
|
||||||
|
for token, tracker := range l.tokenConnections {
|
||||||
|
tracker.mu.Lock()
|
||||||
|
lastSeen := tracker.lastSeen
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
|
||||||
|
if now.Sub(lastSeen) > staleTimeout && tracker.connections.Load() <= 0 {
|
||||||
|
delete(l.tokenConnections, token)
|
||||||
|
tokenCleaned++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
l.connMu.Unlock()
|
l.connMu.Unlock()
|
||||||
|
|
||||||
if connCleaned > 0 {
|
if ipCleaned > 0 || userCleaned > 0 || tokenCleaned > 0 {
|
||||||
l.logger.Debug("msg", "Cleaned up stale connection trackers",
|
l.logger.Debug("msg", "Cleaned up stale connection trackers",
|
||||||
"component", "netlimit",
|
"component", "netlimit",
|
||||||
"cleaned", connCleaned,
|
"ip_cleaned", ipCleaned,
|
||||||
"remaining", len(l.ipConnections))
|
"user_cleaned", userCleaned,
|
||||||
|
"token_cleaned", tokenCleaned,
|
||||||
|
"ip_remaining", len(l.ipConnections),
|
||||||
|
"user_remaining", len(l.userConnections),
|
||||||
|
"token_remaining", len(l.tokenConnections))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -731,3 +785,163 @@ func (l *NetLimiter) cleanupLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tracks a new connection with optional user/token info: Connection limits (IP/user/token/total) for TCP only
|
||||||
|
func (l *NetLimiter) TrackConnection(ip string, user string, token string) bool {
|
||||||
|
if l == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
l.connMu.Lock()
|
||||||
|
defer l.connMu.Unlock()
|
||||||
|
|
||||||
|
// Check total connections limit (0 = disabled)
|
||||||
|
if l.config.MaxConnectionsTotal > 0 {
|
||||||
|
currentTotal := l.totalConnections.Load()
|
||||||
|
if currentTotal >= l.config.MaxConnectionsTotal {
|
||||||
|
l.blockedByConnLimit.Add(1)
|
||||||
|
l.logger.Debug("msg", "TCP connection blocked by total limit",
|
||||||
|
"component", "netlimit",
|
||||||
|
"current_total", currentTotal,
|
||||||
|
"max_total", l.config.MaxConnectionsTotal)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check per-IP connection limit (0 = disabled)
|
||||||
|
if l.config.MaxConnectionsPerIP > 0 && ip != "" {
|
||||||
|
tracker, exists := l.ipConnections[ip]
|
||||||
|
if !exists {
|
||||||
|
tracker = &connTracker{lastSeen: time.Now()}
|
||||||
|
l.ipConnections[ip] = tracker
|
||||||
|
}
|
||||||
|
if tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
|
||||||
|
l.blockedByConnLimit.Add(1)
|
||||||
|
l.logger.Debug("msg", "TCP connection blocked by IP limit",
|
||||||
|
"component", "netlimit",
|
||||||
|
"ip", ip,
|
||||||
|
"current", tracker.connections.Load(),
|
||||||
|
"max", l.config.MaxConnectionsPerIP)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check per-user connection limit (0 = disabled)
|
||||||
|
if l.config.MaxConnectionsPerUser > 0 && user != "" {
|
||||||
|
tracker, exists := l.userConnections[user]
|
||||||
|
if !exists {
|
||||||
|
tracker = &connTracker{lastSeen: time.Now()}
|
||||||
|
l.userConnections[user] = tracker
|
||||||
|
}
|
||||||
|
if tracker.connections.Load() >= l.config.MaxConnectionsPerUser {
|
||||||
|
l.blockedByConnLimit.Add(1)
|
||||||
|
l.logger.Debug("msg", "TCP connection blocked by user limit",
|
||||||
|
"component", "netlimit",
|
||||||
|
"user", user,
|
||||||
|
"current", tracker.connections.Load(),
|
||||||
|
"max", l.config.MaxConnectionsPerUser)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check per-token connection limit (0 = disabled)
|
||||||
|
if l.config.MaxConnectionsPerToken > 0 && token != "" {
|
||||||
|
tracker, exists := l.tokenConnections[token]
|
||||||
|
if !exists {
|
||||||
|
tracker = &connTracker{lastSeen: time.Now()}
|
||||||
|
l.tokenConnections[token] = tracker
|
||||||
|
}
|
||||||
|
if tracker.connections.Load() >= l.config.MaxConnectionsPerToken {
|
||||||
|
l.blockedByConnLimit.Add(1)
|
||||||
|
l.logger.Debug("msg", "TCP connection blocked by token limit",
|
||||||
|
"component", "netlimit",
|
||||||
|
"token", token,
|
||||||
|
"current", tracker.connections.Load(),
|
||||||
|
"max", l.config.MaxConnectionsPerToken)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All checks passed, increment counters
|
||||||
|
l.totalConnections.Add(1)
|
||||||
|
|
||||||
|
if ip != "" && l.config.MaxConnectionsPerIP > 0 {
|
||||||
|
if tracker, exists := l.ipConnections[ip]; exists {
|
||||||
|
tracker.connections.Add(1)
|
||||||
|
tracker.mu.Lock()
|
||||||
|
tracker.lastSeen = time.Now()
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user != "" && l.config.MaxConnectionsPerUser > 0 {
|
||||||
|
if tracker, exists := l.userConnections[user]; exists {
|
||||||
|
tracker.connections.Add(1)
|
||||||
|
tracker.mu.Lock()
|
||||||
|
tracker.lastSeen = time.Now()
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token != "" && l.config.MaxConnectionsPerToken > 0 {
|
||||||
|
if tracker, exists := l.tokenConnections[token]; exists {
|
||||||
|
tracker.connections.Add(1)
|
||||||
|
tracker.mu.Lock()
|
||||||
|
tracker.lastSeen = time.Now()
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Releases a tracked connection
|
||||||
|
func (l *NetLimiter) ReleaseConnection(ip string, user string, token string) {
|
||||||
|
if l == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.connMu.Lock()
|
||||||
|
defer l.connMu.Unlock()
|
||||||
|
|
||||||
|
// Decrement total
|
||||||
|
if l.totalConnections.Load() > 0 {
|
||||||
|
l.totalConnections.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrement IP counter
|
||||||
|
if ip != "" {
|
||||||
|
if tracker, exists := l.ipConnections[ip]; exists {
|
||||||
|
if tracker.connections.Load() > 0 {
|
||||||
|
tracker.connections.Add(-1)
|
||||||
|
}
|
||||||
|
tracker.mu.Lock()
|
||||||
|
tracker.lastSeen = time.Now()
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrement user counter
|
||||||
|
if user != "" {
|
||||||
|
if tracker, exists := l.userConnections[user]; exists {
|
||||||
|
if tracker.connections.Load() > 0 {
|
||||||
|
tracker.connections.Add(-1)
|
||||||
|
}
|
||||||
|
tracker.mu.Lock()
|
||||||
|
tracker.lastSeen = time.Now()
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrement token counter
|
||||||
|
if token != "" {
|
||||||
|
if tracker, exists := l.tokenConnections[token]; exists {
|
||||||
|
if tracker.connections.Load() > 0 {
|
||||||
|
tracker.connections.Add(-1)
|
||||||
|
}
|
||||||
|
tracker.mu.Lock()
|
||||||
|
tracker.lastSeen = time.Now()
|
||||||
|
tracker.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -129,6 +129,11 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure authentication for sources that support it
|
||||||
|
for _, sourceInst := range pipeline.Sources {
|
||||||
|
sourceInst.SetAuth(cfg.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
// Start all sinks
|
// Start all sinks
|
||||||
for i, sinkInst := range pipeline.Sinks {
|
for i, sinkInst := range pipeline.Sinks {
|
||||||
if err := sinkInst.Start(pipelineCtx); err != nil {
|
if err := sinkInst.Start(pipelineCtx); err != nil {
|
||||||
@ -139,9 +144,7 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
|
|||||||
|
|
||||||
// Configure authentication for sinks that support it
|
// Configure authentication for sinks that support it
|
||||||
for _, sinkInst := range pipeline.Sinks {
|
for _, sinkInst := range pipeline.Sinks {
|
||||||
if setter, ok := sinkInst.(sink.AuthSetter); ok {
|
sinkInst.SetAuth(cfg.Auth)
|
||||||
setter.SetAuthConfig(cfg.Auth)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wire sources to sinks through filters
|
// Wire sources to sinks through filters
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"logwisp/src/internal/config"
|
||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
"logwisp/src/internal/format"
|
"logwisp/src/internal/format"
|
||||||
|
|
||||||
@ -121,7 +122,9 @@ func (s *StdoutSink) processLoop(ctx context.Context) {
|
|||||||
// Format and write
|
// Format and write
|
||||||
formatted, err := s.formatter.Format(entry)
|
formatted, err := s.formatter.Format(entry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("msg", "Failed to format log entry for stdout", "error", err)
|
s.logger.Error("msg", "Failed to format log entry for stdout",
|
||||||
|
"component", "stdout_sink",
|
||||||
|
"error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.output.Write(formatted)
|
s.output.Write(formatted)
|
||||||
@ -234,7 +237,9 @@ func (s *StderrSink) processLoop(ctx context.Context) {
|
|||||||
// Format and write
|
// Format and write
|
||||||
formatted, err := s.formatter.Format(entry)
|
formatted, err := s.formatter.Format(entry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("msg", "Failed to format log entry for stderr", "error", err)
|
s.logger.Error("msg", "Failed to format log entry for stderr",
|
||||||
|
"component", "stderr_sink",
|
||||||
|
"error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.output.Write(formatted)
|
s.output.Write(formatted)
|
||||||
@ -246,3 +251,11 @@ func (s *StderrSink) processLoop(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StdoutSink) SetAuth(auth *config.AuthConfig) {
|
||||||
|
// Authentication does not apply to stdout sink
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StderrSink) SetAuth(auth *config.AuthConfig) {
|
||||||
|
// Authentication does not apply to stderr sink
|
||||||
|
}
|
||||||
@ -4,6 +4,7 @@ package sink
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"logwisp/src/internal/config"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -43,7 +44,7 @@ func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
|||||||
writerConfig := log.DefaultConfig()
|
writerConfig := log.DefaultConfig()
|
||||||
writerConfig.Directory = directory
|
writerConfig.Directory = directory
|
||||||
writerConfig.Name = name
|
writerConfig.Name = name
|
||||||
writerConfig.EnableStdout = false // File only
|
writerConfig.EnableConsole = false // File only
|
||||||
writerConfig.ShowTimestamp = false // We already have timestamps in entries
|
writerConfig.ShowTimestamp = false // We already have timestamps in entries
|
||||||
writerConfig.ShowLevel = false // We already have levels in entries
|
writerConfig.ShowLevel = false // We already have levels in entries
|
||||||
|
|
||||||
@ -165,3 +166,7 @@ func (fs *FileSink) processLoop(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fs *FileSink) SetAuth(auth *config.AuthConfig) {
|
||||||
|
// Authentication does not apply to file sink
|
||||||
|
}
|
||||||
@ -142,31 +142,28 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract net limit config
|
// Extract net limit config
|
||||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||||
cfg.NetLimit = &config.NetLimitConfig{}
|
cfg.NetLimit = &config.NetLimitConfig{}
|
||||||
cfg.NetLimit.Enabled, _ = rl["enabled"].(bool)
|
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
|
||||||
if rps, ok := rl["requests_per_second"].(float64); ok {
|
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||||
cfg.NetLimit.RequestsPerSecond = rps
|
cfg.NetLimit.RequestsPerSecond = rps
|
||||||
}
|
}
|
||||||
if burst, ok := rl["burst_size"].(int64); ok {
|
if burst, ok := nl["burst_size"].(int64); ok {
|
||||||
cfg.NetLimit.BurstSize = burst
|
cfg.NetLimit.BurstSize = burst
|
||||||
}
|
}
|
||||||
if limitBy, ok := rl["limit_by"].(string); ok {
|
if respCode, ok := nl["response_code"].(int64); ok {
|
||||||
cfg.NetLimit.LimitBy = limitBy
|
|
||||||
}
|
|
||||||
if respCode, ok := rl["response_code"].(int64); ok {
|
|
||||||
cfg.NetLimit.ResponseCode = respCode
|
cfg.NetLimit.ResponseCode = respCode
|
||||||
}
|
}
|
||||||
if msg, ok := rl["response_message"].(string); ok {
|
if msg, ok := nl["response_message"].(string); ok {
|
||||||
cfg.NetLimit.ResponseMessage = msg
|
cfg.NetLimit.ResponseMessage = msg
|
||||||
}
|
}
|
||||||
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
|
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||||
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
|
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
|
||||||
}
|
}
|
||||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
cfg.NetLimit.MaxConnectionsTotal = maxTotal
|
||||||
}
|
}
|
||||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||||
for _, entry := range ipWhitelist {
|
for _, entry := range ipWhitelist {
|
||||||
if str, ok := entry.(string); ok {
|
if str, ok := entry.(string); ok {
|
||||||
@ -174,7 +171,7 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||||
for _, entry := range ipBlacklist {
|
for _, entry := range ipBlacklist {
|
||||||
if str, ok := entry.(string); ok {
|
if str, ok := entry.(string); ok {
|
||||||
@ -806,8 +803,8 @@ func (h *HTTPSink) GetHost() string {
|
|||||||
return h.config.Host
|
return h.config.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configures http sink authentication
|
// Configures http sink auth
|
||||||
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
func (h *HTTPSink) SetAuth(authCfg *config.AuthConfig) {
|
||||||
if authCfg == nil || authCfg.Type == "none" {
|
if authCfg == nil || authCfg.Type == "none" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,7 +6,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"logwisp/src/internal/auth"
|
||||||
|
"logwisp/src/internal/config"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@ -23,16 +26,17 @@ import (
|
|||||||
|
|
||||||
// Forwards log entries to a remote HTTP endpoint
|
// Forwards log entries to a remote HTTP endpoint
|
||||||
type HTTPClientSink struct {
|
type HTTPClientSink struct {
|
||||||
input chan core.LogEntry
|
input chan core.LogEntry
|
||||||
config HTTPClientConfig
|
config HTTPClientConfig
|
||||||
client *fasthttp.Client
|
client *fasthttp.Client
|
||||||
batch []core.LogEntry
|
batch []core.LogEntry
|
||||||
batchMu sync.Mutex
|
batchMu sync.Mutex
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
formatter format.Formatter
|
formatter format.Formatter
|
||||||
|
authenticator *auth.Authenticator
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
totalProcessed atomic.Uint64
|
totalProcessed atomic.Uint64
|
||||||
@ -44,7 +48,9 @@ type HTTPClientSink struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Holds HTTP client sink configuration
|
// Holds HTTP client sink configuration
|
||||||
|
// TODO: missing toml tags
|
||||||
type HTTPClientConfig struct {
|
type HTTPClientConfig struct {
|
||||||
|
// Config
|
||||||
URL string
|
URL string
|
||||||
BufferSize int64
|
BufferSize int64
|
||||||
BatchSize int64
|
BatchSize int64
|
||||||
@ -57,6 +63,10 @@ type HTTPClientConfig struct {
|
|||||||
RetryDelay time.Duration
|
RetryDelay time.Duration
|
||||||
RetryBackoff float64 // Multiplier for exponential backoff
|
RetryBackoff float64 // Multiplier for exponential backoff
|
||||||
|
|
||||||
|
// Security
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
|
||||||
// TLS configuration
|
// TLS configuration
|
||||||
InsecureSkipVerify bool
|
InsecureSkipVerify bool
|
||||||
CAFile string
|
CAFile string
|
||||||
@ -118,6 +128,12 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
|
|||||||
if insecure, ok := options["insecure_skip_verify"].(bool); ok {
|
if insecure, ok := options["insecure_skip_verify"].(bool); ok {
|
||||||
cfg.InsecureSkipVerify = insecure
|
cfg.InsecureSkipVerify = insecure
|
||||||
}
|
}
|
||||||
|
if username, ok := options["username"].(string); ok {
|
||||||
|
cfg.Username = username
|
||||||
|
}
|
||||||
|
if password, ok := options["password"].(string); ok {
|
||||||
|
cfg.Password = password // TODO: change to Argon2 hashed password
|
||||||
|
}
|
||||||
|
|
||||||
// Extract headers
|
// Extract headers
|
||||||
if headers, ok := options["headers"].(map[string]any); ok {
|
if headers, ok := options["headers"].(map[string]any); ok {
|
||||||
@ -422,8 +438,16 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
|||||||
|
|
||||||
req.SetRequestURI(h.config.URL)
|
req.SetRequestURI(h.config.URL)
|
||||||
req.Header.SetMethod("POST")
|
req.Header.SetMethod("POST")
|
||||||
|
req.Header.SetContentType("application/json")
|
||||||
req.SetBody(body)
|
req.SetBody(body)
|
||||||
|
|
||||||
|
// Add Basic Auth header if credentials configured
|
||||||
|
if h.config.Username != "" && h.config.Password != "" {
|
||||||
|
creds := h.config.Username + ":" + h.config.Password
|
||||||
|
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
|
||||||
|
req.Header.Set("Authorization", "Basic "+encodedCreds)
|
||||||
|
}
|
||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
for k, v := range h.config.Headers {
|
for k, v := range h.config.Headers {
|
||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
@ -495,3 +519,9 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
|||||||
"last_error", lastErr)
|
"last_error", lastErr)
|
||||||
h.failedBatches.Add(1)
|
h.failedBatches.Add(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Not applicable, Clients authenticate to remote servers using Username/Password in config
|
||||||
|
func (h *HTTPClientSink) SetAuth(authCfg *config.AuthConfig) {
|
||||||
|
// No-op: client sinks don't validate incoming connections
|
||||||
|
// They authenticate to remote servers using Username/Password fields
|
||||||
|
}
|
||||||
@ -9,19 +9,22 @@ import (
|
|||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Represents an output destination for log entries
|
// Represents an output data stream
|
||||||
type Sink interface {
|
type Sink interface {
|
||||||
// Input returns the channel for sending log entries to this sink
|
// Returns the channel for sending log entries to this sink
|
||||||
Input() chan<- core.LogEntry
|
Input() chan<- core.LogEntry
|
||||||
|
|
||||||
// Start begins processing log entries
|
// Begins processing log entries
|
||||||
Start(ctx context.Context) error
|
Start(ctx context.Context) error
|
||||||
|
|
||||||
// Stop gracefully shuts down the sink
|
// Gracefully shuts down the sink
|
||||||
Stop()
|
Stop()
|
||||||
|
|
||||||
// GetStats returns sink statistics
|
// Returns sink statistics
|
||||||
GetStats() SinkStats
|
GetStats() SinkStats
|
||||||
|
|
||||||
|
// Configure authentication
|
||||||
|
SetAuth(auth *config.AuthConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains statistics about a sink
|
// Contains statistics about a sink
|
||||||
@ -33,8 +36,3 @@ type SinkStats struct {
|
|||||||
LastProcessed time.Time
|
LastProcessed time.Time
|
||||||
Details map[string]any
|
Details map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
// Interface for sinks that can accept an AuthConfig
|
|
||||||
type AuthSetter interface {
|
|
||||||
SetAuthConfig(auth *config.AuthConfig)
|
|
||||||
}
|
|
||||||
@ -17,7 +17,6 @@ import (
|
|||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
"logwisp/src/internal/format"
|
"logwisp/src/internal/format"
|
||||||
"logwisp/src/internal/limit"
|
"logwisp/src/internal/limit"
|
||||||
"logwisp/src/internal/tls"
|
|
||||||
|
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
"github.com/lixenwraith/log/compat"
|
"github.com/lixenwraith/log/compat"
|
||||||
@ -26,23 +25,20 @@ import (
|
|||||||
|
|
||||||
// Streams log entries via TCP
|
// Streams log entries via TCP
|
||||||
type TCPSink struct {
|
type TCPSink struct {
|
||||||
input chan core.LogEntry
|
// C
|
||||||
config TCPConfig
|
input chan core.LogEntry
|
||||||
server *tcpServer
|
config TCPConfig
|
||||||
done chan struct{}
|
server *tcpServer
|
||||||
activeConns atomic.Int64
|
done chan struct{}
|
||||||
startTime time.Time
|
activeConns atomic.Int64
|
||||||
engine *gnet.Engine
|
startTime time.Time
|
||||||
engineMu sync.Mutex
|
engine *gnet.Engine
|
||||||
wg sync.WaitGroup
|
engineMu sync.Mutex
|
||||||
netLimiter *limit.NetLimiter
|
wg sync.WaitGroup
|
||||||
logger *log.Logger
|
netLimiter *limit.NetLimiter
|
||||||
formatter format.Formatter
|
logger *log.Logger
|
||||||
|
formatter format.Formatter
|
||||||
// Security components
|
|
||||||
authenticator *auth.Authenticator
|
authenticator *auth.Authenticator
|
||||||
tlsManager *tls.Manager
|
|
||||||
authConfig *config.AuthConfig
|
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
totalProcessed atomic.Uint64
|
totalProcessed atomic.Uint64
|
||||||
@ -62,7 +58,6 @@ type TCPConfig struct {
|
|||||||
Port int64
|
Port int64
|
||||||
BufferSize int64
|
BufferSize int64
|
||||||
Heartbeat *config.HeartbeatConfig
|
Heartbeat *config.HeartbeatConfig
|
||||||
TLS *config.TLSConfig
|
|
||||||
NetLimit *config.NetLimitConfig
|
NetLimit *config.NetLimitConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,58 +94,35 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract TLS config
|
|
||||||
if tc, ok := options["tls"].(map[string]any); ok {
|
|
||||||
cfg.TLS = &config.TLSConfig{}
|
|
||||||
cfg.TLS.Enabled, _ = tc["enabled"].(bool)
|
|
||||||
if certFile, ok := tc["cert_file"].(string); ok {
|
|
||||||
cfg.TLS.CertFile = certFile
|
|
||||||
}
|
|
||||||
if keyFile, ok := tc["key_file"].(string); ok {
|
|
||||||
cfg.TLS.KeyFile = keyFile
|
|
||||||
}
|
|
||||||
cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool)
|
|
||||||
if caFile, ok := tc["client_ca_file"].(string); ok {
|
|
||||||
cfg.TLS.ClientCAFile = caFile
|
|
||||||
}
|
|
||||||
cfg.TLS.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
|
|
||||||
if minVer, ok := tc["min_version"].(string); ok {
|
|
||||||
cfg.TLS.MinVersion = minVer
|
|
||||||
}
|
|
||||||
if maxVer, ok := tc["max_version"].(string); ok {
|
|
||||||
cfg.TLS.MaxVersion = maxVer
|
|
||||||
}
|
|
||||||
if ciphers, ok := tc["cipher_suites"].(string); ok {
|
|
||||||
cfg.TLS.CipherSuites = ciphers
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract net limit config
|
// Extract net limit config
|
||||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||||
cfg.NetLimit = &config.NetLimitConfig{}
|
cfg.NetLimit = &config.NetLimitConfig{}
|
||||||
cfg.NetLimit.Enabled, _ = rl["enabled"].(bool)
|
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
|
||||||
if rps, ok := rl["requests_per_second"].(float64); ok {
|
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||||
cfg.NetLimit.RequestsPerSecond = rps
|
cfg.NetLimit.RequestsPerSecond = rps
|
||||||
}
|
}
|
||||||
if burst, ok := rl["burst_size"].(int64); ok {
|
if burst, ok := nl["burst_size"].(int64); ok {
|
||||||
cfg.NetLimit.BurstSize = burst
|
cfg.NetLimit.BurstSize = burst
|
||||||
}
|
}
|
||||||
if limitBy, ok := rl["limit_by"].(string); ok {
|
if respCode, ok := nl["response_code"].(int64); ok {
|
||||||
cfg.NetLimit.LimitBy = limitBy
|
|
||||||
}
|
|
||||||
if respCode, ok := rl["response_code"].(int64); ok {
|
|
||||||
cfg.NetLimit.ResponseCode = respCode
|
cfg.NetLimit.ResponseCode = respCode
|
||||||
}
|
}
|
||||||
if msg, ok := rl["response_message"].(string); ok {
|
if msg, ok := nl["response_message"].(string); ok {
|
||||||
cfg.NetLimit.ResponseMessage = msg
|
cfg.NetLimit.ResponseMessage = msg
|
||||||
}
|
}
|
||||||
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
|
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||||
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
|
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
|
||||||
}
|
}
|
||||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
|
||||||
cfg.NetLimit.MaxTotalConnections = maxTotal
|
cfg.NetLimit.MaxConnectionsPerUser = maxPerUser
|
||||||
}
|
}
|
||||||
if ipWhitelist, ok := rl["ip_whitelist"].([]any); ok {
|
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
|
||||||
|
cfg.NetLimit.MaxConnectionsPerToken = maxPerToken
|
||||||
|
}
|
||||||
|
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||||
|
cfg.NetLimit.MaxConnectionsTotal = maxTotal
|
||||||
|
}
|
||||||
|
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||||
for _, entry := range ipWhitelist {
|
for _, entry := range ipWhitelist {
|
||||||
if str, ok := entry.(string); ok {
|
if str, ok := entry.(string); ok {
|
||||||
@ -158,7 +130,7 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ipBlacklist, ok := rl["ip_blacklist"].([]any); ok {
|
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||||
for _, entry := range ipBlacklist {
|
for _, entry := range ipBlacklist {
|
||||||
if str, ok := entry.(string); ok {
|
if str, ok := entry.(string); ok {
|
||||||
@ -290,18 +262,6 @@ func (t *TCPSink) GetStats() SinkStats {
|
|||||||
netLimitStats = t.netLimiter.GetStats()
|
netLimitStats = t.netLimiter.GetStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
var authStats map[string]any
|
|
||||||
if t.authenticator != nil {
|
|
||||||
authStats = t.authenticator.GetStats()
|
|
||||||
authStats["failures"] = t.authFailures.Load()
|
|
||||||
authStats["successes"] = t.authSuccesses.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
var tlsStats map[string]any
|
|
||||||
if t.tlsManager != nil {
|
|
||||||
tlsStats = t.tlsManager.GetStats()
|
|
||||||
}
|
|
||||||
|
|
||||||
return SinkStats{
|
return SinkStats{
|
||||||
Type: "tcp",
|
Type: "tcp",
|
||||||
TotalProcessed: t.totalProcessed.Load(),
|
TotalProcessed: t.totalProcessed.Load(),
|
||||||
@ -312,8 +272,7 @@ func (t *TCPSink) GetStats() SinkStats {
|
|||||||
"port": t.config.Port,
|
"port": t.config.Port,
|
||||||
"buffer_size": t.config.BufferSize,
|
"buffer_size": t.config.BufferSize,
|
||||||
"net_limit": netLimitStats,
|
"net_limit": netLimitStats,
|
||||||
"auth": authStats,
|
"auth": map[string]any{"enabled": false},
|
||||||
"tls": tlsStats,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -347,37 +306,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
|||||||
"entry_source", entry.Source)
|
"entry_source", entry.Source)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
t.broadcastData(data)
|
||||||
// Broadcast only to authenticated clients
|
|
||||||
t.server.mu.RLock()
|
|
||||||
for conn, client := range t.server.clients {
|
|
||||||
if client.authenticated {
|
|
||||||
// Send through TLS bridge if present
|
|
||||||
if client.tlsBridge != nil {
|
|
||||||
if _, err := client.tlsBridge.Write(data); err != nil {
|
|
||||||
// TLS write failed, connection likely dead
|
|
||||||
t.logger.Debug("msg", "TLS write failed",
|
|
||||||
"component", "tcp_sink",
|
|
||||||
"error", err)
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
t.writeErrors.Add(1)
|
|
||||||
t.handleWriteError(c, err)
|
|
||||||
} else {
|
|
||||||
// Reset consecutive error count on success
|
|
||||||
t.errorMu.Lock()
|
|
||||||
delete(t.consecutiveWriteErrors, c)
|
|
||||||
t.errorMu.Unlock()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.server.mu.RUnlock()
|
|
||||||
|
|
||||||
case <-tickerChan:
|
case <-tickerChan:
|
||||||
heartbeatEntry := t.createHeartbeatEntry()
|
heartbeatEntry := t.createHeartbeatEntry()
|
||||||
@ -388,37 +317,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
|||||||
"error", err)
|
"error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
t.broadcastData(data)
|
||||||
t.server.mu.RLock()
|
|
||||||
for conn, client := range t.server.clients {
|
|
||||||
if client.authenticated {
|
|
||||||
// Validate session is still active
|
|
||||||
if t.authenticator != nil && client.session != nil {
|
|
||||||
if !t.authenticator.ValidateSession(client.session.ID) {
|
|
||||||
// Session expired, close connection
|
|
||||||
conn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if client.tlsBridge != nil {
|
|
||||||
if _, err := client.tlsBridge.Write(data); err != nil {
|
|
||||||
t.logger.Debug("msg", "TLS heartbeat write failed",
|
|
||||||
"component", "tcp_sink",
|
|
||||||
"error", err)
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
t.writeErrors.Add(1)
|
|
||||||
t.handleWriteError(c, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.server.mu.RUnlock()
|
|
||||||
|
|
||||||
case <-t.done:
|
case <-t.done:
|
||||||
return
|
return
|
||||||
@ -426,6 +325,28 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TCPSink) broadcastData(data []byte) {
|
||||||
|
t.server.mu.RLock()
|
||||||
|
defer t.server.mu.RUnlock()
|
||||||
|
|
||||||
|
for conn, client := range t.server.clients {
|
||||||
|
if client.authenticated {
|
||||||
|
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
t.writeErrors.Add(1)
|
||||||
|
t.handleWriteError(c, err)
|
||||||
|
} else {
|
||||||
|
// Reset consecutive error count on success
|
||||||
|
t.errorMu.Lock()
|
||||||
|
delete(t.consecutiveWriteErrors, c)
|
||||||
|
t.errorMu.Unlock()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle write errors with threshold-based connection termination
|
// Handle write errors with threshold-based connection termination
|
||||||
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
|
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
|
||||||
t.errorMu.Lock()
|
t.errorMu.Lock()
|
||||||
@ -487,13 +408,11 @@ func (t *TCPSink) GetActiveConnections() int64 {
|
|||||||
|
|
||||||
// Represents a connected TCP client with auth state
|
// Represents a connected TCP client with auth state
|
||||||
type tcpClient struct {
|
type tcpClient struct {
|
||||||
conn gnet.Conn
|
conn gnet.Conn
|
||||||
buffer bytes.Buffer
|
buffer bytes.Buffer
|
||||||
authenticated bool
|
authenticated bool
|
||||||
session *auth.Session
|
authTimeout time.Time
|
||||||
authTimeout time.Time
|
session *auth.Session
|
||||||
tlsBridge *tls.GNetTLSConn
|
|
||||||
authTimeoutSet bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handles gnet events with authentication
|
// Handles gnet events with authentication
|
||||||
@ -550,24 +469,12 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
|
|
||||||
// Create client state without auth timeout initially
|
// Create client state without auth timeout initially
|
||||||
client := &tcpClient{
|
client := &tcpClient{
|
||||||
conn: c,
|
conn: c,
|
||||||
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
|
authenticated: s.sink.authenticator == nil,
|
||||||
authTimeoutSet: false, // Auth timeout not started yet
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize TLS bridge if enabled
|
if s.sink.authenticator != nil {
|
||||||
if s.sink.tlsManager != nil {
|
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||||
tlsConfig := s.sink.tlsManager.GetTCPConfig()
|
|
||||||
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
|
|
||||||
client.tlsBridge.Handshake() // Start async handshake
|
|
||||||
|
|
||||||
s.sink.logger.Debug("msg", "TLS handshake initiated",
|
|
||||||
"component", "tcp_sink",
|
|
||||||
"remote_addr", remoteAddr)
|
|
||||||
} else if s.sink.authenticator != nil {
|
|
||||||
// Only set auth timeout if no TLS (plain connection)
|
|
||||||
client.authTimeout = time.Now().Add(30 * time.Second) // TODO: configurable or non-hardcoded timer
|
|
||||||
client.authTimeoutSet = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@ -578,12 +485,11 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
s.sink.logger.Debug("msg", "TCP connection opened",
|
s.sink.logger.Debug("msg", "TCP connection opened",
|
||||||
"remote_addr", remoteAddr,
|
"remote_addr", remoteAddr,
|
||||||
"active_connections", newCount,
|
"active_connections", newCount,
|
||||||
"requires_auth", s.sink.authenticator != nil)
|
"auth_enabled", s.sink.authenticator != nil)
|
||||||
|
|
||||||
// Send auth prompt if authentication is required
|
// Send auth prompt if authentication is required
|
||||||
if s.sink.authenticator != nil && s.sink.tlsManager == nil {
|
if s.sink.authenticator != nil {
|
||||||
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
|
return []byte("AUTH_REQUIRED\n"), gnet.None
|
||||||
return authPrompt, gnet.None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, gnet.None
|
return nil, gnet.None
|
||||||
@ -594,17 +500,9 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
|||||||
|
|
||||||
// Remove client state
|
// Remove client state
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
client := s.clients[c]
|
|
||||||
delete(s.clients, c)
|
delete(s.clients, c)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Clean up TLS bridge if present
|
|
||||||
if client != nil && client.tlsBridge != nil {
|
|
||||||
client.tlsBridge.Close()
|
|
||||||
s.sink.logger.Debug("msg", "TLS connection closed",
|
|
||||||
"remote_addr", remoteAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up write error tracking
|
// Clean up write error tracking
|
||||||
s.sink.errorMu.Lock()
|
s.sink.errorMu.Lock()
|
||||||
delete(s.sink.consecutiveWriteErrors, c)
|
delete(s.sink.consecutiveWriteErrors, c)
|
||||||
@ -632,98 +530,34 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
return gnet.Close
|
return gnet.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read all available data
|
// Authentication phase
|
||||||
data, err := c.Next(-1)
|
if !client.authenticated {
|
||||||
if err != nil {
|
// Check auth timeout
|
||||||
s.sink.logger.Error("msg", "Error reading from connection",
|
if time.Now().After(client.authTimeout) {
|
||||||
"component", "tcp_sink",
|
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||||
"error", err)
|
|
||||||
return gnet.Close
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process through TLS bridge if present
|
|
||||||
if client.tlsBridge != nil {
|
|
||||||
// Feed encrypted data into TLS engine
|
|
||||||
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
|
|
||||||
s.sink.logger.Error("msg", "TLS processing error",
|
|
||||||
"component", "tcp_sink",
|
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
|
||||||
"error", err)
|
|
||||||
return gnet.Close
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if handshake is complete
|
|
||||||
if !client.tlsBridge.IsHandshakeDone() {
|
|
||||||
// Still handshaking, wait for more data
|
|
||||||
return gnet.None
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check handshake result
|
|
||||||
_, hsErr := client.tlsBridge.HandshakeComplete()
|
|
||||||
if hsErr != nil {
|
|
||||||
s.sink.logger.Error("msg", "TLS handshake failed",
|
|
||||||
"component", "tcp_sink",
|
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
|
||||||
"error", hsErr)
|
|
||||||
return gnet.Close
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set auth timeout only after TLS handshake completes
|
|
||||||
if !client.authTimeoutSet && s.sink.authenticator != nil && !client.authenticated {
|
|
||||||
client.authTimeout = time.Now().Add(30 * time.Second)
|
|
||||||
client.authTimeoutSet = true
|
|
||||||
s.sink.logger.Debug("msg", "Auth timeout started after TLS handshake",
|
|
||||||
"component", "tcp_sink",
|
"component", "tcp_sink",
|
||||||
"remote_addr", c.RemoteAddr().String())
|
"remote_addr", c.RemoteAddr().String())
|
||||||
|
return gnet.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read decrypted plaintext
|
// Read auth data
|
||||||
data = client.tlsBridge.Read()
|
data, _ := c.Next(-1)
|
||||||
if data == nil || len(data) == 0 {
|
if len(data) == 0 {
|
||||||
// No plaintext available yet
|
|
||||||
return gnet.None
|
return gnet.None
|
||||||
}
|
}
|
||||||
|
|
||||||
// First data after TLS handshake - send auth prompt if needed
|
|
||||||
if s.sink.authenticator != nil && !client.authenticated &&
|
|
||||||
len(client.buffer.Bytes()) == 0 {
|
|
||||||
authPrompt := []byte("AUTH REQUIRED\n")
|
|
||||||
client.tlsBridge.Write(authPrompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only check auth timeout if it has been set
|
|
||||||
if !client.authenticated && client.authTimeoutSet && time.Now().After(client.authTimeout) {
|
|
||||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
|
||||||
"component", "tcp_sink",
|
|
||||||
"remote_addr", c.RemoteAddr().String())
|
|
||||||
if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() {
|
|
||||||
client.tlsBridge.Write([]byte("AUTH TIMEOUT\n"))
|
|
||||||
} else if client.tlsBridge == nil {
|
|
||||||
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
|
|
||||||
}
|
|
||||||
return gnet.Close
|
|
||||||
}
|
|
||||||
|
|
||||||
// If not authenticated, expect auth command
|
|
||||||
if !client.authenticated {
|
|
||||||
client.buffer.Write(data)
|
client.buffer.Write(data)
|
||||||
|
|
||||||
// Look for complete auth line
|
// Look for complete auth line
|
||||||
if line, err := client.buffer.ReadBytes('\n'); err == nil {
|
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 {
|
||||||
line = bytes.TrimSpace(line)
|
line := client.buffer.Bytes()[:idx]
|
||||||
|
client.buffer.Next(idx + 1)
|
||||||
|
|
||||||
// Parse AUTH command: AUTH <method> <credentials>
|
// Parse AUTH command: AUTH <method> <credentials>
|
||||||
parts := strings.SplitN(string(line), " ", 3)
|
parts := strings.SplitN(string(line), " ", 3)
|
||||||
if len(parts) != 3 || parts[0] != "AUTH" {
|
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||||
// Send error through TLS if enabled
|
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||||
errMsg := []byte("AUTH FAILED\n")
|
return gnet.Close
|
||||||
if client.tlsBridge != nil {
|
|
||||||
client.tlsBridge.Write(errMsg)
|
|
||||||
} else {
|
|
||||||
c.AsyncWrite(errMsg, nil)
|
|
||||||
}
|
|
||||||
return gnet.None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate
|
// Authenticate
|
||||||
@ -734,13 +568,7 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
"remote_addr", c.RemoteAddr().String(),
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
"method", parts[1],
|
"method", parts[1],
|
||||||
"error", err)
|
"error", err)
|
||||||
// Send error through TLS if enabled
|
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||||
errMsg := []byte("AUTH FAILED\n")
|
|
||||||
if client.tlsBridge != nil {
|
|
||||||
client.tlsBridge.Write(errMsg)
|
|
||||||
} else {
|
|
||||||
c.AsyncWrite(errMsg, nil)
|
|
||||||
}
|
|
||||||
return gnet.Close
|
return gnet.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -755,35 +583,25 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
"component", "tcp_sink",
|
"component", "tcp_sink",
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
"username", session.Username,
|
"username", session.Username,
|
||||||
"method", session.Method,
|
"method", session.Method)
|
||||||
"tls", client.tlsBridge != nil)
|
|
||||||
|
|
||||||
// Send success through TLS if enabled
|
c.AsyncWrite([]byte("AUTH_OK\n"), nil)
|
||||||
successMsg := []byte("AUTH OK\n")
|
|
||||||
if client.tlsBridge != nil {
|
|
||||||
client.tlsBridge.Write(successMsg)
|
|
||||||
} else {
|
|
||||||
c.AsyncWrite(successMsg, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear buffer after auth
|
|
||||||
client.buffer.Reset()
|
client.buffer.Reset()
|
||||||
}
|
}
|
||||||
return gnet.None
|
return gnet.None
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticated clients shouldn't send data, just discard
|
// Clients shouldn't send data, just discard
|
||||||
c.Discard(-1)
|
c.Discard(-1)
|
||||||
return gnet.None
|
return gnet.None
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configures tcp sink authentication
|
// Configures tcp sink auth
|
||||||
func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
func (t *TCPSink) SetAuth(authCfg *config.AuthConfig) {
|
||||||
if authCfg == nil || authCfg.Type == "none" {
|
if authCfg == nil || authCfg.Type == "none" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
t.authConfig = authCfg
|
|
||||||
authenticator, err := auth.New(authCfg, t.logger)
|
authenticator, err := auth.New(authCfg, t.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
|
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
|
||||||
@ -793,22 +611,7 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
|
|||||||
}
|
}
|
||||||
t.authenticator = authenticator
|
t.authenticator = authenticator
|
||||||
|
|
||||||
// Initialize TLS manager if TLS is configured
|
|
||||||
if t.config.TLS != nil && t.config.TLS.Enabled {
|
|
||||||
tlsManager, err := tls.NewManager(t.config.TLS, t.logger)
|
|
||||||
if err != nil {
|
|
||||||
t.logger.Error("msg", "Failed to create TLS manager",
|
|
||||||
"component", "tcp_sink",
|
|
||||||
"error", err)
|
|
||||||
// Continue without TLS
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.tlsManager = tlsManager
|
|
||||||
}
|
|
||||||
|
|
||||||
t.logger.Info("msg", "Authentication configured for TCP sink",
|
t.logger.Info("msg", "Authentication configured for TCP sink",
|
||||||
"component", "tcp_sink",
|
"component", "tcp_sink",
|
||||||
"auth_type", authCfg.Type,
|
"auth_type", authCfg.Type)
|
||||||
"tls_enabled", t.tlsManager != nil,
|
|
||||||
"tls_bridge", t.tlsManager != nil)
|
|
||||||
}
|
}
|
||||||
@ -2,41 +2,37 @@
|
|||||||
package sink
|
package sink
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"encoding/base64"
|
||||||
"crypto/x509"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"logwisp/src/internal/auth"
|
||||||
"logwisp/src/internal/config"
|
"logwisp/src/internal/config"
|
||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
"logwisp/src/internal/format"
|
"logwisp/src/internal/format"
|
||||||
tlspkg "logwisp/src/internal/tls"
|
|
||||||
|
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Forwards log entries to a remote TCP endpoint
|
// Forwards log entries to a remote TCP endpoint
|
||||||
type TCPClientSink struct {
|
type TCPClientSink struct {
|
||||||
input chan core.LogEntry
|
input chan core.LogEntry
|
||||||
config TCPClientConfig
|
config TCPClientConfig
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
connMu sync.RWMutex
|
connMu sync.RWMutex
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
formatter format.Formatter
|
formatter format.Formatter
|
||||||
|
authenticator *auth.Authenticator
|
||||||
// TLS support
|
|
||||||
tlsManager *tlspkg.Manager
|
|
||||||
tlsConfig *tls.Config
|
|
||||||
|
|
||||||
// Reconnection state
|
// Reconnection state
|
||||||
reconnecting atomic.Bool
|
reconnecting atomic.Bool
|
||||||
@ -60,6 +56,10 @@ type TCPClientConfig struct {
|
|||||||
ReadTimeout time.Duration
|
ReadTimeout time.Duration
|
||||||
KeepAlive time.Duration
|
KeepAlive time.Duration
|
||||||
|
|
||||||
|
// Security
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
|
||||||
// Reconnection settings
|
// Reconnection settings
|
||||||
ReconnectDelay time.Duration
|
ReconnectDelay time.Duration
|
||||||
MaxReconnectDelay time.Duration
|
MaxReconnectDelay time.Duration
|
||||||
@ -120,27 +120,11 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
|||||||
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
|
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
|
||||||
cfg.ReconnectBackoff = backoff
|
cfg.ReconnectBackoff = backoff
|
||||||
}
|
}
|
||||||
|
if username, ok := options["username"].(string); ok {
|
||||||
// Extract TLS config
|
cfg.Username = username
|
||||||
if tc, ok := options["tls"].(map[string]any); ok {
|
}
|
||||||
cfg.TLS = &config.TLSConfig{}
|
if password, ok := options["password"].(string); ok {
|
||||||
cfg.TLS.Enabled, _ = tc["enabled"].(bool)
|
cfg.Password = password
|
||||||
if certFile, ok := tc["cert_file"].(string); ok {
|
|
||||||
cfg.TLS.CertFile = certFile
|
|
||||||
}
|
|
||||||
if keyFile, ok := tc["key_file"].(string); ok {
|
|
||||||
cfg.TLS.KeyFile = keyFile
|
|
||||||
}
|
|
||||||
cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool)
|
|
||||||
if caFile, ok := tc["client_ca_file"].(string); ok {
|
|
||||||
cfg.TLS.ClientCAFile = caFile
|
|
||||||
}
|
|
||||||
if insecure, ok := tc["insecure_skip_verify"].(bool); ok {
|
|
||||||
cfg.TLS.InsecureSkipVerify = insecure
|
|
||||||
}
|
|
||||||
if caFile, ok := tc["ca_file"].(string); ok {
|
|
||||||
cfg.TLS.CAFile = caFile
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &TCPClientSink{
|
t := &TCPClientSink{
|
||||||
@ -154,62 +138,6 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
|
|||||||
t.lastProcessed.Store(time.Time{})
|
t.lastProcessed.Store(time.Time{})
|
||||||
t.connectionUptime.Store(time.Duration(0))
|
t.connectionUptime.Store(time.Duration(0))
|
||||||
|
|
||||||
// Initialize TLS manager if TLS is configured
|
|
||||||
if cfg.TLS != nil && cfg.TLS.Enabled {
|
|
||||||
// Build custom TLS config for client
|
|
||||||
t.tlsConfig = &tls.Config{
|
|
||||||
InsecureSkipVerify: cfg.TLS.InsecureSkipVerify,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract server name from address for SNI
|
|
||||||
host, _, err := net.SplitHostPort(cfg.Address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse address for SNI: %w", err)
|
|
||||||
}
|
|
||||||
t.tlsConfig.ServerName = host
|
|
||||||
|
|
||||||
// Load custom CA for server verification
|
|
||||||
if cfg.TLS.CAFile != "" {
|
|
||||||
caCert, err := os.ReadFile(cfg.TLS.CAFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.TLS.CAFile, err)
|
|
||||||
}
|
|
||||||
caCertPool := x509.NewCertPool()
|
|
||||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
|
||||||
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.TLS.CAFile)
|
|
||||||
}
|
|
||||||
t.tlsConfig.RootCAs = caCertPool
|
|
||||||
logger.Debug("msg", "Custom CA loaded for server verification",
|
|
||||||
"component", "tcp_client_sink",
|
|
||||||
"ca_file", cfg.TLS.CAFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load client certificate for mTLS
|
|
||||||
if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" {
|
|
||||||
cert, err := tls.LoadX509KeyPair(cfg.TLS.CertFile, cfg.TLS.KeyFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load client certificate: %w", err)
|
|
||||||
}
|
|
||||||
t.tlsConfig.Certificates = []tls.Certificate{cert}
|
|
||||||
logger.Info("msg", "Client certificate loaded for mTLS",
|
|
||||||
"component", "tcp_client_sink",
|
|
||||||
"cert_file", cfg.TLS.CertFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set minimum TLS version if configured
|
|
||||||
if cfg.TLS.MinVersion != "" {
|
|
||||||
t.tlsConfig.MinVersion = parseTLSVersion(cfg.TLS.MinVersion, tls.VersionTLS12)
|
|
||||||
} else {
|
|
||||||
t.tlsConfig.MinVersion = tls.VersionTLS12 // Default minimum
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("msg", "TLS enabled for TCP client",
|
|
||||||
"component", "tcp_client_sink",
|
|
||||||
"address", cfg.Address,
|
|
||||||
"server_name", host,
|
|
||||||
"insecure", cfg.TLS.InsecureSkipVerify,
|
|
||||||
"mtls", cfg.TLS.CertFile != "")
|
|
||||||
}
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -376,33 +304,44 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
|
|||||||
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap with TLS if configured
|
// Handle authentication if credentials configured
|
||||||
if t.tlsConfig != nil {
|
if t.config.Username != "" && t.config.Password != "" {
|
||||||
t.logger.Debug("msg", "Initiating TLS handshake",
|
// Read auth challenge
|
||||||
"component", "tcp_client_sink",
|
reader := bufio.NewReader(conn)
|
||||||
"address", t.config.Address)
|
challenge, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
tlsConn := tls.Client(conn, t.tlsConfig)
|
|
||||||
|
|
||||||
// Perform handshake with timeout
|
|
||||||
handshakeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := tlsConn.HandshakeContext(handshakeCtx); err != nil {
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
return nil, fmt.Errorf("failed to read auth challenge: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log connection details
|
if strings.TrimSpace(challenge) == "AUTH_REQUIRED" {
|
||||||
state := tlsConn.ConnectionState()
|
// Send credentials
|
||||||
t.logger.Info("msg", "TLS connection established",
|
creds := t.config.Username + ":" + t.config.Password
|
||||||
"component", "tcp_client_sink",
|
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
|
||||||
"address", t.config.Address,
|
authCmd := fmt.Sprintf("AUTH basic %s\n", encodedCreds)
|
||||||
"tls_version", tlsVersionString(state.Version),
|
|
||||||
"cipher_suite", tls.CipherSuiteName(state.CipherSuite),
|
|
||||||
"server_name", state.ServerName)
|
|
||||||
|
|
||||||
return tlsConn, nil
|
if _, err := conn.Write([]byte(authCmd)); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("failed to send auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
response, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("failed to read auth response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(response) != "AUTH_OK" {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("authentication failed: %s", response)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.logger.Debug("msg", "TCP authentication successful",
|
||||||
|
"component", "tcp_client_sink",
|
||||||
|
"address", t.config.Address,
|
||||||
|
"username", t.config.Username)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
@ -504,34 +443,8 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns human-readable TLS version
|
// Not applicable, Clients authenticate to remote servers using Username/Password in config
|
||||||
func tlsVersionString(version uint16) string {
|
func (h *TCPClientSink) SetAuth(authCfg *config.AuthConfig) {
|
||||||
switch version {
|
// No-op: client sinks don't validate incoming connections
|
||||||
case tls.VersionTLS10:
|
// They authenticate to remote servers using Username/Password fields
|
||||||
return "TLS1.0"
|
|
||||||
case tls.VersionTLS11:
|
|
||||||
return "TLS1.1"
|
|
||||||
case tls.VersionTLS12:
|
|
||||||
return "TLS1.2"
|
|
||||||
case tls.VersionTLS13:
|
|
||||||
return "TLS1.3"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("0x%04x", version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts string to TLS version constant
|
|
||||||
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
|
|
||||||
switch strings.ToUpper(version) {
|
|
||||||
case "TLS1.0", "TLS10":
|
|
||||||
return tls.VersionTLS10
|
|
||||||
case "TLS1.1", "TLS11":
|
|
||||||
return tls.VersionTLS11
|
|
||||||
case "TLS1.2", "TLS12":
|
|
||||||
return tls.VersionTLS12
|
|
||||||
case "TLS1.3", "TLS13":
|
|
||||||
return tls.VersionTLS13
|
|
||||||
default:
|
|
||||||
return defaultVersion
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@ -13,6 +13,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"logwisp/src/internal/config"
|
||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
|
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
@ -287,3 +288,7 @@ func globToRegex(glob string) string {
|
|||||||
regex = strings.ReplaceAll(regex, `\?`, `.`)
|
regex = strings.ReplaceAll(regex, `\?`, `.`)
|
||||||
return "^" + regex + "$"
|
return "^" + regex + "$"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ds *DirectorySource) SetAuth(auth *config.AuthConfig) {
|
||||||
|
// Authentication does not apply to directory source
|
||||||
|
}
|
||||||
@ -4,6 +4,7 @@ package source
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"logwisp/src/internal/auth"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -20,21 +21,31 @@ import (
|
|||||||
|
|
||||||
// Receives log entries via HTTP POST requests
|
// Receives log entries via HTTP POST requests
|
||||||
type HTTPSource struct {
|
type HTTPSource struct {
|
||||||
host string
|
// Config
|
||||||
port int64
|
host string
|
||||||
path string
|
port int64
|
||||||
bufferSize int64
|
path string
|
||||||
|
bufferSize int64
|
||||||
|
maxRequestBodySize int64
|
||||||
|
|
||||||
|
// Application
|
||||||
server *fasthttp.Server
|
server *fasthttp.Server
|
||||||
subscribers []chan core.LogEntry
|
subscribers []chan core.LogEntry
|
||||||
mu sync.RWMutex
|
|
||||||
done chan struct{}
|
|
||||||
wg sync.WaitGroup
|
|
||||||
netLimiter *limit.NetLimiter
|
netLimiter *limit.NetLimiter
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
||||||
// Add TLS support
|
// Runtime
|
||||||
tlsManager *tls.Manager
|
mu sync.RWMutex
|
||||||
tlsConfig *config.TLSConfig
|
done chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Security
|
||||||
|
authenticator *auth.Authenticator
|
||||||
|
authConfig *config.AuthConfig
|
||||||
|
authFailures atomic.Uint64
|
||||||
|
authSuccesses atomic.Uint64
|
||||||
|
tlsManager *tls.Manager
|
||||||
|
tlsConfig *config.TLSConfig
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
totalEntries atomic.Uint64
|
totalEntries atomic.Uint64
|
||||||
@ -66,42 +77,54 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err
|
|||||||
bufferSize = bufSize
|
bufferSize = bufSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maxRequestBodySize := int64(10 * 1024 * 1024) // fasthttp default 10MB
|
||||||
|
if maxBodySize, ok := options["max_body_size"].(int64); ok && maxBodySize > 0 && maxBodySize < maxRequestBodySize {
|
||||||
|
maxRequestBodySize = maxBodySize
|
||||||
|
}
|
||||||
|
|
||||||
h := &HTTPSource{
|
h := &HTTPSource{
|
||||||
host: host,
|
host: host,
|
||||||
port: port,
|
port: port,
|
||||||
path: ingestPath,
|
path: ingestPath,
|
||||||
bufferSize: bufferSize,
|
bufferSize: bufferSize,
|
||||||
done: make(chan struct{}),
|
maxRequestBodySize: maxRequestBodySize,
|
||||||
startTime: time.Now(),
|
done: make(chan struct{}),
|
||||||
logger: logger,
|
startTime: time.Now(),
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
h.lastEntryTime.Store(time.Time{})
|
h.lastEntryTime.Store(time.Time{})
|
||||||
|
|
||||||
// Initialize net limiter if configured
|
// Initialize net limiter if configured
|
||||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||||
if enabled, _ := rl["enabled"].(bool); enabled {
|
if enabled, _ := nl["enabled"].(bool); enabled {
|
||||||
cfg := config.NetLimitConfig{
|
cfg := config.NetLimitConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rps, ok := rl["requests_per_second"].(float64); ok {
|
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||||
cfg.RequestsPerSecond = rps
|
cfg.RequestsPerSecond = rps
|
||||||
}
|
}
|
||||||
if burst, ok := rl["burst_size"].(int64); ok {
|
if burst, ok := nl["burst_size"].(int64); ok {
|
||||||
cfg.BurstSize = burst
|
cfg.BurstSize = burst
|
||||||
}
|
}
|
||||||
if limitBy, ok := rl["limit_by"].(string); ok {
|
if respCode, ok := nl["response_code"].(int64); ok {
|
||||||
cfg.LimitBy = limitBy
|
|
||||||
}
|
|
||||||
if respCode, ok := rl["response_code"].(int64); ok {
|
|
||||||
cfg.ResponseCode = respCode
|
cfg.ResponseCode = respCode
|
||||||
}
|
}
|
||||||
if msg, ok := rl["response_message"].(string); ok {
|
if msg, ok := nl["response_message"].(string); ok {
|
||||||
cfg.ResponseMessage = msg
|
cfg.ResponseMessage = msg
|
||||||
}
|
}
|
||||||
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
|
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||||
cfg.MaxConnectionsPerIP = maxPerIP
|
cfg.MaxConnectionsPerIP = maxPerIP
|
||||||
}
|
}
|
||||||
|
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
|
||||||
|
cfg.MaxConnectionsPerUser = maxPerUser
|
||||||
|
}
|
||||||
|
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
|
||||||
|
cfg.MaxConnectionsPerToken = maxPerToken
|
||||||
|
}
|
||||||
|
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||||
|
cfg.MaxConnectionsTotal = maxTotal
|
||||||
|
}
|
||||||
|
|
||||||
h.netLimiter = limit.NewNetLimiter(cfg, logger)
|
h.netLimiter = limit.NewNetLimiter(cfg, logger)
|
||||||
}
|
}
|
||||||
@ -157,10 +180,11 @@ func (h *HTTPSource) Subscribe() <-chan core.LogEntry {
|
|||||||
|
|
||||||
func (h *HTTPSource) Start() error {
|
func (h *HTTPSource) Start() error {
|
||||||
h.server = &fasthttp.Server{
|
h.server = &fasthttp.Server{
|
||||||
Handler: h.requestHandler,
|
Handler: h.requestHandler,
|
||||||
DisableKeepalive: false,
|
DisableKeepalive: false,
|
||||||
StreamRequestBody: true,
|
StreamRequestBody: true,
|
||||||
CloseOnShutdown: true,
|
CloseOnShutdown: true,
|
||||||
|
MaxRequestBodySize: int(h.maxRequestBodySize),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use configured host and port
|
// Use configured host and port
|
||||||
@ -259,19 +283,9 @@ func (h *HTTPSource) GetStats() SourceStats {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||||
// Only handle POST to the configured ingest path
|
|
||||||
if string(ctx.Method()) != "POST" || string(ctx.Path()) != h.path {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
|
||||||
ctx.SetContentType("application/json")
|
|
||||||
json.NewEncoder(ctx).Encode(map[string]string{
|
|
||||||
"error": "Not Found",
|
|
||||||
"hint": fmt.Sprintf("POST logs to %s", h.path),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract and validate IP
|
|
||||||
remoteAddr := ctx.RemoteAddr().String()
|
remoteAddr := ctx.RemoteAddr().String()
|
||||||
|
|
||||||
|
// 1. IPv6 check (early reject)
|
||||||
ipStr, _, err := net.SplitHostPort(remoteAddr)
|
ipStr, _, err := net.SplitHostPort(remoteAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if ip := net.ParseIP(ipStr); ip != nil && ip.To4() == nil {
|
if ip := net.ParseIP(ipStr); ip != nil && ip.To4() == nil {
|
||||||
@ -284,7 +298,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check net limit
|
// 2. Net limit check (early reject)
|
||||||
if h.netLimiter != nil {
|
if h.netLimiter != nil {
|
||||||
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
|
||||||
ctx.SetStatusCode(int(statusCode))
|
ctx.SetStatusCode(int(statusCode))
|
||||||
@ -297,7 +311,65 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the request body
|
// 3. Path check (only process ingest path)
|
||||||
|
path := string(ctx.Path())
|
||||||
|
if path != h.path {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||||
|
ctx.SetContentType("application/json")
|
||||||
|
json.NewEncoder(ctx).Encode(map[string]string{
|
||||||
|
"error": "Not Found",
|
||||||
|
"hint": fmt.Sprintf("POST logs to %s", h.path),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Method check (only accept POST)
|
||||||
|
if string(ctx.Method()) != "POST" {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
|
||||||
|
ctx.SetContentType("application/json")
|
||||||
|
ctx.Response.Header.Set("Allow", "POST")
|
||||||
|
json.NewEncoder(ctx).Encode(map[string]string{
|
||||||
|
"error": "Method not allowed",
|
||||||
|
"hint": "Use POST to submit logs",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Authentication check (if configured)
|
||||||
|
if h.authenticator != nil {
|
||||||
|
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||||
|
session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
h.authFailures.Add(1)
|
||||||
|
h.logger.Warn("msg", "Authentication failed",
|
||||||
|
"component", "http_source",
|
||||||
|
"remote_addr", remoteAddr,
|
||||||
|
"error", err)
|
||||||
|
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
||||||
|
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
|
||||||
|
realm := h.authConfig.BasicAuth.Realm
|
||||||
|
if realm == "" {
|
||||||
|
realm = "Restricted"
|
||||||
|
}
|
||||||
|
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm))
|
||||||
|
} else if h.authConfig.Type == "bearer" {
|
||||||
|
ctx.Response.Header.Set("WWW-Authenticate", "Bearer")
|
||||||
|
}
|
||||||
|
ctx.SetContentType("application/json")
|
||||||
|
json.NewEncoder(ctx).Encode(map[string]string{
|
||||||
|
"error": "Unauthorized",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.authSuccesses.Add(1)
|
||||||
|
h.logger.Debug("msg", "Request authenticated",
|
||||||
|
"component", "http_source",
|
||||||
|
"remote_addr", remoteAddr,
|
||||||
|
"username", session.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Process request body
|
||||||
body := ctx.PostBody()
|
body := ctx.PostBody()
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||||
@ -308,7 +380,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the log entries
|
// 7. Parse log entries
|
||||||
entries, err := h.parseEntries(body)
|
entries, err := h.parseEntries(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.invalidEntries.Add(1)
|
h.invalidEntries.Add(1)
|
||||||
@ -320,7 +392,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish entries
|
// 8. Publish entries to subscribers
|
||||||
accepted := 0
|
accepted := 0
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if h.publish(entry) {
|
if h.publish(entry) {
|
||||||
@ -328,7 +400,7 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return success response
|
// 9. Return success response
|
||||||
ctx.SetStatusCode(fasthttp.StatusAccepted)
|
ctx.SetStatusCode(fasthttp.StatusAccepted)
|
||||||
ctx.SetContentType("application/json")
|
ctx.SetContentType("application/json")
|
||||||
json.NewEncoder(ctx).Encode(map[string]any{
|
json.NewEncoder(ctx).Encode(map[string]any{
|
||||||
@ -461,3 +533,24 @@ func splitLines(data []byte) [][]byte {
|
|||||||
|
|
||||||
return lines
|
return lines
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure HTTP source auth
|
||||||
|
func (h *HTTPSource) SetAuth(authCfg *config.AuthConfig) {
|
||||||
|
if authCfg == nil || authCfg.Type == "none" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.authConfig = authCfg
|
||||||
|
authenticator, err := auth.New(authCfg, h.logger)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("msg", "Failed to initialize authenticator for HTTP source",
|
||||||
|
"component", "http_source",
|
||||||
|
"error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.authenticator = authenticator
|
||||||
|
|
||||||
|
h.logger.Info("msg", "Authentication configured for HTTP source",
|
||||||
|
"component", "http_source",
|
||||||
|
"auth_type", authCfg.Type)
|
||||||
|
}
|
||||||
@ -4,22 +4,26 @@ package source
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"logwisp/src/internal/config"
|
||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Represents an input data stream
|
// Represents an input data stream
|
||||||
type Source interface {
|
type Source interface {
|
||||||
// Subscribe returns a channel that receives log entries
|
// Returns a channel that receives log entries
|
||||||
Subscribe() <-chan core.LogEntry
|
Subscribe() <-chan core.LogEntry
|
||||||
|
|
||||||
// Start begins reading from the source
|
// Begins reading from the source
|
||||||
Start() error
|
Start() error
|
||||||
|
|
||||||
// Stop gracefully shuts down the source
|
// Gracefully shuts down the source
|
||||||
Stop()
|
Stop()
|
||||||
|
|
||||||
// GetStats returns source statistics
|
// Returns source statistics
|
||||||
GetStats() SourceStats
|
GetStats() SourceStats
|
||||||
|
|
||||||
|
// Configure authentication
|
||||||
|
SetAuth(auth *config.AuthConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains statistics about a source
|
// Contains statistics about a source
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"logwisp/src/internal/config"
|
||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
|
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
@ -119,3 +120,7 @@ func (s *StdinSource) publish(entry core.LogEntry) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StdinSource) SetAuth(auth *config.AuthConfig) {
|
||||||
|
// Authentication does not apply to stdin source
|
||||||
|
}
|
||||||
@ -5,9 +5,9 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -16,7 +16,6 @@ import (
|
|||||||
"logwisp/src/internal/config"
|
"logwisp/src/internal/config"
|
||||||
"logwisp/src/internal/core"
|
"logwisp/src/internal/core"
|
||||||
"logwisp/src/internal/limit"
|
"logwisp/src/internal/limit"
|
||||||
"logwisp/src/internal/tls"
|
|
||||||
|
|
||||||
"github.com/lixenwraith/log"
|
"github.com/lixenwraith/log"
|
||||||
"github.com/lixenwraith/log/compat"
|
"github.com/lixenwraith/log/compat"
|
||||||
@ -24,28 +23,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxClientBufferSize = 10 * 1024 * 1024 // 10MB max per client
|
maxClientBufferSize = 10 * 1024 * 1024 // 10MB max per client
|
||||||
maxLineLength = 1 * 1024 * 1024 // 1MB max per log line
|
maxLineLength = 1 * 1024 * 1024 // 1MB max per log line
|
||||||
maxEncryptedDataPerRead = 1 * 1024 * 1024 // 1MB max encrypted data per read
|
|
||||||
maxCumulativeEncrypted = 20 * 1024 * 1024 // 20MB total encrypted before processing
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Receives log entries via TCP connections
|
// Receives log entries via TCP connections
|
||||||
type TCPSource struct {
|
type TCPSource struct {
|
||||||
host string
|
host string
|
||||||
port int64
|
port int64
|
||||||
bufferSize int64
|
bufferSize int64
|
||||||
server *tcpSourceServer
|
server *tcpSourceServer
|
||||||
subscribers []chan core.LogEntry
|
subscribers []chan core.LogEntry
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
engine *gnet.Engine
|
engine *gnet.Engine
|
||||||
engineMu sync.Mutex
|
engineMu sync.Mutex
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
netLimiter *limit.NetLimiter
|
netLimiter *limit.NetLimiter
|
||||||
tlsManager *tls.Manager
|
logger *log.Logger
|
||||||
tlsConfig *config.TLSConfig
|
authenticator *auth.Authenticator
|
||||||
logger *log.Logger
|
|
||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
totalEntries atomic.Uint64
|
totalEntries atomic.Uint64
|
||||||
@ -54,6 +50,8 @@ type TCPSource struct {
|
|||||||
activeConns atomic.Int64
|
activeConns atomic.Int64
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
lastEntryTime atomic.Value // time.Time
|
lastEntryTime atomic.Value // time.Time
|
||||||
|
authFailures atomic.Uint64
|
||||||
|
authSuccesses atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new TCP server source
|
// Creates a new TCP server source
|
||||||
@ -84,58 +82,35 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error
|
|||||||
t.lastEntryTime.Store(time.Time{})
|
t.lastEntryTime.Store(time.Time{})
|
||||||
|
|
||||||
// Initialize net limiter if configured
|
// Initialize net limiter if configured
|
||||||
if rl, ok := options["net_limit"].(map[string]any); ok {
|
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||||
if enabled, _ := rl["enabled"].(bool); enabled {
|
if enabled, _ := nl["enabled"].(bool); enabled {
|
||||||
cfg := config.NetLimitConfig{
|
cfg := config.NetLimitConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rps, ok := rl["requests_per_second"].(float64); ok {
|
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||||
cfg.RequestsPerSecond = rps
|
cfg.RequestsPerSecond = rps
|
||||||
}
|
}
|
||||||
if burst, ok := rl["burst_size"].(int64); ok {
|
if burst, ok := nl["burst_size"].(int64); ok {
|
||||||
cfg.BurstSize = burst
|
cfg.BurstSize = burst
|
||||||
}
|
}
|
||||||
if limitBy, ok := rl["limit_by"].(string); ok {
|
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||||
cfg.LimitBy = limitBy
|
|
||||||
}
|
|
||||||
if maxPerIP, ok := rl["max_connections_per_ip"].(int64); ok {
|
|
||||||
cfg.MaxConnectionsPerIP = maxPerIP
|
cfg.MaxConnectionsPerIP = maxPerIP
|
||||||
}
|
}
|
||||||
if maxTotal, ok := rl["max_total_connections"].(int64); ok {
|
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
|
||||||
cfg.MaxTotalConnections = maxTotal
|
cfg.MaxConnectionsPerUser = maxPerUser
|
||||||
|
}
|
||||||
|
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
|
||||||
|
cfg.MaxConnectionsPerToken = maxPerToken
|
||||||
|
}
|
||||||
|
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||||
|
cfg.MaxConnectionsTotal = maxTotal
|
||||||
}
|
}
|
||||||
|
|
||||||
t.netLimiter = limit.NewNetLimiter(cfg, logger)
|
t.netLimiter = limit.NewNetLimiter(cfg, logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract TLS config and initialize TLS manager
|
|
||||||
if tc, ok := options["tls"].(map[string]any); ok {
|
|
||||||
t.tlsConfig = &config.TLSConfig{}
|
|
||||||
t.tlsConfig.Enabled, _ = tc["enabled"].(bool)
|
|
||||||
if certFile, ok := tc["cert_file"].(string); ok {
|
|
||||||
t.tlsConfig.CertFile = certFile
|
|
||||||
}
|
|
||||||
if keyFile, ok := tc["key_file"].(string); ok {
|
|
||||||
t.tlsConfig.KeyFile = keyFile
|
|
||||||
}
|
|
||||||
t.tlsConfig.ClientAuth, _ = tc["client_auth"].(bool)
|
|
||||||
if caFile, ok := tc["client_ca_file"].(string); ok {
|
|
||||||
t.tlsConfig.ClientCAFile = caFile
|
|
||||||
}
|
|
||||||
t.tlsConfig.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
|
|
||||||
|
|
||||||
// Create TLS manager if enabled
|
|
||||||
if t.tlsConfig.Enabled {
|
|
||||||
tlsManager, err := tls.NewManager(t.tlsConfig, logger)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
|
||||||
}
|
|
||||||
t.tlsManager = tlsManager
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,8 +142,7 @@ func (t *TCPSource) Start() error {
|
|||||||
defer t.wg.Done()
|
defer t.wg.Done()
|
||||||
t.logger.Info("msg", "TCP source server starting",
|
t.logger.Info("msg", "TCP source server starting",
|
||||||
"component", "tcp_source",
|
"component", "tcp_source",
|
||||||
"port", t.port,
|
"port", t.port)
|
||||||
"tls_enabled", t.tlsManager != nil)
|
|
||||||
|
|
||||||
err := gnet.Run(t.server, addr,
|
err := gnet.Run(t.server, addr,
|
||||||
gnet.WithLogger(gnetLogger),
|
gnet.WithLogger(gnetLogger),
|
||||||
@ -283,9 +257,8 @@ type tcpClient struct {
|
|||||||
conn gnet.Conn
|
conn gnet.Conn
|
||||||
buffer bytes.Buffer
|
buffer bytes.Buffer
|
||||||
authenticated bool
|
authenticated bool
|
||||||
session *auth.Session
|
|
||||||
authTimeout time.Time
|
authTimeout time.Time
|
||||||
tlsBridge *tls.GNetTLSConn
|
session *auth.Session
|
||||||
maxBufferSeen int
|
maxBufferSeen int
|
||||||
cumulativeEncrypted int64
|
cumulativeEncrypted int64
|
||||||
}
|
}
|
||||||
@ -339,22 +312,17 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create client state
|
// Create client state
|
||||||
client := &tcpClient{conn: c}
|
client := &tcpClient{
|
||||||
|
conn: c,
|
||||||
// Initialize TLS bridge if enabled
|
authenticated: s.source.authenticator == nil,
|
||||||
if s.source.tlsManager != nil {
|
}
|
||||||
tlsConfig := s.source.tlsManager.GetTCPConfig()
|
|
||||||
client.tlsBridge = tls.NewServerConn(c, tlsConfig)
|
if s.source.authenticator != nil {
|
||||||
client.tlsBridge.Handshake() // Start async handshake
|
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||||
|
|
||||||
s.source.logger.Debug("msg", "TLS handshake initiated",
|
|
||||||
"component", "tcp_source",
|
|
||||||
"remote_addr", remoteAddr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create client state
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.clients[c] = &tcpClient{conn: c}
|
s.clients[c] = client
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
newCount := s.source.activeConns.Add(1)
|
newCount := s.source.activeConns.Add(1)
|
||||||
@ -362,7 +330,12 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|||||||
"component", "tcp_source",
|
"component", "tcp_source",
|
||||||
"remote_addr", remoteAddr,
|
"remote_addr", remoteAddr,
|
||||||
"active_connections", newCount,
|
"active_connections", newCount,
|
||||||
"tls_enabled", s.source.tlsManager != nil)
|
"requires_auth", s.source.authenticator != nil)
|
||||||
|
|
||||||
|
// Send auth challenge if required
|
||||||
|
if s.source.authenticator != nil {
|
||||||
|
return []byte("AUTH_REQUIRED\n"), gnet.None
|
||||||
|
}
|
||||||
|
|
||||||
return nil, gnet.None
|
return nil, gnet.None
|
||||||
}
|
}
|
||||||
@ -372,18 +345,9 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
|||||||
|
|
||||||
// Remove client state
|
// Remove client state
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
client := s.clients[c]
|
|
||||||
delete(s.clients, c)
|
delete(s.clients, c)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Clean up TLS bridge if present
|
|
||||||
if client != nil && client.tlsBridge != nil {
|
|
||||||
client.tlsBridge.Close()
|
|
||||||
s.source.logger.Debug("msg", "TLS connection closed",
|
|
||||||
"component", "tcp_source",
|
|
||||||
"remote_addr", remoteAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove connection tracking
|
// Remove connection tracking
|
||||||
if s.source.netLimiter != nil {
|
if s.source.netLimiter != nil {
|
||||||
s.source.netLimiter.RemoveConnection(remoteAddr)
|
s.source.netLimiter.RemoveConnection(remoteAddr)
|
||||||
@ -416,79 +380,64 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
return gnet.Close
|
return gnet.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check encrypted data size BEFORE processing through TLS
|
// Authentication phase
|
||||||
if len(data) > maxEncryptedDataPerRead {
|
if !client.authenticated {
|
||||||
s.source.logger.Warn("msg", "Encrypted data per read limit exceeded",
|
if time.Now().After(client.authTimeout) {
|
||||||
"component", "tcp_source",
|
s.source.logger.Warn("msg", "Authentication timeout",
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
"component", "tcp_source",
|
||||||
"data_size", len(data),
|
"remote_addr", c.RemoteAddr().String())
|
||||||
"limit", maxEncryptedDataPerRead)
|
return gnet.Close
|
||||||
s.source.invalidEntries.Add(1)
|
}
|
||||||
return gnet.Close
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track cumulative encrypted data to prevent slow accumulation
|
client.buffer.Write(data)
|
||||||
client.cumulativeEncrypted += int64(len(data))
|
|
||||||
if client.cumulativeEncrypted > maxCumulativeEncrypted {
|
|
||||||
s.source.logger.Warn("msg", "Cumulative encrypted data limit exceeded",
|
|
||||||
"component", "tcp_source",
|
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
|
||||||
"total_encrypted", client.cumulativeEncrypted,
|
|
||||||
"limit", maxCumulativeEncrypted)
|
|
||||||
s.source.invalidEntries.Add(1)
|
|
||||||
return gnet.Close
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process through TLS bridge if present
|
// Look for auth line
|
||||||
if client.tlsBridge != nil {
|
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 {
|
||||||
// Feed encrypted data into TLS engine
|
line := client.buffer.Bytes()[:idx]
|
||||||
if err := client.tlsBridge.ProcessIncoming(data); err != nil {
|
client.buffer.Next(idx + 1)
|
||||||
if errors.Is(err, tls.ErrTLSBackpressure) {
|
|
||||||
s.source.logger.Warn("msg", "TLS backpressure, closing slow client",
|
parts := strings.SplitN(string(line), " ", 3)
|
||||||
"component", "tcp_source",
|
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||||
"remote_addr", c.RemoteAddr().String())
|
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||||
} else {
|
return gnet.Close
|
||||||
s.source.logger.Error("msg", "TLS processing error",
|
}
|
||||||
|
|
||||||
|
session, err := s.source.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
s.source.authFailures.Add(1)
|
||||||
|
s.source.logger.Warn("msg", "Authentication failed",
|
||||||
"component", "tcp_source",
|
"component", "tcp_source",
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
"error", err)
|
"error", err)
|
||||||
|
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||||
|
return gnet.Close
|
||||||
}
|
}
|
||||||
return gnet.Close
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if handshake is complete
|
s.source.authSuccesses.Add(1)
|
||||||
if !client.tlsBridge.IsHandshakeDone() {
|
s.mu.Lock()
|
||||||
// Still handshaking, wait for more data
|
client.authenticated = true
|
||||||
return gnet.None
|
client.session = session
|
||||||
}
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Check handshake result
|
s.source.logger.Info("msg", "TCP client authenticated",
|
||||||
_, hsErr := client.tlsBridge.HandshakeComplete()
|
|
||||||
if hsErr != nil {
|
|
||||||
s.source.logger.Error("msg", "TLS handshake failed",
|
|
||||||
"component", "tcp_source",
|
"component", "tcp_source",
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
"error", hsErr)
|
"username", session.Username)
|
||||||
return gnet.Close
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read decrypted plaintext
|
c.AsyncWrite([]byte("AUTH_OK\n"), nil)
|
||||||
data = client.tlsBridge.Read()
|
client.buffer.Reset()
|
||||||
if data == nil || len(data) == 0 {
|
|
||||||
// No plaintext available yet
|
|
||||||
return gnet.None
|
|
||||||
}
|
}
|
||||||
// Reset cumulative counter after successful decryption and processing
|
return gnet.None
|
||||||
client.cumulativeEncrypted = 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check buffer size before appending
|
// Check if appending the new data would exceed the client buffer limit.
|
||||||
if client.buffer.Len()+len(data) > maxClientBufferSize {
|
if client.buffer.Len()+len(data) > maxClientBufferSize {
|
||||||
s.source.logger.Warn("msg", "Client buffer limit exceeded",
|
s.source.logger.Warn("msg", "Client buffer limit exceeded, closing connection.",
|
||||||
"component", "tcp_source",
|
"component", "tcp_source",
|
||||||
"remote_addr", c.RemoteAddr().String(),
|
"remote_addr", c.RemoteAddr().String(),
|
||||||
"buffer_size", client.buffer.Len(),
|
"buffer_size", client.buffer.Len(),
|
||||||
"incoming_size", len(data))
|
"incoming_size", len(data),
|
||||||
|
"limit", maxClientBufferSize)
|
||||||
s.source.invalidEntries.Add(1)
|
s.source.invalidEntries.Add(1)
|
||||||
return gnet.Close
|
return gnet.Close
|
||||||
}
|
}
|
||||||
@ -573,12 +522,22 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
|||||||
return gnet.None
|
return gnet.None
|
||||||
}
|
}
|
||||||
|
|
||||||
// noopLogger implements gnet's Logger interface but discards everything
|
// Configure TCP source auth
|
||||||
// type noopLogger struct{}
|
func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) {
|
||||||
// func (n noopLogger) Debugf(format string, args ...any) {}
|
if authCfg == nil || authCfg.Type == "none" {
|
||||||
// func (n noopLogger) Infof(format string, args ...any) {}
|
return
|
||||||
// func (n noopLogger) Warnf(format string, args ...any) {}
|
}
|
||||||
// func (n noopLogger) Errorf(format string, args ...any) {}
|
|
||||||
// func (n noopLogger) Fatalf(format string, args ...any) {}
|
|
||||||
|
|
||||||
// Usage: gnet.Run(..., gnet.WithLogger(noopLogger{}), ...)
|
authenticator, err := auth.New(authCfg, t.logger)
|
||||||
|
if err != nil {
|
||||||
|
t.logger.Error("msg", "Failed to initialize authenticator for TCP source",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.authenticator = authenticator
|
||||||
|
|
||||||
|
t.logger.Info("msg", "Authentication configured for TCP source",
|
||||||
|
"component", "tcp_source",
|
||||||
|
"auth_type", authCfg.Type)
|
||||||
|
}
|
||||||
@ -1,341 +0,0 @@
|
|||||||
// FILE: src/internal/tls/gnet_bridge.go
|
|
||||||
package tls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/panjf2000/gnet/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrTLSBackpressure = errors.New("TLS processing backpressure")
|
|
||||||
ErrConnectionClosed = errors.New("connection closed")
|
|
||||||
ErrPlaintextBufferExceeded = errors.New("plaintext buffer size exceeded")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Maximum plaintext buffer size to prevent memory exhaustion
|
|
||||||
const maxPlaintextBufferSize = 32 * 1024 * 1024 // 32MB
|
|
||||||
|
|
||||||
// Bridges gnet.Conn with crypto/tls via io.Pipe
|
|
||||||
type GNetTLSConn struct {
|
|
||||||
gnetConn gnet.Conn
|
|
||||||
tlsConn *tls.Conn
|
|
||||||
config *tls.Config
|
|
||||||
|
|
||||||
// Buffered channels for non-blocking operation
|
|
||||||
incomingCipher chan []byte // Network → TLS (encrypted)
|
|
||||||
outgoingCipher chan []byte // TLS → Network (encrypted)
|
|
||||||
|
|
||||||
// Handshake state
|
|
||||||
handshakeOnce sync.Once
|
|
||||||
handshakeDone chan struct{}
|
|
||||||
handshakeErr error
|
|
||||||
|
|
||||||
// Decrypted data buffer
|
|
||||||
plainBuf []byte
|
|
||||||
plainMu sync.Mutex
|
|
||||||
|
|
||||||
// Lifecycle
|
|
||||||
closed atomic.Bool
|
|
||||||
closeOnce sync.Once
|
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
// Error tracking
|
|
||||||
lastErr atomic.Value // error
|
|
||||||
logger interface{ Warn(args ...any) } // Minimal logger interface
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a server-side TLS bridge
|
|
||||||
func NewServerConn(gnetConn gnet.Conn, config *tls.Config) *GNetTLSConn {
|
|
||||||
tc := &GNetTLSConn{
|
|
||||||
gnetConn: gnetConn,
|
|
||||||
config: config,
|
|
||||||
handshakeDone: make(chan struct{}),
|
|
||||||
// Buffered channels sized for throughput without blocking
|
|
||||||
incomingCipher: make(chan []byte, 128), // 128 packets buffer
|
|
||||||
outgoingCipher: make(chan []byte, 128),
|
|
||||||
plainBuf: make([]byte, 0, 65536), // 64KB initial capacity
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create TLS conn with channel-based transport
|
|
||||||
rawConn := &channelConn{
|
|
||||||
incoming: tc.incomingCipher,
|
|
||||||
outgoing: tc.outgoingCipher,
|
|
||||||
localAddr: gnetConn.LocalAddr(),
|
|
||||||
remoteAddr: gnetConn.RemoteAddr(),
|
|
||||||
tc: tc,
|
|
||||||
}
|
|
||||||
tc.tlsConn = tls.Server(rawConn, config)
|
|
||||||
|
|
||||||
// Start pump goroutines
|
|
||||||
tc.wg.Add(2)
|
|
||||||
go tc.pumpCipherToNetwork()
|
|
||||||
go tc.pumpPlaintextFromTLS()
|
|
||||||
|
|
||||||
return tc
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a client-side TLS bridge (similar changes)
|
|
||||||
func NewClientConn(gnetConn gnet.Conn, config *tls.Config, serverName string) *GNetTLSConn {
|
|
||||||
tc := &GNetTLSConn{
|
|
||||||
gnetConn: gnetConn,
|
|
||||||
config: config,
|
|
||||||
handshakeDone: make(chan struct{}),
|
|
||||||
incomingCipher: make(chan []byte, 128),
|
|
||||||
outgoingCipher: make(chan []byte, 128),
|
|
||||||
plainBuf: make([]byte, 0, 65536),
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ServerName == "" {
|
|
||||||
config = config.Clone()
|
|
||||||
config.ServerName = serverName
|
|
||||||
}
|
|
||||||
|
|
||||||
rawConn := &channelConn{
|
|
||||||
incoming: tc.incomingCipher,
|
|
||||||
outgoing: tc.outgoingCipher,
|
|
||||||
localAddr: gnetConn.LocalAddr(),
|
|
||||||
remoteAddr: gnetConn.RemoteAddr(),
|
|
||||||
tc: tc,
|
|
||||||
}
|
|
||||||
tc.tlsConn = tls.Client(rawConn, config)
|
|
||||||
|
|
||||||
tc.wg.Add(2)
|
|
||||||
go tc.pumpCipherToNetwork()
|
|
||||||
go tc.pumpPlaintextFromTLS()
|
|
||||||
|
|
||||||
return tc
|
|
||||||
}
|
|
||||||
|
|
||||||
// Feeds encrypted data from network into TLS engine (non-blocking)
|
|
||||||
func (tc *GNetTLSConn) ProcessIncoming(encryptedData []byte) error {
|
|
||||||
if tc.closed.Load() {
|
|
||||||
return ErrConnectionClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-blocking send with backpressure detection
|
|
||||||
select {
|
|
||||||
case tc.incomingCipher <- encryptedData:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
// Channel full - TLS processing can't keep up
|
|
||||||
// Drop connection under backpressure vs blocking event loop
|
|
||||||
if tc.logger != nil {
|
|
||||||
tc.logger.Warn("msg", "TLS backpressure, dropping data",
|
|
||||||
"remote_addr", tc.gnetConn.RemoteAddr())
|
|
||||||
}
|
|
||||||
return ErrTLSBackpressure
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sends TLS-encrypted data to network
|
|
||||||
func (tc *GNetTLSConn) pumpCipherToNetwork() {
|
|
||||||
defer tc.wg.Done()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case data, ok := <-tc.outgoingCipher:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Send to network
|
|
||||||
if err := tc.gnetConn.AsyncWrite(data, nil); err != nil {
|
|
||||||
tc.lastErr.Store(err)
|
|
||||||
tc.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-time.After(30 * time.Second):
|
|
||||||
// Keepalive/timeout check
|
|
||||||
if tc.closed.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reads decrypted data from TLS
|
|
||||||
func (tc *GNetTLSConn) pumpPlaintextFromTLS() {
|
|
||||||
defer tc.wg.Done()
|
|
||||||
buf := make([]byte, 32768) // 32KB read buffer
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, err := tc.tlsConn.Read(buf)
|
|
||||||
if n > 0 {
|
|
||||||
tc.plainMu.Lock()
|
|
||||||
// Check buffer size limit before appending to prevent memory exhaustion
|
|
||||||
if len(tc.plainBuf)+n > maxPlaintextBufferSize {
|
|
||||||
tc.plainMu.Unlock()
|
|
||||||
// Log warning about buffer limit
|
|
||||||
if tc.logger != nil {
|
|
||||||
tc.logger.Warn("msg", "Plaintext buffer limit exceeded, closing connection",
|
|
||||||
"remote_addr", tc.gnetConn.RemoteAddr(),
|
|
||||||
"buffer_size", len(tc.plainBuf),
|
|
||||||
"incoming_size", n,
|
|
||||||
"limit", maxPlaintextBufferSize)
|
|
||||||
}
|
|
||||||
// Store error and close connection
|
|
||||||
tc.lastErr.Store(ErrPlaintextBufferExceeded)
|
|
||||||
tc.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tc.plainBuf = append(tc.plainBuf, buf[:n]...)
|
|
||||||
tc.plainMu.Unlock()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if err != io.EOF {
|
|
||||||
tc.lastErr.Store(err)
|
|
||||||
}
|
|
||||||
tc.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns available decrypted plaintext (non-blocking)
|
|
||||||
func (tc *GNetTLSConn) Read() []byte {
|
|
||||||
tc.plainMu.Lock()
|
|
||||||
defer tc.plainMu.Unlock()
|
|
||||||
|
|
||||||
if len(tc.plainBuf) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Atomic buffer swap under mutex protection to prevent race condition
|
|
||||||
data := tc.plainBuf
|
|
||||||
tc.plainBuf = make([]byte, 0, cap(tc.plainBuf))
|
|
||||||
return data
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypts plaintext and queues for network transmission
|
|
||||||
func (tc *GNetTLSConn) Write(plaintext []byte) (int, error) {
|
|
||||||
if tc.closed.Load() {
|
|
||||||
return 0, ErrConnectionClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tc.IsHandshakeDone() {
|
|
||||||
return 0, errors.New("handshake not complete")
|
|
||||||
}
|
|
||||||
|
|
||||||
return tc.tlsConn.Write(plaintext)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initiates TLS handshake asynchronously
|
|
||||||
func (tc *GNetTLSConn) Handshake() {
|
|
||||||
tc.handshakeOnce.Do(func() {
|
|
||||||
go func() {
|
|
||||||
tc.handshakeErr = tc.tlsConn.Handshake()
|
|
||||||
close(tc.handshakeDone)
|
|
||||||
}()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if handshake is complete
|
|
||||||
func (tc *GNetTLSConn) IsHandshakeDone() bool {
|
|
||||||
select {
|
|
||||||
case <-tc.handshakeDone:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Waits for handshake completion
|
|
||||||
func (tc *GNetTLSConn) HandshakeComplete() (<-chan struct{}, error) {
|
|
||||||
<-tc.handshakeDone
|
|
||||||
return tc.handshakeDone, tc.handshakeErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shuts down the bridge
|
|
||||||
func (tc *GNetTLSConn) Close() error {
|
|
||||||
tc.closeOnce.Do(func() {
|
|
||||||
tc.closed.Store(true)
|
|
||||||
|
|
||||||
// Close TLS connection
|
|
||||||
tc.tlsConn.Close()
|
|
||||||
|
|
||||||
// Close channels to stop pumps
|
|
||||||
close(tc.incomingCipher)
|
|
||||||
close(tc.outgoingCipher)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Wait for pumps to finish
|
|
||||||
tc.wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns TLS connection state
|
|
||||||
func (tc *GNetTLSConn) GetConnectionState() tls.ConnectionState {
|
|
||||||
return tc.tlsConn.ConnectionState()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns last error
|
|
||||||
func (tc *GNetTLSConn) GetError() error {
|
|
||||||
if err, ok := tc.lastErr.Load().(error); ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implements net.Conn over channels
|
|
||||||
type channelConn struct {
|
|
||||||
incoming <-chan []byte
|
|
||||||
outgoing chan<- []byte
|
|
||||||
localAddr net.Addr
|
|
||||||
remoteAddr net.Addr
|
|
||||||
tc *GNetTLSConn
|
|
||||||
readBuf []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channelConn) Read(b []byte) (int, error) {
|
|
||||||
// Use buffered read for efficiency
|
|
||||||
if len(c.readBuf) > 0 {
|
|
||||||
n := copy(b, c.readBuf)
|
|
||||||
c.readBuf = c.readBuf[n:]
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for new data
|
|
||||||
select {
|
|
||||||
case data, ok := <-c.incoming:
|
|
||||||
if !ok {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
n := copy(b, data)
|
|
||||||
if n < len(data) {
|
|
||||||
c.readBuf = data[n:] // Buffer remainder
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
case <-time.After(30 * time.Second):
|
|
||||||
return 0, errors.New("read timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channelConn) Write(b []byte) (int, error) {
|
|
||||||
if c.tc.closed.Load() {
|
|
||||||
return 0, ErrConnectionClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make a copy since TLS may hold reference
|
|
||||||
data := make([]byte, len(b))
|
|
||||||
copy(data, b)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case c.outgoing <- data:
|
|
||||||
return len(b), nil
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
return 0, errors.New("write timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channelConn) Close() error { return nil }
|
|
||||||
func (c *channelConn) LocalAddr() net.Addr { return c.localAddr }
|
|
||||||
func (c *channelConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
|
||||||
func (c *channelConn) SetDeadline(t time.Time) error { return nil }
|
|
||||||
func (c *channelConn) SetReadDeadline(t time.Time) error { return nil }
|
|
||||||
func (c *channelConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
||||||
@ -117,18 +117,6 @@ func (m *Manager) GetHTTPConfig() *tls.Config {
|
|||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns TLS config for raw TCP connections
|
|
||||||
func (m *Manager) GetTCPConfig() *tls.Config {
|
|
||||||
if m == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := m.tlsConfig.Clone()
|
|
||||||
// No ALPN for raw TCP
|
|
||||||
cfg.NextProtos = nil
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates a client certificate for mTLS
|
// Validates a client certificate for mTLS
|
||||||
func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
|
func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
|
||||||
if m == nil || !m.config.ClientAuth {
|
if m == nil || !m.config.ClientAuth {
|
||||||
@ -174,6 +162,21 @@ func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns TLS statistics
|
||||||
|
func (m *Manager) GetStats() map[string]any {
|
||||||
|
if m == nil {
|
||||||
|
return map[string]any{"enabled": false}
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"enabled": true,
|
||||||
|
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
|
||||||
|
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
|
||||||
|
"client_auth": m.config.ClientAuth,
|
||||||
|
"cipher_suites": len(m.tlsConfig.CipherSuites),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
|
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
|
||||||
switch strings.ToUpper(version) {
|
switch strings.ToUpper(version) {
|
||||||
case "TLS1.0", "TLS10":
|
case "TLS1.0", "TLS10":
|
||||||
@ -217,21 +220,6 @@ func parseCipherSuites(suites string) []uint16 {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns TLS statistics
|
|
||||||
func (m *Manager) GetStats() map[string]any {
|
|
||||||
if m == nil {
|
|
||||||
return map[string]any{"enabled": false}
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[string]any{
|
|
||||||
"enabled": true,
|
|
||||||
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
|
|
||||||
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
|
|
||||||
"client_auth": m.config.ClientAuth,
|
|
||||||
"cipher_suites": len(m.tlsConfig.CipherSuites),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func tlsVersionString(version uint16) string {
|
func tlsVersionString(version uint16) string {
|
||||||
switch version {
|
switch version {
|
||||||
case tls.VersionTLS10:
|
case tls.VersionTLS10:
|
||||||
|
|||||||
Reference in New Issue
Block a user