v0.6.0 auth restructuring, scram auth added, more tests added

This commit is contained in:
2025-10-02 17:16:43 -04:00
parent 3047e556f7
commit 490fb777ab
37 changed files with 2283 additions and 888 deletions

View File

@ -3,16 +3,16 @@
### Configuration Precedence: CLI flags > Environment > File > Defaults ### Configuration Precedence: CLI flags > Environment > File > Defaults
### Default values shown - uncommented lines represent active configuration ### Default values shown - uncommented lines represent active configuration
### Global settings ### Global Settings
background = false # Run as daemon background = false # Run as daemon
quiet = false # Suppress console output quiet = false # Suppress console output
disable_status_reporter = false # Status logging disable_status_reporter = false # Disable status logging
config_auto_reload = false # File change detection config_auto_reload = false # Reload config on file change
config_save_on_exit = false # Persist runtime changes config_save_on_exit = false # Persist runtime changes
### Logging Configuration ### Logging Configuration
[logging] [logging]
output = "stdout" # file|stdout|stderr|split|all|none output = "stdout" # file|stdout|stderr|both|none
level = "info" # debug|info|warn|error level = "info" # debug|info|warn|error
[logging.file] [logging.file]
@ -20,7 +20,7 @@ directory = "./log" # Log directory path
name = "logwisp" # Base filename name = "logwisp" # Base filename
max_size_mb = 100 # Rotation threshold max_size_mb = 100 # Rotation threshold
max_total_size_mb = 1000 # Total size limit max_total_size_mb = 1000 # Total size limit
retention_hours = 168.0 # Delete logs older than retention_hours = 168.0 # Delete logs older than (7 days)
[logging.console] [logging.console]
target = "stdout" # stdout|stderr|split target = "stdout" # stdout|stderr|split
@ -30,38 +30,56 @@ format = "txt" # txt|json
[[pipelines]] [[pipelines]]
name = "default" # Pipeline identifier name = "default" # Pipeline identifier
### Directory Sources ### Rate Limiting (Pipeline-level)
# [pipelines.rate_limit]
# rate = 0.0 # Entries per second (0=disabled)
# burst = 0.0 # Burst capacity (defaults to rate)
# policy = "pass" # pass|drop
# max_entry_size_bytes = 0 # Max entry size (0=unlimited)
### Filters
# [[pipelines.filters]]
# type = "include" # include|exclude
# logic = "or" # or|and
# patterns = [".*ERROR.*", ".*WARN.*"] # Regex patterns
### Sources
### Directory Source
[[pipelines.sources]] [[pipelines.sources]]
type = "directory" type = "directory"
[pipelines.sources.options] [pipelines.sources.options]
path = "./" # Directory to monitor path = "./" # Directory to monitor
pattern = "*.log" # Glob pattern pattern = "*.log" # Glob pattern
check_interval_ms = 100 # Scan interval check_interval_ms = 100 # Scan interval (min: 10ms)
### Console Sources ### Stdin Source
# [[pipelines.sources]] # [[pipelines.sources]]
# type = "stdin" # type = "stdin"
# [pipelines.sources.options] # [pipelines.sources.options]
# buffer_size = 1000 # Input buffer size # buffer_size = 1000 # Input buffer size
### HTTP Sources ### HTTP Source
# [[pipelines.sources]] # [[pipelines.sources]]
# type = "http" # type = "http"
# [pipelines.sources.options] # [pipelines.sources.options]
# host = "0.0.0.0" # Listen address # host = "0.0.0.0" # Listen address
# port = 8081 # Listen port # port = 8081 # Listen port
# path = "/ingest" # Ingest endpoint # ingest_path = "/ingest" # Ingest endpoint
# max_body_size = 1048576 # Max request size # buffer_size = 1000 # Input buffer size
# max_body_size = 1048576 # Max request size bytes
# [pipelines.sources.options.tls] # [pipelines.sources.options.tls]
# enabled = false # Enable TLS # enabled = false # Enable TLS
# cert_file = "" # TLS certificate # cert_file = "" # Server certificate
# key_file = "" # TLS key # key_file = "" # Server key
# client_auth = false # Require client certs # client_auth = false # Require client certs
# client_ca_file = "" # Client CA cert # client_ca_file = "" # Client CA cert
# verify_client_cert = false # Verify client certs # verify_client_cert = false # Verify client certs
# insecure_skip_verify = false # Skip verification # insecure_skip_verify = false # Skip verification (server-side)
# ca_file = "" # Custom CA file # ca_file = "" # Custom CA file
# min_version = "TLS1.2" # Min TLS version # min_version = "TLS1.2" # Min TLS version
# max_version = "TLS1.3" # Max TLS version # max_version = "TLS1.3" # Max TLS version
@ -69,8 +87,8 @@ check_interval_ms = 100 # Scan interval
# [pipelines.sources.options.net_limit] # [pipelines.sources.options.net_limit]
# enabled = false # Enable rate limiting # enabled = false # Enable rate limiting
# ip_whitelist = [] # Allowed IPs/CIDRs # ip_whitelist = [] # Allowed IPs/CIDRs (IPv4 only)
# ip_blacklist = [] # Blocked IPs/CIDRs # ip_blacklist = [] # Blocked IPs/CIDRs (IPv4 only)
# requests_per_second = 100.0 # Rate limit per client # requests_per_second = 100.0 # Rate limit per client
# burst_size = 100 # Burst capacity # burst_size = 100 # Burst capacity
# response_code = 429 # HTTP status when limited # response_code = 429 # HTTP status when limited
@ -78,59 +96,32 @@ check_interval_ms = 100 # Scan interval
# max_connections_per_ip = 10 # Max concurrent per IP # max_connections_per_ip = 10 # Max concurrent per IP
# max_connections_total = 1000 # Max total connections # max_connections_total = 1000 # Max total connections
### TCP Sources ### TCP Source
# [[pipelines.sources]] # [[pipelines.sources]]
# type = "tcp" # type = "tcp"
# [pipelines.sources.options] # [pipelines.sources.options]
# host = "0.0.0.0" # Listen address # host = "0.0.0.0" # Listen address
# port = 9091 # Listen port # port = 9091 # Listen port
# buffer_size = 1000 # Input buffer size
# [pipelines.sources.options.net_limit] # [pipelines.sources.options.net_limit]
# enabled = false # Enable rate limiting # enabled = false # Enable rate limiting
# ip_whitelist = [] # Allowed IPs/CIDRs # ip_whitelist = [] # Allowed IPs/CIDRs (IPv4 only)
# ip_blacklist = [] # Blocked IPs/CIDRs # ip_blacklist = [] # Blocked IPs/CIDRs (IPv4 only)
# requests_per_second = 100.0 # Rate limit per client # requests_per_second = 100.0 # Rate limit per client
# burst_size = 100 # Burst capacity # burst_size = 100 # Burst capacity
# response_code = 429 # Response code when limited # response_code = 429 # TCP rejection
# response_message = "Rate limit exceeded" # response_message = "Rate limit exceeded"
# max_connections_per_ip = 10 # Max concurrent per IP # max_connections_per_ip = 10 # Max concurrent per IP
# max_connections_per_user = 10 # Max concurrent per user # max_connections_per_user = 10 # Max concurrent per user
# max_connections_per_token = 10 # Max concurrent per token # max_connections_per_token = 10 # Max concurrent per token
# max_connections_total = 1000 # Max total connections # max_connections_total = 1000 # Max total connections
### Rate limiting ### Format Configuration
# [pipelines.rate_limit]
# rate = 0.0 # Entries/second (0=unlimited)
# burst = 0.0 # Burst capacity
# policy = "drop" # pass|drop
# max_entry_size_bytes = 0 # Entry size limit
### Filters ### Raw formatter (default - passes through unchanged)
# [[pipelines.filters]] # format = "raw"
# type = "include" # include|exclude
# logic = "or" # or|and
# patterns = [] # Regex patterns
## Examples of filter patterns:
## Include only error or fatal logs containing "database":
## type = "include"
## logic = "and"
## patterns = ["(?i)(error|fatal)", "database"]
##
## Exclude debug logs from test environment:
## type = "exclude"
## logic = "or"
## patterns = ["(?i)debug", "test-env"]
##
## Include only JSON formatted logs:
## type = "include"
## patterns = ["^\\{.*\\}$"]
### Format
### Raw formatter (default)
# format = "raw" # raw|json|text
### No options for raw formatter ### No options for raw formatter
### JSON formatter ### JSON formatter
@ -143,12 +134,14 @@ check_interval_ms = 100 # Scan interval
# source_field = "source" # Source field name # source_field = "source" # Source field name
### Text formatter ### Text formatter
# format = "text" # format = "txt"
# [pipelines.format_options] # [pipelines.format_options]
# template = "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}" # template = "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}"
# timestamp_format = "2006-01-02T15:04:05Z07:00" # Go time format # timestamp_format = "2006-01-02T15:04:05Z07:00" # Go time format
### HTTP Sinks ### Sinks
### HTTP Sink (SSE Server)
[[pipelines.sinks]] [[pipelines.sinks]]
type = "http" type = "http"
@ -162,14 +155,14 @@ status_path = "/status" # Status endpoint
[pipelines.sinks.options.heartbeat] [pipelines.sinks.options.heartbeat]
enabled = true # Send heartbeats enabled = true # Send heartbeats
interval_seconds = 30 # Heartbeat interval interval_seconds = 30 # Heartbeat interval
include_timestamp = true # Include time include_timestamp = true # Include timestamp
include_stats = false # Include statistics include_stats = false # Include statistics
format = "comment" # comment|message format = "comment" # comment|message
# [pipelines.sinks.options.tls] # [pipelines.sinks.options.tls]
# enabled = false # Enable TLS # enabled = false # Enable TLS
# cert_file = "" # TLS certificate # cert_file = "" # Server certificate
# key_file = "" # TLS key # key_file = "" # Server key
# client_auth = false # Require client certs # client_auth = false # Require client certs
# client_ca_file = "" # Client CA cert # client_ca_file = "" # Client CA cert
# verify_client_cert = false # Verify client certs # verify_client_cert = false # Verify client certs
@ -181,8 +174,8 @@ format = "comment" # comment|message
# [pipelines.sinks.options.net_limit] # [pipelines.sinks.options.net_limit]
# enabled = false # Enable rate limiting # enabled = false # Enable rate limiting
# ip_whitelist = [] # Allowed IPs/CIDRs # ip_whitelist = [] # Allowed IPs/CIDRs (IPv4 only)
# ip_blacklist = [] # Blocked IPs/CIDRs # ip_blacklist = [] # Blocked IPs/CIDRs (IPv4 only)
# requests_per_second = 100.0 # Rate limit per client # requests_per_second = 100.0 # Rate limit per client
# burst_size = 100 # Burst capacity # burst_size = 100 # Burst capacity
# response_code = 429 # HTTP status when limited # response_code = 429 # HTTP status when limited
@ -190,7 +183,7 @@ format = "comment" # comment|message
# max_connections_per_ip = 10 # Max concurrent per IP # max_connections_per_ip = 10 # Max concurrent per IP
# max_connections_total = 1000 # Max total connections # max_connections_total = 1000 # Max total connections
### TCP Sinks ### TCP Sink (TCP Server)
# [[pipelines.sinks]] # [[pipelines.sinks]]
# type = "tcp" # type = "tcp"
@ -198,28 +191,33 @@ format = "comment" # comment|message
# host = "0.0.0.0" # Listen address # host = "0.0.0.0" # Listen address
# port = 9090 # Server port # port = 9090 # Server port
# buffer_size = 1000 # Buffer size # buffer_size = 1000 # Buffer size
# auth_type = "none" # none|scram
# [pipelines.sinks.options.heartbeat] # [pipelines.sinks.options.heartbeat]
# enabled = false # Send heartbeats # enabled = false # Send heartbeats
# interval_seconds = 30 # Heartbeat interval # interval_seconds = 30 # Heartbeat interval
# include_timestamp = false # Include time # include_timestamp = false # Include timestamp
# include_stats = false # Include statistics # include_stats = false # Include statistics
# format = "comment" # comment|message # format = "comment" # comment|message
# [pipelines.sinks.options.net_limit] # [pipelines.sinks.options.net_limit]
# enabled = false # Enable rate limiting # enabled = false # Enable rate limiting
# ip_whitelist = [] # Allowed IPs/CIDRs # ip_whitelist = [] # Allowed IPs/CIDRs (IPv4 only)
# ip_blacklist = [] # Blocked IPs/CIDRs # ip_blacklist = [] # Blocked IPs/CIDRs (IPv4 only)
# requests_per_second = 100.0 # Rate limit per client # requests_per_second = 100.0 # Rate limit per client
# burst_size = 100 # Burst capacity # burst_size = 100 # Burst capacity
# response_code = 429 # HTTP status when limited # response_code = 429 # TCP rejection code
# response_message = "Rate limit exceeded" # response_message = "Rate limit exceeded"
# max_connections_per_ip = 10 # Max concurrent per IP # max_connections_per_ip = 10 # Max concurrent per IP
# max_connections_per_user = 10 # Max concurrent per user # max_connections_per_user = 10 # Max concurrent per user
# max_connections_per_token = 10 # Max concurrent per token # max_connections_per_token = 10 # Max concurrent per token
# max_connections_total = 1000 # Max total connections # max_connections_total = 1000 # Max total connections
### HTTP Client Sinks # [pipelines.sinks.options.scram]
# username = "" # SCRAM auth username
# password = "" # SCRAM auth password
### HTTP Client Sink (Forward to remote HTTP endpoint)
# [[pipelines.sinks]] # [[pipelines.sinks]]
# type = "http_client" # type = "http_client"
@ -231,16 +229,31 @@ format = "comment" # comment|message
# timeout_seconds = 30 # Request timeout # timeout_seconds = 30 # Request timeout
# max_retries = 3 # Retry attempts # max_retries = 3 # Retry attempts
# retry_delay_ms = 1000 # Initial retry delay # retry_delay_ms = 1000 # Initial retry delay
# retry_backoff = 2.0 # Exponential backoff # retry_backoff = 2.0 # Exponential backoff multiplier
# insecure_skip_verify = false # Skip TLS verification # insecure_skip_verify = false # Skip TLS verification
# ca_file = "" # Custom CA certificate # auth_type = "none" # none|basic|bearer|mtls
# headers = {} # Custom HTTP headers
# [pipelines.sinks.options.basic]
# username = "" # Basic auth username
# password_hash = "" # Argon2 password hash
# [pipelines.sinks.options.bearer]
# token = "" # Bearer token
====== not needed:
## Custom HTTP headers
# [pipelines.sinks.options.headers]
# Content-Type = "application/json"
# Authorization = "Bearer token"
## Client certificate for mTLS
# [pipelines.sinks.options.tls] # [pipelines.sinks.options.tls]
# ca_file = "" # Custom CA certificate
# cert_file = "" # Client certificate # cert_file = "" # Client certificate
# key_file = "" # Client key # key_file = "" # Client key
### TCP Client Sinks ### TCP Client Sink (Forward to remote TCP endpoint)
# [[pipelines.sinks]] # [[pipelines.sinks]]
# type = "tcp_client" # type = "tcp_client"
@ -253,23 +266,21 @@ format = "comment" # comment|message
# keep_alive_seconds = 30 # TCP keepalive # keep_alive_seconds = 30 # TCP keepalive
# reconnect_delay_ms = 1000 # Initial reconnect delay # reconnect_delay_ms = 1000 # Initial reconnect delay
# max_reconnect_delay_seconds = 30 # Max reconnect delay # max_reconnect_delay_seconds = 30 # Max reconnect delay
# reconnect_backoff = 1.5 # Exponential backoff # reconnect_backoff = 1.5 # Exponential backoff multiplier
# [pipelines.sinks.options.tls]
# enabled = false # Enable TLS
# insecure_skip_verify = false # Skip verification
# ca_file = "" # Custom CA certificate
# cert_file = "" # Client certificate
# key_file = "" # Client key
### File Sinks # [pipelines.sinks.options.scram]
# username = "" # Auth username
# password_hash = "" # Argon2 password hash
### File Sink
# [[pipelines.sinks]] # [[pipelines.sinks]]
# type = "file" # type = "file"
# [pipelines.sinks.options] # [pipelines.sinks.options]
# directory = "" # Output dir (required) # directory = "./" # Output dir
# name = "" # Base name (required) # name = "logwisp.output" # Base name
# buffer_size = 1000 # Input channel buffer size # buffer_size = 1000 # Input channel buffer
# max_size_mb = 100 # Rotation size # max_size_mb = 100 # Rotation size
# max_total_size_mb = 0 # Total limit (0=unlimited) # max_total_size_mb = 0 # Total limit (0=unlimited)
# retention_hours = 0.0 # Retention (0=disabled) # retention_hours = 0.0 # Retention (0=disabled)
@ -277,39 +288,32 @@ format = "comment" # comment|message
### Console Sinks ### Console Sinks
# [[pipelines.sinks]] # [[pipelines.sinks]]
# type = "stdout" # type = "console"
# [pipelines.sinks.options] # [pipelines.sinks.options]
# target = "stdout" # stdout|stderr|split
# buffer_size = 1000 # Buffer size # buffer_size = 1000 # Buffer size
# target = "stdout" # Override for split mode
# [[pipelines.sinks]] ### Authentication Configuration
# type = "stderr"
# [pipelines.sinks.options]
# buffer_size = 1000 # Buffer size
# target = "stderr" # Override for split mode
### Authentication
# [pipelines.auth] # [pipelines.auth]
# type = "none" # none|basic|bearer|mtls # type = "none" # none|basic|bearer|mtls
### Basic authentication ### Basic Authentication
# [pipelines.auth.basic_auth] # [pipelines.auth.basic_auth]
# realm = "LogWisp" # WWW-Authenticate realm # realm = "LogWisp" # WWW-Authenticate realm
# users_file = "" # External users file # users_file = "" # External users file path
# [[pipelines.auth.basic_auth.users]] # [[pipelines.auth.basic_auth.users]]
# username = "" # Username # username = "" # Username
# password_hash = "" # bcrypt hash # password_hash = "" # Argon2 password hash
### Bearer authentication ### Bearer Token Authentication
# [pipelines.auth.bearer_auth] # [pipelines.auth.bearer_auth]
# tokens = [] # Static bearer tokens # tokens = [] # Static bearer tokens
### JWT validation ### JWT Validation
# [pipelines.auth.bearer_auth.jwt] # [pipelines.auth.bearer_auth.jwt]
# jwks_url = "" # JWKS endpoint # jwks_url = "" # JWKS endpoint for key discovery
# signing_key = "" # Static signing key # signing_key = "" # Static signing key (if not using JWKS)
# issuer = "" # Expected issuer # issuer = "" # Expected issuer claim
# audience = "" # Expected audience # audience = "" # Expected audience claim

3
go.mod
View File

@ -3,10 +3,10 @@ module logwisp
go 1.25.1 go 1.25.1
require ( require (
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2
github.com/panjf2000/gnet/v2 v2.9.4 github.com/panjf2000/gnet/v2 v2.9.4
github.com/stretchr/testify v1.10.0
github.com/valyala/fasthttp v1.66.0 github.com/valyala/fasthttp v1.66.0
golang.org/x/crypto v0.42.0 golang.org/x/crypto v0.42.0
golang.org/x/term v0.35.0 golang.org/x/term v0.35.0
@ -20,6 +20,7 @@ require (
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/panjf2000/ants/v2 v2.11.3 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect go.uber.org/zap v1.27.0 // indirect

4
go.sum
View File

@ -6,14 +6,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw= github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw=
github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw= github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0= github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0=
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0= github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e h1:/XWCqFdSOiUf0/a5a63GHsvEdpglsYfn3qieNxTeyDc=
github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 h1:9Qf+BR83sKjok2E1Nct+3Sfzoj2dLGwC/zyQDVNmmqs= github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 h1:9Qf+BR83sKjok2E1Nct+3Sfzoj2dLGwC/zyQDVNmmqs=
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0= github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=

View File

@ -105,7 +105,7 @@ func (c *helpCommand) Description() string {
type authCommand struct{} type authCommand struct{}
func (c *authCommand) Execute(args []string) error { func (c *authCommand) Execute(args []string) error {
gen := auth.NewGeneratorCommand() gen := auth.NewAuthGeneratorCommand()
return gen.Execute(args) return gen.Execute(args)
} }

View File

@ -183,11 +183,13 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
"name", name) "name", name)
} }
case "stdout", "stderr": case "console":
if target, ok := sinkCfg.Options["target"].(string); ok {
logger.Info("msg", "Console sink configured", logger.Info("msg", "Console sink configured",
"pipeline", cfg.Name, "pipeline", cfg.Name,
"sink_index", i, "sink_index", i,
"type", sinkCfg.Type) "target", target)
}
} }
} }

View File

@ -2,22 +2,17 @@
package auth package auth
import ( import (
"bufio"
"crypto/rand" "crypto/rand"
"crypto/subtle"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net" "net"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
"logwisp/src/internal/config" "logwisp/src/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
"golang.org/x/crypto/argon2"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
@ -28,10 +23,7 @@ const maxAuthTrackedIPs = 10000
type Authenticator struct { type Authenticator struct {
config *config.AuthConfig config *config.AuthConfig
logger *log.Logger logger *log.Logger
basicUsers map[string]string // username -> password hash
bearerTokens map[string]bool // token -> valid bearerTokens map[string]bool // token -> valid
jwtParser *jwt.Parser
jwtKeyFunc jwt.Keyfunc
mu sync.RWMutex mu sync.RWMutex
// Session tracking // Session tracking
@ -55,69 +47,32 @@ type ipAuthState struct {
type Session struct { type Session struct {
ID string ID string
Username string Username string
Method string // basic, bearer, jwt, mtls Method string // basic, bearer, mtls
RemoteAddr string RemoteAddr string
CreatedAt time.Time CreatedAt time.Time
LastActivity time.Time LastActivity time.Time
Metadata map[string]any
} }
// Creates a new authenticator from config // Creates a new authenticator from config
func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
if cfg == nil || cfg.Type == "none" { // SCRAM is handled by ScramManager in sources
if cfg == nil || cfg.Type == "none" || cfg.Type == "scram" {
return nil, nil return nil, nil
} }
a := &Authenticator{ a := &Authenticator{
config: cfg, config: cfg,
logger: logger, logger: logger,
basicUsers: make(map[string]string),
bearerTokens: make(map[string]bool), bearerTokens: make(map[string]bool),
sessions: make(map[string]*Session), sessions: make(map[string]*Session),
ipAuthAttempts: make(map[string]*ipAuthState), ipAuthAttempts: make(map[string]*ipAuthState),
} }
// Initialize Basic Auth users
if cfg.Type == "basic" && cfg.BasicAuth != nil {
for _, user := range cfg.BasicAuth.Users {
a.basicUsers[user.Username] = user.PasswordHash
}
// Load users from file if specified
if cfg.BasicAuth.UsersFile != "" {
if err := a.loadUsersFile(cfg.BasicAuth.UsersFile); err != nil {
return nil, fmt.Errorf("failed to load users file: %w", err)
}
}
}
// Initialize Bearer tokens // Initialize Bearer tokens
if cfg.Type == "bearer" && cfg.BearerAuth != nil { if cfg.Type == "bearer" && cfg.BearerAuth != nil {
for _, token := range cfg.BearerAuth.Tokens { for _, token := range cfg.BearerAuth.Tokens {
a.bearerTokens[token] = true a.bearerTokens[token] = true
} }
// Setup JWT validation if configured
if cfg.BearerAuth.JWT != nil {
a.jwtParser = jwt.NewParser(
jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}),
jwt.WithLeeway(5*time.Second),
jwt.WithExpirationRequired(),
)
// Setup key function
if cfg.BearerAuth.JWT.SigningKey != "" {
// Static key
key := []byte(cfg.BearerAuth.JWT.SigningKey)
a.jwtKeyFunc = func(token *jwt.Token) (any, error) {
return key, nil
}
} else if cfg.BearerAuth.JWT.JWKSURL != "" {
// JWKS support would require additional implementation
// ☢ SECURITY: JWKS rotation not implemented - tokens won't refresh keys
return nil, fmt.Errorf("JWKS support not yet implemented")
}
}
} }
// Start session cleanup // Start session cleanup
@ -276,8 +231,6 @@ func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Sessio
var err error var err error
switch a.config.Type { switch a.config.Type {
case "basic":
session, err = a.authenticateBasic(authHeader, remoteAddr)
case "bearer": case "bearer":
session, err = a.authenticateBearer(authHeader, remoteAddr) session, err = a.authenticateBearer(authHeader, remoteAddr)
default: default:
@ -322,24 +275,6 @@ func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string)
session, err = a.validateToken(credentials, remoteAddr) session, err = a.validateToken(credentials, remoteAddr)
} }
case "basic":
if a.config.Type != "basic" {
err = fmt.Errorf("basic auth not configured")
} else {
// Expect base64(username:password)
decoded, decErr := base64.StdEncoding.DecodeString(credentials)
if decErr != nil {
err = fmt.Errorf("invalid credentials encoding")
} else {
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) != 2 {
err = fmt.Errorf("invalid credentials format")
} else {
session, err = a.validateBasicAuth(parts[0], parts[1], remoteAddr)
}
}
}
default: default:
err = fmt.Errorf("unsupported auth method: %s", method) err = fmt.Errorf("unsupported auth method: %s", method)
} }
@ -355,91 +290,6 @@ func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string)
return session, nil return session, nil
} }
func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) {
if !strings.HasPrefix(authHeader, "Basic ") {
return nil, fmt.Errorf("invalid basic auth header")
}
payload, err := base64.StdEncoding.DecodeString(authHeader[6:])
if err != nil {
return nil, fmt.Errorf("invalid base64 encoding")
}
parts := strings.SplitN(string(payload), ":", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid credentials format")
}
return a.validateBasicAuth(parts[0], parts[1], remoteAddr)
}
func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string) (*Session, error) {
a.mu.RLock()
expectedHash, exists := a.basicUsers[username]
a.mu.RUnlock()
if !exists {
// Perform argon2 anyway to prevent timing attacks
dummySalt := make([]byte, 16)
argon2.IDKey([]byte(password), dummySalt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
return nil, fmt.Errorf("invalid credentials")
}
// Parse PHC format hash
if !verifyArgon2idHash(password, expectedHash) {
return nil, fmt.Errorf("invalid credentials")
}
session := &Session{
ID: generateSessionID(),
Username: username,
Method: "basic",
RemoteAddr: remoteAddr,
CreatedAt: time.Now(),
LastActivity: time.Now(),
}
a.storeSession(session)
return session, nil
}
// Verify Argon2id hashes
func verifyArgon2idHash(password, hash string) bool {
// Parse PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
parts := strings.Split(hash, "$")
if len(parts) != 6 || parts[1] != "argon2id" {
return false
}
// Parse version
var version int
fmt.Sscanf(parts[2], "v=%d", &version)
if version != argon2.Version {
return false
}
// Parse parameters
var memory, time, threads uint32
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
// Decode salt and hash
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return false
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return false
}
// Compute hash
computedHash := argon2.IDKey([]byte(password), salt, time, memory, uint8(threads), uint32(len(expectedHash)))
// Constant time comparison
return subtle.ConstantTimeCompare(computedHash, expectedHash) == 1
}
func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) { func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) {
if !strings.HasPrefix(authHeader, "Bearer ") { if !strings.HasPrefix(authHeader, "Bearer ") {
return nil, fmt.Errorf("invalid bearer auth header") return nil, fmt.Errorf("invalid bearer auth header")
@ -452,97 +302,22 @@ func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Sess
func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) { func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) {
// Check static tokens first // Check static tokens first
a.mu.RLock() a.mu.RLock()
isStatic := a.bearerTokens[token] isValid := a.bearerTokens[token]
a.mu.RUnlock() a.mu.RUnlock()
if isStatic { if !isValid {
return nil, fmt.Errorf("invalid token")
}
session := &Session{ session := &Session{
ID: generateSessionID(), ID: generateSessionID(),
Method: "bearer", Method: "bearer",
RemoteAddr: remoteAddr, RemoteAddr: remoteAddr,
CreatedAt: time.Now(), CreatedAt: time.Now(),
LastActivity: time.Now(), LastActivity: time.Now(),
Metadata: map[string]any{"token_type": "static"},
} }
a.storeSession(session) a.storeSession(session)
return session, nil return session, nil
}
// Try JWT validation if configured
if a.jwtParser != nil && a.jwtKeyFunc != nil {
claims := jwt.MapClaims{}
parsedToken, err := a.jwtParser.ParseWithClaims(token, claims, a.jwtKeyFunc)
if err != nil {
return nil, fmt.Errorf("JWT validation failed: %w", err)
}
if !parsedToken.Valid {
return nil, fmt.Errorf("invalid JWT token")
}
// Explicit expiration check
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return nil, fmt.Errorf("token expired")
}
} else {
// Reject tokens without expiration
return nil, fmt.Errorf("token missing expiration claim")
}
// Check not-before claim
if nbf, ok := claims["nbf"].(float64); ok {
if time.Now().Unix() < int64(nbf) {
return nil, fmt.Errorf("token not yet valid")
}
}
// Check issuer if configured
if a.config.BearerAuth.JWT.Issuer != "" {
if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer {
return nil, fmt.Errorf("invalid token issuer")
}
}
// Check audience if configured
if a.config.BearerAuth.JWT.Audience != "" {
// Handle both string and []string audience formats
audValid := false
switch aud := claims["aud"].(type) {
case string:
audValid = aud == a.config.BearerAuth.JWT.Audience
case []any:
for _, aa := range aud {
if audStr, ok := aa.(string); ok && audStr == a.config.BearerAuth.JWT.Audience {
audValid = true
break
}
}
}
if !audValid {
return nil, fmt.Errorf("invalid token audience")
}
}
username := ""
if sub, ok := claims["sub"].(string); ok {
username = sub
}
session := &Session{
ID: generateSessionID(),
Username: username,
Method: "jwt",
RemoteAddr: remoteAddr,
CreatedAt: time.Now(),
LastActivity: time.Now(),
Metadata: map[string]any{"claims": claims},
}
a.storeSession(session)
return session, nil
}
return nil, fmt.Errorf("invalid token")
} }
func (a *Authenticator) storeSession(session *Session) { func (a *Authenticator) storeSession(session *Session) {
@ -598,49 +373,6 @@ func (a *Authenticator) authAttemptCleanup() {
} }
} }
func (a *Authenticator) loadUsersFile(path string) error {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open users file: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
lineNumber := 0
for scanner.Scan() {
lineNumber++
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue // Skip empty lines and comments
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
a.logger.Warn("msg", "Skipping malformed line in users file",
"component", "auth",
"path", path,
"line_number", lineNumber)
continue
}
username, hash := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
if username != "" && hash != "" {
// File-based users can overwrite inline users if names conflict
a.basicUsers[username] = hash
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading users file: %w", err)
}
a.logger.Info("msg", "Loaded users from file",
"component", "auth",
"path", path,
"user_count", len(a.basicUsers))
return nil
}
func generateSessionID() string { func generateSessionID() string {
b := make([]byte, 32) b := make([]byte, 32)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
@ -686,7 +418,6 @@ func (a *Authenticator) GetStats() map[string]any {
"enabled": true, "enabled": true,
"type": a.config.Type, "type": a.config.Type,
"active_sessions": sessionCount, "active_sessions": sessionCount,
"basic_users": len(a.basicUsers),
"static_tokens": len(a.bearerTokens), "static_tokens": len(a.bearerTokens),
} }
} }

View File

@ -10,6 +10,8 @@ import (
"os" "os"
"syscall" "syscall"
"logwisp/src/internal/scram"
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
"golang.org/x/term" "golang.org/x/term"
) )
@ -23,40 +25,45 @@ const (
argon2KeyLen = 32 argon2KeyLen = 32
) )
type GeneratorCommand struct { type AuthGeneratorCommand struct {
output io.Writer output io.Writer
errOut io.Writer errOut io.Writer
} }
func NewGeneratorCommand() *GeneratorCommand { func NewAuthGeneratorCommand() *AuthGeneratorCommand {
return &GeneratorCommand{ return &AuthGeneratorCommand{
output: os.Stdout, output: os.Stdout,
errOut: os.Stderr, errOut: os.Stderr,
} }
} }
func (g *GeneratorCommand) Execute(args []string) error { func (ag *AuthGeneratorCommand) Execute(args []string) error {
cmd := flag.NewFlagSet("auth", flag.ContinueOnError) cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
cmd.SetOutput(g.errOut) cmd.SetOutput(ag.errOut)
var ( var (
username = cmd.String("u", "", "Username for basic auth") username = cmd.String("u", "", "Username")
password = cmd.String("p", "", "Password to hash (will prompt if not provided)") password = cmd.String("p", "", "Password (will prompt if not provided)")
authType = cmd.String("type", "basic", "Auth type: basic (HTTP) or scram (TCP)")
genToken = cmd.Bool("t", false, "Generate random bearer token") genToken = cmd.Bool("t", false, "Generate random bearer token")
tokenLen = cmd.Int("l", 32, "Token length in bytes") tokenLen = cmd.Int("l", 32, "Token length in bytes (min 16, max 512)")
) )
cmd.Usage = func() { cmd.Usage = func() {
fmt.Fprintln(g.errOut, "Generate authentication credentials for LogWisp") fmt.Fprintln(ag.errOut, "Generate authentication credentials for LogWisp")
fmt.Fprintln(g.errOut, "\nUsage: logwisp auth [options]") fmt.Fprintln(ag.errOut, "\nUsage: logwisp auth [options]")
fmt.Fprintln(g.errOut, "\nExamples:") fmt.Fprintln(ag.errOut, "\nExamples:")
fmt.Fprintln(g.errOut, " # Generate Argon2id hash for user") fmt.Fprintln(ag.errOut, " # Generate basic auth hash for HTTP sources/sinks")
fmt.Fprintln(g.errOut, " logwisp auth -u admin") fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type basic")
fmt.Fprintln(g.errOut, " ") fmt.Fprintln(ag.errOut, " ")
fmt.Fprintln(g.errOut, " # Generate 64-byte bearer token") fmt.Fprintln(ag.errOut, " # Generate SCRAM credentials for TCP sources/sinks")
fmt.Fprintln(g.errOut, " logwisp auth -t -l 64") fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type scram")
fmt.Fprintln(g.errOut, "\nOptions:") fmt.Fprintln(ag.errOut, " ")
fmt.Fprintln(ag.errOut, " # Generate 64-byte bearer token")
fmt.Fprintln(ag.errOut, " logwisp auth -t -l 64")
fmt.Fprintln(ag.errOut, "\nOptions:")
cmd.PrintDefaults() cmd.PrintDefaults()
fmt.Fprintln(ag.errOut)
} }
if err := cmd.Parse(args); err != nil { if err := cmd.Parse(args); err != nil {
@ -64,22 +71,29 @@ func (g *GeneratorCommand) Execute(args []string) error {
} }
if *genToken { if *genToken {
return g.generateToken(*tokenLen) return ag.generateToken(*tokenLen)
} }
if *username == "" { if *username == "" {
cmd.Usage() cmd.Usage()
return fmt.Errorf("username required for password hash generation") return fmt.Errorf("username required for credential generation")
} }
return g.generatePasswordHash(*username, *password) switch *authType {
case "basic":
return ag.generateBasicAuth(*username, *password)
case "scram":
return ag.generateScramAuth(*username, *password)
default:
return fmt.Errorf("invalid auth type: %s (use 'basic' or 'scram')", *authType)
}
} }
func (g *GeneratorCommand) generatePasswordHash(username, password string) error { func (ag *AuthGeneratorCommand) generateBasicAuth(username, password string) error {
// Get password if not provided // Get password if not provided
if password == "" { if password == "" {
pass1 := g.promptPassword("Enter password: ") pass1 := ag.promptPassword("Enter password: ")
pass2 := g.promptPassword("Confirm password: ") pass2 := ag.promptPassword("Confirm password: ")
if pass1 != pass2 { if pass1 != pass2 {
return fmt.Errorf("passwords don't match") return fmt.Errorf("passwords don't match")
} }
@ -102,20 +116,61 @@ func (g *GeneratorCommand) generatePasswordHash(username, password string) error
argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64) argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64)
// Output configuration snippets // Output configuration snippets
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):") fmt.Fprintln(ag.output, "\n# Basic Auth Configuration (HTTP sources/sinks)")
fmt.Fprintln(g.output, "[[pipelines.auth.basic_auth.users]]") fmt.Fprintln(ag.output, "# REQUIRES HTTPS/TLS for security")
fmt.Fprintf(g.output, "username = %q\n", username) fmt.Fprintln(ag.output, "# Add to logwisp.toml under [[pipelines]]:")
fmt.Fprintf(g.output, "password_hash = %q\n\n", phcHash) fmt.Fprintln(ag.output, "")
fmt.Fprintln(ag.output, "[pipelines.auth]")
fmt.Fprintln(g.output, "# Users File Format (for external auth file):") fmt.Fprintln(ag.output, `type = "basic"`)
fmt.Fprintf(g.output, "%s:%s\n", username, phcHash) fmt.Fprintln(ag.output, "")
fmt.Fprintln(ag.output, "[[pipelines.auth.basic_auth.users]]")
fmt.Fprintf(ag.output, "username = %q\n", username)
fmt.Fprintf(ag.output, "password_hash = %q\n\n", phcHash)
return nil return nil
} }
func (g *GeneratorCommand) generateToken(length int) error { func (ag *AuthGeneratorCommand) generateScramAuth(username, password string) error {
// Get password if not provided
if password == "" {
pass1 := ag.promptPassword("Enter password: ")
pass2 := ag.promptPassword("Confirm password: ")
if pass1 != pass2 {
return fmt.Errorf("passwords don't match")
}
password = pass1
}
// Generate salt
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("failed to generate salt: %w", err)
}
// Derive SCRAM credential
cred, err := scram.DeriveCredential(username, password, salt, 3, 65536, 4)
if err != nil {
return fmt.Errorf("failed to derive SCRAM credential: %w", err)
}
// Output SCRAM configuration
fmt.Fprintln(ag.output, "\n# SCRAM Auth Configuration (for TCP sources/sinks)")
fmt.Fprintln(ag.output, "# Add to logwisp.toml:")
fmt.Fprintln(ag.output, "[[pipelines.auth.scram_auth.users]]")
fmt.Fprintf(ag.output, "username = %q\n", username)
fmt.Fprintf(ag.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
fmt.Fprintf(ag.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
fmt.Fprintf(ag.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
fmt.Fprintf(ag.output, "argon_time = %d\n", cred.ArgonTime)
fmt.Fprintf(ag.output, "argon_memory = %d\n", cred.ArgonMemory)
fmt.Fprintf(ag.output, "argon_threads = %d\n\n", cred.ArgonThreads)
return nil
}
func (ag *AuthGeneratorCommand) generateToken(length int) error {
if length < 16 { if length < 16 {
fmt.Fprintln(g.errOut, "Warning: tokens < 16 bytes are cryptographically weak") fmt.Fprintln(ag.errOut, "Warning: tokens < 16 bytes are cryptographically weak")
} }
if length > 512 { if length > 512 {
return fmt.Errorf("token length exceeds maximum (512 bytes)") return fmt.Errorf("token length exceeds maximum (512 bytes)")
@ -129,22 +184,23 @@ func (g *GeneratorCommand) generateToken(length int) error {
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token) b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
hex := fmt.Sprintf("%x", token) hex := fmt.Sprintf("%x", token)
fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):") fmt.Fprintln(ag.output, "\n# Bearer Token Configuration")
fmt.Fprintf(g.output, "tokens = [%q]\n\n", b64) fmt.Fprintln(ag.output, "# Add to logwisp.toml:")
fmt.Fprintf(ag.output, "tokens = [%q]\n\n", b64)
fmt.Fprintln(g.output, "# Generated Token:") fmt.Fprintln(ag.output, "# Generated Token:")
fmt.Fprintf(g.output, "Base64: %s\n", b64) fmt.Fprintf(ag.output, "Base64: %s\n", b64)
fmt.Fprintf(g.output, "Hex: %s\n", hex) fmt.Fprintf(ag.output, "Hex: %s\n", hex)
return nil return nil
} }
func (g *GeneratorCommand) promptPassword(prompt string) string { func (ag *AuthGeneratorCommand) promptPassword(prompt string) string {
fmt.Fprint(g.errOut, prompt) fmt.Fprint(ag.errOut, prompt)
password, err := term.ReadPassword(int(syscall.Stdin)) password, err := term.ReadPassword(syscall.Stdin)
fmt.Fprintln(g.errOut) fmt.Fprintln(ag.errOut)
if err != nil { if err != nil {
fmt.Fprintf(g.errOut, "Failed to read password: %v\n", err) fmt.Fprintf(ag.errOut, "Failed to read password: %v\n", err)
os.Exit(1) os.Exit(1)
} }
return string(password) return string(password)

View File

@ -6,61 +6,61 @@ import (
) )
type AuthConfig struct { type AuthConfig struct {
// Authentication type: "none", "basic", "bearer", "mtls" // Authentication type: "none", "basic", "scram", "bearer", "mtls"
Type string `toml:"type"` Type string `toml:"type"`
// Basic auth
BasicAuth *BasicAuthConfig `toml:"basic_auth"` BasicAuth *BasicAuthConfig `toml:"basic_auth"`
ScramAuth *ScramAuthConfig `toml:"scram_auth"`
// Bearer token auth
BearerAuth *BearerAuthConfig `toml:"bearer_auth"` BearerAuth *BearerAuthConfig `toml:"bearer_auth"`
} }
type BasicAuthConfig struct { type BasicAuthConfig struct {
// Static users (for simple deployments)
Users []BasicAuthUser `toml:"users"` Users []BasicAuthUser `toml:"users"`
// External auth file
UsersFile string `toml:"users_file"`
// Realm for WWW-Authenticate header
Realm string `toml:"realm"` Realm string `toml:"realm"`
} }
type BasicAuthUser struct { type BasicAuthUser struct {
Username string `toml:"username"` Username string `toml:"username"`
// Password hash (Argon2id) PasswordHash string `toml:"password_hash"` // Argon2
PasswordHash string `toml:"password_hash"` }
type ScramAuthConfig struct {
Users []ScramUser `toml:"users"`
}
type ScramUser struct {
Username string `toml:"username"`
StoredKey string `toml:"stored_key"` // base64
ServerKey string `toml:"server_key"` // base64
Salt string `toml:"salt"` // base64
ArgonTime uint32 `toml:"argon_time"`
ArgonMemory uint32 `toml:"argon_memory"`
ArgonThreads uint8 `toml:"argon_threads"`
} }
type BearerAuthConfig struct { type BearerAuthConfig struct {
// Static tokens // Static tokens
Tokens []string `toml:"tokens"` Tokens []string `toml:"tokens"`
// JWT validation // TODO: Maybe future development
JWT *JWTConfig `toml:"jwt"` // // JWT validation
// JWT *JWTConfig `toml:"jwt"`
} }
type JWTConfig struct { // TODO: Maybe future development
// JWKS URL for key discovery // type JWTConfig struct {
JWKSURL string `toml:"jwks_url"` // JWKSURL string `toml:"jwks_url"`
// SigningKey string `toml:"signing_key"`
// Static signing key (if not using JWKS) // Issuer string `toml:"issuer"`
SigningKey string `toml:"signing_key"` // Audience string `toml:"audience"`
// }
// Expected issuer
Issuer string `toml:"issuer"`
// Expected audience
Audience string `toml:"audience"`
}
func validateAuth(pipelineName string, auth *AuthConfig) error { func validateAuth(pipelineName string, auth *AuthConfig) error {
if auth == nil { if auth == nil {
return nil return nil
} }
validTypes := map[string]bool{"none": true, "basic": true, "bearer": true, "mtls": true} validTypes := map[string]bool{"none": true, "basic": true, "scram": true, "bearer": true, "mtls": true}
if !validTypes[auth.Type] { if !validTypes[auth.Type] {
return fmt.Errorf("pipeline '%s': invalid auth type: %s", pipelineName, auth.Type) return fmt.Errorf("pipeline '%s': invalid auth type: %s", pipelineName, auth.Type)
} }
@ -69,6 +69,10 @@ func validateAuth(pipelineName string, auth *AuthConfig) error {
return fmt.Errorf("pipeline '%s': basic auth type specified but config missing", pipelineName) return fmt.Errorf("pipeline '%s': basic auth type specified but config missing", pipelineName)
} }
if auth.Type == "scram" && auth.ScramAuth == nil {
return fmt.Errorf("pipeline '%s': scram auth type specified but config missing", pipelineName)
}
if auth.Type == "bearer" && auth.BearerAuth == nil { if auth.Type == "bearer" && auth.BearerAuth == nil {
return fmt.Errorf("pipeline '%s': bearer auth type specified but config missing", pipelineName) return fmt.Errorf("pipeline '%s': bearer auth type specified but config missing", pipelineName)
} }

View File

@ -180,10 +180,10 @@ func applyConsoleTargetOverrides(cfg *Config) error {
return fmt.Errorf("invalid LOGWISP_CONSOLE_TARGET value: %s", consoleTarget) return fmt.Errorf("invalid LOGWISP_CONSOLE_TARGET value: %s", consoleTarget)
} }
// Apply to all console sinks // Apply to console sinks
for i, pipeline := range cfg.Pipelines { for i, pipeline := range cfg.Pipelines {
for j, sink := range pipeline.Sinks { for j, sink := range pipeline.Sinks {
if sink.Type == "stdout" || sink.Type == "stderr" { if sink.Type == "console" {
if sink.Options == nil { if sink.Options == nil {
cfg.Pipelines[i].Sinks[j].Options = make(map[string]any) cfg.Pipelines[i].Sinks[j].Options = make(map[string]any)
} }
@ -193,10 +193,5 @@ func applyConsoleTargetOverrides(cfg *Config) error {
} }
} }
// Also update logging console target if applicable
if cfg.Logging.Console != nil && consoleTarget == "split" {
cfg.Logging.Console.Target = "split"
}
return nil return nil
} }

View File

@ -36,7 +36,7 @@ type PipelineConfig struct {
// Represents an input data source // Represents an input data source
type SourceConfig struct { type SourceConfig struct {
// Source type: "directory", "stdin", "tcp", "http" // Source type
Type string `toml:"type"` Type string `toml:"type"`
// Type-specific configuration options // Type-specific configuration options
@ -45,7 +45,7 @@ type SourceConfig struct {
// Represents an output destination // Represents an output destination
type SinkConfig struct { type SinkConfig struct {
// Sink type: "http", "tcp", "file", "stdout", "stderr" // Sink type
Type string `toml:"type"` Type string `toml:"type"`
// Type-specific configuration options // Type-specific configuration options
@ -404,7 +404,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts
} }
} }
case "stdout", "stderr": case "console":
// No specific validation needed for console sinks // No specific validation needed for console sinks
default: default:

View File

@ -0,0 +1,80 @@
// FILE: logwisp/src/internal/filter/chain_test.go
package filter
import (
"testing"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"github.com/stretchr/testify/assert"
)
func TestNewChain(t *testing.T) {
logger := newTestLogger()
t.Run("Success", func(t *testing.T) {
configs := []config.FilterConfig{
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
{Type: config.FilterTypeExclude, Patterns: []string{"banana"}},
}
chain, err := NewChain(configs, logger)
assert.NoError(t, err)
assert.NotNil(t, chain)
assert.Len(t, chain.filters, 2)
})
t.Run("ErrorInvalidRegexInChain", func(t *testing.T) {
configs := []config.FilterConfig{
{Patterns: []string{"apple"}},
{Patterns: []string{"["}},
}
chain, err := NewChain(configs, logger)
assert.Error(t, err)
assert.Nil(t, chain)
assert.Contains(t, err.Error(), "filter[1]")
})
}
func TestChain_Apply(t *testing.T) {
logger := newTestLogger()
entry := core.LogEntry{Message: "an apple a day"}
t.Run("EmptyChain", func(t *testing.T) {
chain, err := NewChain([]config.FilterConfig{}, logger)
assert.NoError(t, err)
assert.True(t, chain.Apply(entry))
})
t.Run("AllFiltersPass", func(t *testing.T) {
configs := []config.FilterConfig{
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
{Type: config.FilterTypeInclude, Patterns: []string{"day"}},
{Type: config.FilterTypeExclude, Patterns: []string{"banana"}},
}
chain, err := NewChain(configs, logger)
assert.NoError(t, err)
assert.True(t, chain.Apply(entry))
})
t.Run("OneFilterFails", func(t *testing.T) {
configs := []config.FilterConfig{
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
{Type: config.FilterTypeExclude, Patterns: []string{"day"}}, // This one will fail
{Type: config.FilterTypeInclude, Patterns: []string{"a"}},
}
chain, err := NewChain(configs, logger)
assert.NoError(t, err)
assert.False(t, chain.Apply(entry))
})
t.Run("FirstFilterFails", func(t *testing.T) {
configs := []config.FilterConfig{
{Type: config.FilterTypeInclude, Patterns: []string{"banana"}}, // This one will fail
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
}
chain, err := NewChain(configs, logger)
assert.NoError(t, err)
assert.False(t, chain.Apply(entry))
})
}

View File

@ -66,6 +66,9 @@ func (f *Filter) Apply(entry core.LogEntry) bool {
// No patterns means pass everything // No patterns means pass everything
if len(f.patterns) == 0 { if len(f.patterns) == 0 {
f.logger.Debug("msg", "No patterns configured, passing entry",
"component", "filter",
"type", f.config.Type)
return true return true
} }
@ -78,10 +81,32 @@ func (f *Filter) Apply(entry core.LogEntry) bool {
text = entry.Source + " " + text text = entry.Source + " " + text
} }
f.logger.Debug("msg", "Filter checking entry",
"component", "filter",
"type", f.config.Type,
"logic", f.config.Logic,
"entry_level", entry.Level,
"entry_source", entry.Source,
"entry_message", entry.Message[:min(100, len(entry.Message))], // First 100 chars
"text_to_match", text[:min(150, len(text))], // First 150 chars
"patterns", f.config.Patterns)
for i, pattern := range f.config.Patterns {
isMatch := f.patterns[i].MatchString(text)
f.logger.Debug("msg", "Pattern match result",
"component", "filter",
"pattern_index", i,
"pattern", pattern,
"matched", isMatch)
}
matched := f.matches(text) matched := f.matches(text)
if matched { if matched {
f.totalMatched.Add(1) f.totalMatched.Add(1)
} }
f.logger.Debug("msg", "Filter final match result",
"component", "filter",
"matched", matched)
// Determine if we should pass or drop // Determine if we should pass or drop
shouldPass := false shouldPass := false
@ -92,6 +117,12 @@ func (f *Filter) Apply(entry core.LogEntry) bool {
shouldPass = !matched shouldPass = !matched
} }
f.logger.Debug("msg", "Filter decision",
"component", "filter",
"type", f.config.Type,
"matched", matched,
"should_pass", shouldPass)
if !shouldPass { if !shouldPass {
f.totalDropped.Add(1) f.totalDropped.Add(1)
} }

View File

@ -0,0 +1,159 @@
// FILE: logwisp/src/internal/filter/filter_test.go
package filter
import (
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"testing"
"github.com/lixenwraith/log"
"github.com/stretchr/testify/assert"
)
func newTestLogger() *log.Logger {
return log.NewLogger()
}
func TestNewFilter(t *testing.T) {
logger := newTestLogger()
t.Run("SuccessWithDefaults", func(t *testing.T) {
cfg := config.FilterConfig{Patterns: []string{"test"}}
f, err := NewFilter(cfg, logger)
assert.NoError(t, err)
assert.NotNil(t, f)
assert.Equal(t, config.FilterTypeInclude, f.config.Type)
assert.Equal(t, config.FilterLogicOr, f.config.Logic)
})
t.Run("SuccessWithCustomConfig", func(t *testing.T) {
cfg := config.FilterConfig{
Type: config.FilterTypeExclude,
Logic: config.FilterLogicAnd,
Patterns: []string{"test", "pattern"},
}
f, err := NewFilter(cfg, logger)
assert.NoError(t, err)
assert.NotNil(t, f)
assert.Equal(t, config.FilterTypeExclude, f.config.Type)
assert.Equal(t, config.FilterLogicAnd, f.config.Logic)
assert.Len(t, f.patterns, 2)
})
t.Run("ErrorInvalidRegex", func(t *testing.T) {
cfg := config.FilterConfig{Patterns: []string{"["}}
f, err := NewFilter(cfg, logger)
assert.Error(t, err)
assert.Nil(t, f)
assert.Contains(t, err.Error(), "invalid regex pattern")
})
}
func TestFilter_Apply(t *testing.T) {
logger := newTestLogger()
testCases := []struct {
name string
cfg config.FilterConfig
entry core.LogEntry
expected bool
}{
// Include OR logic
{
name: "IncludeOR_MatchOne",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}},
entry: core.LogEntry{Message: "this is an apple"},
expected: true,
},
{
name: "IncludeOR_NoMatch",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}},
entry: core.LogEntry{Message: "this is a pear"},
expected: false,
},
// Include AND logic
{
name: "IncludeAND_MatchAll",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}},
entry: core.LogEntry{Message: "an apple keeps the doctor away"},
expected: true,
},
{
name: "IncludeAND_MatchOne",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}},
entry: core.LogEntry{Message: "this is an apple"},
expected: false,
},
// Exclude OR logic
{
name: "ExcludeOR_MatchOne",
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}},
entry: core.LogEntry{Message: "this is an error"},
expected: false,
},
{
name: "ExcludeOR_NoMatch",
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}},
entry: core.LogEntry{Message: "this is a warning"},
expected: true,
},
// Exclude AND logic
{
name: "ExcludeAND_MatchAll",
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}},
entry: core.LogEntry{Message: "critical error in database"},
expected: false,
},
{
name: "ExcludeAND_MatchOne",
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}},
entry: core.LogEntry{Message: "critical error in app"},
expected: true,
},
// Edge Cases
{
name: "NoPatterns",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{}},
entry: core.LogEntry{Message: "any message"},
expected: true,
},
{
name: "EmptyEntry_NoPatterns",
cfg: config.FilterConfig{Patterns: []string{}},
entry: core.LogEntry{},
expected: true,
},
{
name: "EmptyEntry_DoesNotMatchSpace",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{" "}},
entry: core.LogEntry{Level: "", Source: "", Message: ""},
expected: false, // CORRECTED: An empty entry results in an empty string, which doesn't match a space.
},
{
name: "MatchOnLevel",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"ERROR"}},
entry: core.LogEntry{Level: "ERROR", Message: "A message"},
expected: true,
},
{
name: "MatchOnSource",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"database"}},
entry: core.LogEntry{Source: "database", Message: "A message"},
expected: true,
},
{
name: "MatchOnCombinedFields",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"^app ERROR"}},
entry: core.LogEntry{Source: "app", Level: "ERROR", Message: "A message"},
expected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
f, err := NewFilter(tc.cfg, logger)
assert.NoError(t, err)
assert.Equal(t, tc.expected, f.Apply(tc.entry))
})
}
}

View File

@ -28,7 +28,7 @@ func NewFormatter(name string, options map[string]any, logger *log.Logger) (Form
switch name { switch name {
case "json": case "json":
return NewJSONFormatter(options, logger) return NewJSONFormatter(options, logger)
case "text": case "txt":
return NewTextFormatter(options, logger) return NewTextFormatter(options, logger)
case "raw": case "raw":
return NewRawFormatter(options, logger) return NewRawFormatter(options, logger)

View File

@ -0,0 +1,65 @@
// FILE: logwisp/src/internal/format/format_test.go
package format
import (
"testing"
"github.com/lixenwraith/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestLogger() *log.Logger {
return log.NewLogger()
}
func TestNewFormatter(t *testing.T) {
logger := newTestLogger()
testCases := []struct {
name string
formatName string
expected string
expectError bool
}{
{
name: "JSONFormatter",
formatName: "json",
expected: "json",
},
{
name: "TextFormatter",
formatName: "txt",
expected: "txt",
},
{
name: "RawFormatter",
formatName: "raw",
expected: "raw",
},
{
name: "DefaultToRaw",
formatName: "",
expected: "raw",
},
{
name: "UnknownFormatter",
formatName: "xml",
expectError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
formatter, err := NewFormatter(tc.formatName, nil, logger)
if tc.expectError {
assert.Error(t, err)
assert.Nil(t, formatter)
} else {
require.NoError(t, err)
require.NotNil(t, formatter)
assert.Equal(t, tc.expected, formatter.Name())
}
})
}
}

View File

@ -0,0 +1,129 @@
// FILE: logwisp/src/internal/format/json_test.go
package format
import (
"encoding/json"
"strings"
"testing"
"time"
"logwisp/src/internal/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestJSONFormatter_Format(t *testing.T) {
logger := newTestLogger()
testTime := time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC)
entry := core.LogEntry{
Time: testTime,
Source: "test-app",
Level: "INFO",
Message: "this is a test",
}
t.Run("BasicFormatting", func(t *testing.T) {
formatter, err := NewJSONFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err, "Output should be valid JSON")
assert.Equal(t, testTime.Format(time.RFC3339Nano), result["timestamp"])
assert.Equal(t, "INFO", result["level"])
assert.Equal(t, "test-app", result["source"])
assert.Equal(t, "this is a test", result["message"])
assert.True(t, strings.HasSuffix(string(output), "\n"), "Output should end with a newline")
})
t.Run("PrettyFormatting", func(t *testing.T) {
formatter, err := NewJSONFormatter(map[string]any{"pretty": true}, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
assert.Contains(t, string(output), ` "level": "INFO"`)
assert.True(t, strings.HasSuffix(string(output), "\n"))
})
t.Run("MessageIsJSON", func(t *testing.T) {
jsonMessageEntry := entry
jsonMessageEntry.Message = `{"user":"test","request_id":"abc-123"}`
formatter, err := NewJSONFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(jsonMessageEntry)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err)
assert.Equal(t, "test", result["user"])
assert.Equal(t, "abc-123", result["request_id"])
_, messageExists := result["message"]
assert.False(t, messageExists, "message field should not exist when message is merged JSON")
})
t.Run("MessageIsJSONWithConflicts", func(t *testing.T) {
jsonMessageEntry := entry
jsonMessageEntry.Level = "INFO" // top-level
jsonMessageEntry.Message = `{"level":"DEBUG","msg":"hello"}`
formatter, err := NewJSONFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(jsonMessageEntry)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err)
assert.Equal(t, "INFO", result["level"], "Top-level LogEntry field should take precedence")
})
t.Run("CustomFieldNames", func(t *testing.T) {
options := map[string]any{"timestamp_field": "@timestamp"}
formatter, err := NewJSONFormatter(options, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err)
_, defaultExists := result["timestamp"]
assert.False(t, defaultExists)
assert.Equal(t, testTime.Format(time.RFC3339Nano), result["@timestamp"])
})
}
func TestJSONFormatter_FormatBatch(t *testing.T) {
logger := newTestLogger()
formatter, err := NewJSONFormatter(nil, logger)
require.NoError(t, err)
entries := []core.LogEntry{
{Time: time.Now(), Level: "INFO", Message: "First message"},
{Time: time.Now(), Level: "WARN", Message: "Second message"},
}
output, err := formatter.FormatBatch(entries)
require.NoError(t, err)
var result []map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err, "Batch output should be a valid JSON array")
require.Len(t, result, 2)
assert.Equal(t, "First message", result[0]["message"])
assert.Equal(t, "WARN", result[1]["level"])
}

View File

@ -0,0 +1,29 @@
// FILE: logwisp/src/internal/format/raw_test.go
package format
import (
"testing"
"time"
"logwisp/src/internal/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRawFormatter_Format(t *testing.T) {
logger := newTestLogger()
formatter, err := NewRawFormatter(nil, logger)
require.NoError(t, err)
entry := core.LogEntry{
Time: time.Now(),
Message: "This is a raw log line.",
}
output, err := formatter.Format(entry)
require.NoError(t, err)
expected := "This is a raw log line.\n"
assert.Equal(t, expected, string(output))
}

View File

@ -104,5 +104,5 @@ func (f *TextFormatter) Format(entry core.LogEntry) ([]byte, error) {
// Returns the formatter name // Returns the formatter name
func (f *TextFormatter) Name() string { func (f *TextFormatter) Name() string {
return "text" return "txt"
} }

View File

@ -0,0 +1,81 @@
// FILE: logwisp/src/internal/format/text_test.go
package format
import (
"fmt"
"strings"
"testing"
"time"
"logwisp/src/internal/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewTextFormatter(t *testing.T) {
logger := newTestLogger()
t.Run("InvalidTemplate", func(t *testing.T) {
options := map[string]any{"template": "{{ .Timestamp | InvalidFunc }}"}
_, err := NewTextFormatter(options, logger)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid template")
})
}
func TestTextFormatter_Format(t *testing.T) {
logger := newTestLogger()
testTime := time.Date(2023, 10, 27, 10, 30, 0, 0, time.UTC)
entry := core.LogEntry{
Time: testTime,
Source: "api",
Level: "WARN",
Message: "rate limit exceeded",
}
t.Run("DefaultTemplate", func(t *testing.T) {
formatter, err := NewTextFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
expected := fmt.Sprintf("[%s] [WARN] api - rate limit exceeded\n", testTime.Format(time.RFC3339))
assert.Equal(t, expected, string(output))
})
t.Run("CustomTemplate", func(t *testing.T) {
options := map[string]any{"template": "{{.Level}}:{{.Source}}:{{.Message}}"}
formatter, err := NewTextFormatter(options, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
expected := "WARN:api:rate limit exceeded\n"
assert.Equal(t, expected, string(output))
})
t.Run("CustomTimestampFormat", func(t *testing.T) {
options := map[string]any{"timestamp_format": "2006-01-02"}
formatter, err := NewTextFormatter(options, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
assert.True(t, strings.HasPrefix(string(output), "[2023-10-27]"))
})
t.Run("EmptyLevelDefaultsToInfo", func(t *testing.T) {
emptyLevelEntry := entry
emptyLevelEntry.Level = ""
formatter, err := NewTextFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(emptyLevelEntry)
require.NoError(t, err)
assert.Contains(t, string(output), "[INFO]")
})
}

View File

@ -0,0 +1,106 @@
// FILE: src/internal/scram/client.go
package scram
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"
"golang.org/x/crypto/argon2"
)
// Client handles SCRAM client-side authentication
type Client struct {
Username string
Password string
// Handshake state
clientNonce string
serverFirst *ServerFirst
authMessage string
serverKey []byte
}
// NewClient creates SCRAM client
func NewClient(username, password string) *Client {
return &Client{
Username: username,
Password: password,
}
}
// StartAuthentication generates ClientFirst message
func (c *Client) StartAuthentication() (*ClientFirst, error) {
// Generate client nonce
nonce := make([]byte, 32)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
c.clientNonce = base64.StdEncoding.EncodeToString(nonce)
return &ClientFirst{
Username: c.Username,
ClientNonce: c.clientNonce,
}, nil
}
// ProcessServerFirst handles server challenge
func (c *Client) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) {
c.serverFirst = msg
// Decode salt
salt, err := base64.StdEncoding.DecodeString(msg.Salt)
if err != nil {
return nil, fmt.Errorf("invalid salt encoding: %w", err)
}
// Derive keys using Argon2id
saltedPassword := argon2.IDKey([]byte(c.Password), salt,
msg.ArgonTime, msg.ArgonMemory, msg.ArgonThreads, 32)
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
storedKey := sha256.Sum256(clientKey)
// Build auth message
clientFirstBare := fmt.Sprintf("u=%s,n=%s", c.Username, c.clientNonce)
clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce)
c.authMessage = clientFirstBare + "," + msg.Marshal() + "," + clientFinalBare
// Compute client proof
clientSignature := computeHMAC(storedKey[:], []byte(c.authMessage))
clientProof := xorBytes(clientKey, clientSignature)
// Store server key for verification
c.serverKey = serverKey
return &ClientFinal{
FullNonce: msg.FullNonce,
ClientProof: base64.StdEncoding.EncodeToString(clientProof),
}, nil
}
// VerifyServerFinal validates server signature
func (c *Client) VerifyServerFinal(msg *ServerFinal) error {
if c.authMessage == "" || c.serverKey == nil {
return fmt.Errorf("invalid handshake state")
}
// Compute expected server signature
expectedSig := computeHMAC(c.serverKey, []byte(c.authMessage))
// Decode received signature
receivedSig, err := base64.StdEncoding.DecodeString(msg.ServerSignature)
if err != nil {
return fmt.Errorf("invalid signature encoding: %w", err)
}
// ☢ SECURITY: Constant-time comparison
if subtle.ConstantTimeCompare(expectedSig, receivedSig) != 1 {
return fmt.Errorf("server authentication failed")
}
return nil
}

View File

@ -0,0 +1,99 @@
// FILE: src/internal/scram/credential.go
package scram
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"
"strings"
"golang.org/x/crypto/argon2"
)
// Credential stores SCRAM authentication data
type Credential struct {
Username string
Salt []byte // 16+ bytes
ArgonTime uint32 // e.g., 3
ArgonMemory uint32 // e.g., 64*1024 KiB
ArgonThreads uint8 // e.g., 4
StoredKey []byte // SHA256(ClientKey)
ServerKey []byte // For server auth
PHCHash string
}
// DeriveCredential creates SCRAM credential from password
func DeriveCredential(username, password string, salt []byte, time, memory uint32, threads uint8) (*Credential, error) {
if len(salt) < 16 {
return nil, fmt.Errorf("salt must be at least 16 bytes")
}
// Derive salted password using Argon2id
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, 32)
// Derive keys
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
storedKey := sha256.Sum256(clientKey)
return &Credential{
Username: username,
Salt: salt,
ArgonTime: time,
ArgonMemory: memory,
ArgonThreads: threads,
StoredKey: storedKey[:],
ServerKey: serverKey,
}, nil
}
// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
// Parse PHC: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
parts := strings.Split(phcHash, "$")
if len(parts) != 6 || parts[1] != "argon2id" {
return nil, fmt.Errorf("invalid PHC format")
}
var memory, time uint32
var threads uint8
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return nil, fmt.Errorf("invalid salt encoding: %w", err)
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return nil, fmt.Errorf("invalid hash encoding: %w", err)
}
// Verify password matches
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
return nil, fmt.Errorf("password verification failed")
}
// Now derive SCRAM credential
return DeriveCredential(username, password, salt, time, memory, threads)
}
func computeHMAC(key, message []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(message)
return mac.Sum(nil)
}
func xorBytes(a, b []byte) []byte {
if len(a) != len(b) {
panic("xor length mismatch")
}
result := make([]byte, len(a))
for i := range a {
result[i] = a[i] ^ b[i]
}
return result
}

View File

@ -0,0 +1,117 @@
// FILE: src/internal/scram/integration.go
package scram
import (
"crypto/rand"
"fmt"
"sync"
"time"
"golang.org/x/time/rate"
)
// ScramManager provides high-level SCRAM operations with rate limiting
type ScramManager struct {
server *Server
sessions map[string]*SessionInfo
limiter map[string]*rate.Limiter
mu sync.RWMutex
}
// SessionInfo tracks authenticated sessions
type SessionInfo struct {
Username string
RemoteAddr string
SessionID string
CreatedAt time.Time
LastActivity time.Time
Method string // "scram-sha-256"
}
// NewScramManager creates SCRAM manager
func NewScramManager() *ScramManager {
m := &ScramManager{
server: NewServer(),
sessions: make(map[string]*SessionInfo),
limiter: make(map[string]*rate.Limiter),
}
// Start cleanup goroutine
go m.cleanupLoop()
return m
}
// RegisterUser creates new user credential
func (sm *ScramManager) RegisterUser(username, password string) error {
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("salt generation failed: %w", err)
}
cred, err := DeriveCredential(username, password, salt,
sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads)
if err != nil {
return err
}
sm.server.AddCredential(cred)
return nil
}
// GetRateLimiter returns per-IP rate limiter
func (sm *ScramManager) GetRateLimiter(remoteAddr string) *rate.Limiter {
sm.mu.Lock()
defer sm.mu.Unlock()
if limiter, exists := sm.limiter[remoteAddr]; exists {
return limiter
}
// 10 attempts per minute, burst of 3
limiter := rate.NewLimiter(rate.Every(6*time.Second), 3)
sm.limiter[remoteAddr] = limiter
// Prevent unbounded growth
if len(sm.limiter) > 10000 {
// Remove oldest entries
for addr := range sm.limiter {
delete(sm.limiter, addr)
if len(sm.limiter) < 8000 {
break
}
}
}
return limiter
}
func (sm *ScramManager) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
sm.mu.Lock()
cutoff := time.Now().Add(-30 * time.Minute)
for sid, session := range sm.sessions {
if session.LastActivity.Before(cutoff) {
delete(sm.sessions, sid)
}
}
sm.mu.Unlock()
}
}
// HandleClientFirst wraps server's HandleClientFirst
func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
return sm.server.HandleClientFirst(msg)
}
// HandleClientFinal wraps server's HandleClientFinal
func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
return sm.server.HandleClientFinal(msg)
}
// AddCredential wraps server's AddCredential
func (sm *ScramManager) AddCredential(cred *Credential) {
sm.server.AddCredential(cred)
}

View File

@ -0,0 +1,101 @@
// FILE: src/internal/scram/message.go
package scram
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)
// ClientFirst initiates authentication
type ClientFirst struct {
Username string `json:"u"`
ClientNonce string `json:"n"`
}
// ServerFirst contains server challenge
type ServerFirst struct {
FullNonce string `json:"r"` // client_nonce + server_nonce
Salt string `json:"s"` // base64
ArgonTime uint32 `json:"t"`
ArgonMemory uint32 `json:"m"`
ArgonThreads uint8 `json:"p"`
}
// ClientFinal contains client proof
type ClientFinal struct {
FullNonce string `json:"r"`
ClientProof string `json:"p"` // base64
}
// ServerFinal contains server signature for mutual auth
type ServerFinal struct {
ServerSignature string `json:"v"` // base64
SessionID string `json:"sid,omitempty"`
}
// Marshal/Unmarshal helpers for TCP protocol (line-based)
func (cf *ClientFirst) Marshal() string {
return fmt.Sprintf("u=%s,n=%s", cf.Username, cf.ClientNonce)
}
func ParseClientFirst(data string) (*ClientFirst, error) {
parts := strings.Split(data, ",")
msg := &ClientFirst{}
for _, part := range parts {
kv := strings.SplitN(part, "=", 2)
if len(kv) != 2 {
continue
}
switch kv[0] {
case "u":
msg.Username = kv[1]
case "n":
msg.ClientNonce = kv[1]
}
}
if msg.Username == "" || msg.ClientNonce == "" {
return nil, fmt.Errorf("missing required fields")
}
return msg, nil
}
func (sf *ServerFirst) Marshal() string {
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads)
}
func ParseServerFirst(data string) (*ServerFirst, error) {
parts := strings.Split(data, ",")
msg := &ServerFirst{}
for _, part := range parts {
kv := strings.SplitN(part, "=", 2)
if len(kv) != 2 {
continue
}
switch kv[0] {
case "r":
msg.FullNonce = kv[1]
case "s":
msg.Salt = kv[1]
case "t":
fmt.Sscanf(kv[1], "%d", &msg.ArgonTime)
case "m":
fmt.Sscanf(kv[1], "%d", &msg.ArgonMemory)
case "p":
fmt.Sscanf(kv[1], "%d", &msg.ArgonThreads)
}
}
return msg, nil
}
// JSON variants for HTTP
func (cf *ClientFirst) MarshalJSON() ([]byte, error) {
return json.Marshal(*cf)
}
func (sf *ServerFirst) MarshalJSON() ([]byte, error) {
sf.Salt = base64.StdEncoding.EncodeToString([]byte(sf.Salt))
return json.Marshal(*sf)
}

View File

@ -0,0 +1,228 @@
// FILE: src/internal/scram/scram_test.go
package scram
import (
"crypto/rand"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/argon2"
)
func TestCredentialDerivation(t *testing.T) {
salt := make([]byte, 16)
_, err := rand.Read(salt)
require.NoError(t, err)
cred, err := DeriveCredential("testuser", "testpass123", salt, 3, 64*1024, 4)
require.NoError(t, err)
assert.Equal(t, "testuser", cred.Username)
assert.Equal(t, uint32(3), cred.ArgonTime)
assert.Equal(t, uint32(64*1024), cred.ArgonMemory)
assert.Equal(t, uint8(4), cred.ArgonThreads)
assert.Len(t, cred.StoredKey, 32)
assert.Len(t, cred.ServerKey, 32)
}
func TestFullHandshake(t *testing.T) {
// Setup server with credential
server := NewServer()
salt := make([]byte, 16)
_, err := rand.Read(salt)
require.NoError(t, err)
cred, err := DeriveCredential("alice", "secret123", salt, 3, 64*1024, 4)
require.NoError(t, err)
server.AddCredential(cred)
// Setup client
client := NewClient("alice", "secret123")
// Step 1: Client starts
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
assert.Equal(t, "alice", clientFirst.Username)
assert.NotEmpty(t, clientFirst.ClientNonce)
// Step 2: Server responds
serverFirst, err := server.HandleClientFirst(clientFirst)
require.NoError(t, err)
assert.Contains(t, serverFirst.FullNonce, clientFirst.ClientNonce)
assert.Equal(t, base64.StdEncoding.EncodeToString(salt), serverFirst.Salt)
// Step 3: Client proves
clientFinal, err := client.ProcessServerFirst(serverFirst)
require.NoError(t, err)
assert.Equal(t, serverFirst.FullNonce, clientFinal.FullNonce)
assert.NotEmpty(t, clientFinal.ClientProof)
// Step 4: Server verifies and signs
serverFinal, err := server.HandleClientFinal(clientFinal)
require.NoError(t, err)
assert.NotEmpty(t, serverFinal.ServerSignature)
assert.NotEmpty(t, serverFinal.SessionID)
// Step 5: Client verifies server
err = client.VerifyServerFinal(serverFinal)
assert.NoError(t, err)
}
func TestInvalidPassword(t *testing.T) {
server := NewServer()
salt := make([]byte, 16)
rand.Read(salt)
cred, _ := DeriveCredential("bob", "correct", salt, 3, 64*1024, 4)
server.AddCredential(cred)
// Client with wrong password
client := NewClient("bob", "wrong")
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.HandleClientFirst(clientFirst)
clientFinal, _ := client.ProcessServerFirst(serverFirst)
clientFinal, err := client.ProcessServerFirst(serverFirst)
require.NoError(t, err) // Check error to prevent panic
// Server should reject
_, err = server.HandleClientFinal(clientFinal)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
}
func TestHandshakeTimeout(t *testing.T) {
server := NewServer()
salt := make([]byte, 16)
rand.Read(salt)
cred, _ := DeriveCredential("charlie", "pass", salt, 3, 64*1024, 4)
server.AddCredential(cred)
client := NewClient("charlie", "pass")
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.HandleClientFirst(clientFirst)
// Manipulate handshake timestamp
server.mu.Lock()
if state, exists := server.handshakes[serverFirst.FullNonce]; exists {
state.CreatedAt = time.Now().Add(-61 * time.Second)
}
server.mu.Unlock()
clientFinal, _ := client.ProcessServerFirst(serverFirst)
_, err := server.HandleClientFinal(clientFinal)
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout")
}
func TestMessageParsing(t *testing.T) {
// Test ClientFirst parsing
cf, err := ParseClientFirst("u=alice,n=abc123")
require.NoError(t, err)
assert.Equal(t, "alice", cf.Username)
assert.Equal(t, "abc123", cf.ClientNonce)
// Test marshal round-trip
marshaled := cf.Marshal()
parsed, err := ParseClientFirst(marshaled)
require.NoError(t, err)
assert.Equal(t, cf.Username, parsed.Username)
assert.Equal(t, cf.ClientNonce, parsed.ClientNonce)
}
func TestPHCMigration(t *testing.T) {
// Create PHC hash using known values
password := "testpass"
salt := []byte("saltsaltsaltsalt")
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
// This would be generated by the existing auth system
phcHash := "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + "dGVzdGhhc2g" // dummy hash
// For testing, we need a valid hash - generate it
hash := argon2.IDKey([]byte(password), salt, 3, 65536, 4, 32)
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
phcHash = "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + hashB64
cred, err := MigrateFromPHC("alice", password, phcHash)
require.NoError(t, err)
assert.Equal(t, "alice", cred.Username)
assert.Equal(t, salt, cred.Salt)
}
func TestRateLimiting(t *testing.T) {
manager := NewScramManager()
limiter := manager.GetRateLimiter("192.168.1.1")
// Should allow burst
for i := 0; i < 3; i++ {
assert.True(t, limiter.Allow())
}
// 4th should be rate limited
assert.False(t, limiter.Allow())
}
func TestNonceUniqueness(t *testing.T) {
nonces := make(map[string]bool)
for i := 0; i < 1000; i++ {
nonce := generateNonce()
assert.False(t, nonces[nonce], "Duplicate nonce generated")
nonces[nonce] = true
}
}
func TestConstantTimeComparison(t *testing.T) {
server := NewServer()
salt := make([]byte, 16)
rand.Read(salt)
// Add real user
cred, _ := DeriveCredential("realuser", "realpass", salt, 3, 64*1024, 4)
server.AddCredential(cred)
// Test timing independence for invalid user
fakeClient := NewClient("fakeuser", "fakepass")
clientFirst, _ := fakeClient.StartAuthentication()
// Should still return ServerFirst (no user enumeration)
serverFirst, err := server.HandleClientFirst(clientFirst)
assert.NotNil(t, serverFirst)
assert.Error(t, err) // Internal error, not exposed to client
}
func BenchmarkArgon2Derivation(b *testing.B) {
salt := make([]byte, 16)
rand.Read(salt)
b.ResetTimer()
for i := 0; i < b.N; i++ {
DeriveCredential("user", "password", salt, 3, 64*1024, 4)
}
}
func BenchmarkFullHandshake(b *testing.B) {
server := NewServer()
salt := make([]byte, 16)
rand.Read(salt)
cred, _ := DeriveCredential("bench", "pass", salt, 3, 64*1024, 4)
server.AddCredential(cred)
b.ResetTimer()
for i := 0; i < b.N; i++ {
client := NewClient("bench", "pass")
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.HandleClientFirst(clientFirst)
clientFinal, _ := client.ProcessServerFirst(serverFirst)
serverFinal, _ := server.HandleClientFinal(clientFinal)
client.VerifyServerFinal(serverFinal)
}
}

View File

@ -0,0 +1,179 @@
// FILE: src/internal/scram/server.go
package scram
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"
"sync"
"time"
)
// Server handles SCRAM authentication
type Server struct {
credentials map[string]*Credential
handshakes map[string]*HandshakeState
mu sync.RWMutex
// Default Argon2 params for new registrations
DefaultTime uint32
DefaultMemory uint32
DefaultThreads uint8
}
// HandshakeState tracks ongoing authentication
type HandshakeState struct {
Username string
ClientNonce string
ServerNonce string
FullNonce string
AuthMessage string
Credential *Credential
CreatedAt time.Time
ClientProof []byte
}
// NewServer creates SCRAM server
func NewServer() *Server {
return &Server{
credentials: make(map[string]*Credential),
handshakes: make(map[string]*HandshakeState),
DefaultTime: 3,
DefaultMemory: 64 * 1024,
DefaultThreads: 4,
}
}
// AddCredential registers user credential
func (s *Server) AddCredential(cred *Credential) {
s.mu.Lock()
defer s.mu.Unlock()
s.credentials[cred.Username] = cred
}
// HandleClientFirst processes initial auth request
func (s *Server) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Check if user exists
cred, exists := s.credentials[msg.Username]
if !exists {
// Prevent user enumeration - still generate response
salt := make([]byte, 16)
rand.Read(salt)
serverNonce := generateNonce()
return &ServerFirst{
FullNonce: msg.ClientNonce + serverNonce,
Salt: base64.StdEncoding.EncodeToString(salt),
ArgonTime: s.DefaultTime,
ArgonMemory: s.DefaultMemory,
ArgonThreads: s.DefaultThreads,
}, fmt.Errorf("invalid credentials")
}
// Generate server nonce
serverNonce := generateNonce()
fullNonce := msg.ClientNonce + serverNonce
// Store handshake state
state := &HandshakeState{
Username: msg.Username,
ClientNonce: msg.ClientNonce,
ServerNonce: serverNonce,
FullNonce: fullNonce,
Credential: cred,
CreatedAt: time.Now(),
}
s.handshakes[fullNonce] = state
// Cleanup old handshakes
s.cleanupHandshakes()
return &ServerFirst{
FullNonce: fullNonce,
Salt: base64.StdEncoding.EncodeToString(cred.Salt),
ArgonTime: cred.ArgonTime,
ArgonMemory: cred.ArgonMemory,
ArgonThreads: cred.ArgonThreads,
}, nil
}
// HandleClientFinal verifies client proof
func (s *Server) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
s.mu.Lock()
defer s.mu.Unlock()
state, exists := s.handshakes[msg.FullNonce]
if !exists {
return nil, fmt.Errorf("invalid nonce or expired handshake")
}
defer delete(s.handshakes, msg.FullNonce)
// Check timeout
if time.Since(state.CreatedAt) > 60*time.Second {
return nil, fmt.Errorf("handshake timeout")
}
// Decode client proof
clientProof, err := base64.StdEncoding.DecodeString(msg.ClientProof)
if err != nil {
return nil, fmt.Errorf("invalid proof encoding")
}
// Build auth message
clientFirstBare := fmt.Sprintf("u=%s,n=%s", state.Username, state.ClientNonce)
serverFirst := &ServerFirst{
FullNonce: state.FullNonce,
Salt: base64.StdEncoding.EncodeToString(state.Credential.Salt),
ArgonTime: state.Credential.ArgonTime,
ArgonMemory: state.Credential.ArgonMemory,
ArgonThreads: state.Credential.ArgonThreads,
}
clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce)
authMessage := clientFirstBare + "," + serverFirst.Marshal() + "," + clientFinalBare
// Compute client signature
clientSignature := computeHMAC(state.Credential.StoredKey, []byte(authMessage))
// XOR to get ClientKey
clientKey := xorBytes(clientProof, clientSignature)
// Verify by computing StoredKey
computedStoredKey := sha256.Sum256(clientKey)
if subtle.ConstantTimeCompare(computedStoredKey[:], state.Credential.StoredKey) != 1 {
return nil, fmt.Errorf("authentication failed")
}
// Generate server signature for mutual auth
serverSignature := computeHMAC(state.Credential.ServerKey, []byte(authMessage))
return &ServerFinal{
ServerSignature: base64.StdEncoding.EncodeToString(serverSignature),
SessionID: generateSessionID(),
}, nil
}
func (s *Server) cleanupHandshakes() {
cutoff := time.Now().Add(-60 * time.Second)
for nonce, state := range s.handshakes {
if state.CreatedAt.Before(cutoff) {
delete(s.handshakes, nonce)
}
}
}
func generateNonce() string {
b := make([]byte, 32)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
func generateSessionID() string {
b := make([]byte, 24)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}

View File

@ -281,10 +281,8 @@ func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter)
return sink.NewTCPClientSink(cfg.Options, s.logger, formatter) return sink.NewTCPClientSink(cfg.Options, s.logger, formatter)
case "file": case "file":
return sink.NewFileSink(cfg.Options, s.logger, formatter) return sink.NewFileSink(cfg.Options, s.logger, formatter)
case "stdout": case "console":
return sink.NewStdoutSink(cfg.Options, s.logger, formatter) return sink.NewConsoleSink(cfg.Options, s.logger, formatter)
case "stderr":
return sink.NewStderrSink(cfg.Options, s.logger, formatter)
default: default:
return nil, fmt.Errorf("unknown sink type: %s", cfg.Type) return nil, fmt.Errorf("unknown sink type: %s", cfg.Type)
} }

View File

@ -2,9 +2,9 @@
package sink package sink
import ( import (
"bytes"
"context" "context"
"io" "fmt"
"os"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@ -16,20 +16,13 @@ import (
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
) )
// Holds common configuration for console sinks // ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance
type ConsoleConfig struct { type ConsoleSink struct {
Target string // "stdout", "stderr", or "split"
BufferSize int64
}
// Writes log entries to stdout
type StdoutSink struct {
input chan core.LogEntry input chan core.LogEntry
config ConsoleConfig writer *log.Logger // Dedicated internal logger instance for console writing
output io.Writer
done chan struct{} done chan struct{}
startTime time.Time startTime time.Time
logger *log.Logger logger *log.Logger // Application logger for app logs
formatter format.Formatter formatter format.Formatter
// Statistics // Statistics
@ -37,29 +30,38 @@ type StdoutSink struct {
lastProcessed atomic.Value // time.Time lastProcessed atomic.Value // time.Time
} }
// Creates a new stdout sink // Creates a new console sink
func NewStdoutSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*StdoutSink, error) { func NewConsoleSink(options map[string]any, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) {
config := ConsoleConfig{ target := "stdout"
Target: "stdout", if t, ok := options["target"].(string); ok {
BufferSize: 1000, target = t
} }
// Check for split mode configuration bufferSize := int64(1000)
if target, ok := options["target"].(string); ok { if buf, ok := options["buffer_size"].(int64); ok && buf > 0 {
config.Target = target bufferSize = buf
} }
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { // Dedicated logger instance as console writer
config.BufferSize = bufSize writer, err := log.NewBuilder().
EnableFile(false).
EnableConsole(true).
ConsoleTarget(target).
Format("raw"). // Passthrough pre-formatted messages
ShowTimestamp(false). // Disable writer's own timestamp
ShowLevel(false). // Disable writer's own level prefix
Build()
if err != nil {
return nil, fmt.Errorf("failed to create console writer: %w", err)
} }
s := &StdoutSink{ s := &ConsoleSink{
input: make(chan core.LogEntry, config.BufferSize), input: make(chan core.LogEntry, bufferSize),
config: config, writer: writer,
output: os.Stdout,
done: make(chan struct{}), done: make(chan struct{}),
startTime: time.Now(), startTime: time.Now(),
logger: logger, logger: appLogger,
formatter: formatter, formatter: formatter,
} }
s.lastProcessed.Store(time.Time{}) s.lastProcessed.Store(time.Time{})
@ -67,39 +69,52 @@ func NewStdoutSink(options map[string]any, logger *log.Logger, formatter format.
return s, nil return s, nil
} }
func (s *StdoutSink) Input() chan<- core.LogEntry { func (s *ConsoleSink) Input() chan<- core.LogEntry {
return s.input return s.input
} }
func (s *StdoutSink) Start(ctx context.Context) error { func (s *ConsoleSink) Start(ctx context.Context) error {
// Start the internal writer's processing goroutine.
if err := s.writer.Start(); err != nil {
return fmt.Errorf("failed to start console writer: %w", err)
}
go s.processLoop(ctx) go s.processLoop(ctx)
s.logger.Info("msg", "Stdout sink started", s.logger.Info("msg", "Console sink started",
"component", "stdout_sink", "component", "console_sink",
"target", s.config.Target) "target", s.writer.GetConfig().ConsoleTarget)
return nil return nil
} }
func (s *StdoutSink) Stop() { func (s *ConsoleSink) Stop() {
s.logger.Info("msg", "Stopping stdout sink") target := s.writer.GetConfig().ConsoleTarget
s.logger.Info("msg", "Stopping console sink", "target", target)
close(s.done) close(s.done)
s.logger.Info("msg", "Stdout sink stopped")
// Shutdown the internal writer with a timeout.
if err := s.writer.Shutdown(2 * time.Second); err != nil {
s.logger.Error("msg", "Error shutting down console writer",
"component", "console_sink",
"error", err)
}
s.logger.Info("msg", "Console sink stopped", "target", target)
} }
func (s *StdoutSink) GetStats() SinkStats { func (s *ConsoleSink) GetStats() SinkStats {
lastProc, _ := s.lastProcessed.Load().(time.Time) lastProc, _ := s.lastProcessed.Load().(time.Time)
return SinkStats{ return SinkStats{
Type: "stdout", Type: "console",
TotalProcessed: s.totalProcessed.Load(), TotalProcessed: s.totalProcessed.Load(),
StartTime: s.startTime, StartTime: s.startTime,
LastProcessed: lastProc, LastProcessed: lastProc,
Details: map[string]any{ Details: map[string]any{
"target": s.config.Target, "target": s.writer.GetConfig().ConsoleTarget,
}, },
} }
} }
func (s *StdoutSink) processLoop(ctx context.Context) { // processLoop reads entries, formats them, and passes them to the internal writer.
func (s *ConsoleSink) processLoop(ctx context.Context) {
for { for {
select { select {
case entry, ok := <-s.input: case entry, ok := <-s.input:
@ -110,24 +125,30 @@ func (s *StdoutSink) processLoop(ctx context.Context) {
s.totalProcessed.Add(1) s.totalProcessed.Add(1)
s.lastProcessed.Store(time.Now()) s.lastProcessed.Store(time.Now())
// Handle split mode - only process INFO/DEBUG for stdout // Format the entry using the pipeline's configured formatter.
if s.config.Target == "split" {
upperLevel := strings.ToUpper(entry.Level)
if upperLevel == "ERROR" || upperLevel == "WARN" || upperLevel == "WARNING" {
// Skip ERROR/WARN levels in stdout when in split mode
continue
}
}
// Format and write
formatted, err := s.formatter.Format(entry) formatted, err := s.formatter.Format(entry)
if err != nil { if err != nil {
s.logger.Error("msg", "Failed to format log entry for stdout", s.logger.Error("msg", "Failed to format log entry for console",
"component", "stdout_sink", "component", "console_sink",
"error", err) "error", err)
continue continue
} }
s.output.Write(formatted)
// Convert to string to prevent hex encoding of []byte by log package
// Strip new line, writer adds it
message := string(bytes.TrimSuffix(formatted, []byte{'\n'}))
switch strings.ToUpper(entry.Level) {
case "DEBUG":
s.writer.Debug(message)
case "INFO":
s.writer.Info(message)
case "WARN", "WARNING":
s.writer.Warn(message)
case "ERROR", "FATAL":
s.writer.Error(message)
default:
s.writer.Message(message)
}
case <-ctx.Done(): case <-ctx.Done():
return return
@ -137,125 +158,6 @@ func (s *StdoutSink) processLoop(ctx context.Context) {
} }
} }
// Writes log entries to stderr func (s *ConsoleSink) SetAuth(auth *config.AuthConfig) {
type StderrSink struct { // Authentication does not apply to the console sink.
input chan core.LogEntry
config ConsoleConfig
output io.Writer
done chan struct{}
startTime time.Time
logger *log.Logger
formatter format.Formatter
// Statistics
totalProcessed atomic.Uint64
lastProcessed atomic.Value // time.Time
}
// Creates a new stderr sink
func NewStderrSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*StderrSink, error) {
config := ConsoleConfig{
Target: "stderr",
BufferSize: 1000,
}
// Check for split mode configuration
if target, ok := options["target"].(string); ok {
config.Target = target
}
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
config.BufferSize = bufSize
}
s := &StderrSink{
input: make(chan core.LogEntry, config.BufferSize),
config: config,
output: os.Stderr,
done: make(chan struct{}),
startTime: time.Now(),
logger: logger,
formatter: formatter,
}
s.lastProcessed.Store(time.Time{})
return s, nil
}
func (s *StderrSink) Input() chan<- core.LogEntry {
return s.input
}
func (s *StderrSink) Start(ctx context.Context) error {
go s.processLoop(ctx)
s.logger.Info("msg", "Stderr sink started",
"component", "stderr_sink",
"target", s.config.Target)
return nil
}
func (s *StderrSink) Stop() {
s.logger.Info("msg", "Stopping stderr sink")
close(s.done)
s.logger.Info("msg", "Stderr sink stopped")
}
func (s *StderrSink) GetStats() SinkStats {
lastProc, _ := s.lastProcessed.Load().(time.Time)
return SinkStats{
Type: "stderr",
TotalProcessed: s.totalProcessed.Load(),
StartTime: s.startTime,
LastProcessed: lastProc,
Details: map[string]any{
"target": s.config.Target,
},
}
}
func (s *StderrSink) processLoop(ctx context.Context) {
for {
select {
case entry, ok := <-s.input:
if !ok {
return
}
s.totalProcessed.Add(1)
s.lastProcessed.Store(time.Now())
// Handle split mode - only process ERROR/WARN for stderr
if s.config.Target == "split" {
upperLevel := strings.ToUpper(entry.Level)
if upperLevel != "ERROR" && upperLevel != "WARN" && upperLevel != "WARNING" {
// Skip non-ERROR/WARN levels in stderr when in split mode
continue
}
}
// Format and write
formatted, err := s.formatter.Format(entry)
if err != nil {
s.logger.Error("msg", "Failed to format log entry for stderr",
"component", "stderr_sink",
"error", err)
continue
}
s.output.Write(formatted)
case <-ctx.Done():
return
case <-s.done:
return
}
}
}
func (s *StdoutSink) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to stdout sink
}
func (s *StderrSink) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to stderr sink
} }

View File

@ -2,6 +2,7 @@
package sink package sink
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"logwisp/src/internal/config" "logwisp/src/internal/config"
@ -32,12 +33,14 @@ type FileSink struct {
func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*FileSink, error) { func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*FileSink, error) {
directory, ok := options["directory"].(string) directory, ok := options["directory"].(string)
if !ok || directory == "" { if !ok || directory == "" {
return nil, fmt.Errorf("file sink requires 'directory' option") directory = "./"
logger.Warn("No directory or invalid directory provided, current directory will be used")
} }
name, ok := options["name"].(string) name, ok := options["name"].(string)
if !ok || name == "" { if !ok || name == "" {
return nil, fmt.Errorf("file sink requires 'name' option") name = "logwisp.output"
logger.Warn(fmt.Sprintf("No filename provided, %s will be used", name))
} }
// Create configuration for the internal log writer // Create configuration for the internal log writer
@ -77,7 +80,7 @@ func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Fo
} }
// Buffer size for input channel // Buffer size for input channel
// TODO: Make this configurable // TODO: Centralized constant file in core package
bufferSize := int64(1000) bufferSize := int64(1000)
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
bufferSize = bufSize bufferSize = bufSize
@ -152,11 +155,9 @@ func (fs *FileSink) processLoop(ctx context.Context) {
continue continue
} }
// Write formatted bytes (strip newline as writer adds it) // Convert to string to prevent hex encoding of []byte by log package
message := string(formatted) // Strip new line, writer adds it
if len(message) > 0 && message[len(message)-1] == '\n' { message := string(bytes.TrimSuffix(formatted, []byte{'\n'}))
message = message[:len(message)-1]
}
fs.writer.Message(message) fs.writer.Message(message)
case <-ctx.Done(): case <-ctx.Done():

View File

@ -471,6 +471,21 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
} }
} }
// Enforce TLS for authentication
if h.authenticator != nil && h.authConfig.Type != "none" {
isTLS := ctx.IsTLS() || h.tlsManager != nil
if !isTLS {
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "TLS required for authentication",
"hint": "Use HTTPS for authenticated connections",
})
return
}
}
path := string(ctx.Path()) path := string(ctx.Path())
// Status endpoint doesn't require auth // Status endpoint doesn't require auth
@ -811,7 +826,7 @@ func (h *HTTPSink) SetAuth(authCfg *config.AuthConfig) {
} }
h.authConfig = authCfg h.authConfig = authCfg
authenticator, err := auth.New(authCfg, h.logger) authenticator, err := auth.NewAuthenticator(authCfg, h.logger)
if err != nil { if err != nil {
h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink", h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink",
"component", "http_sink", "component", "http_sink",

View File

@ -52,27 +52,29 @@ type HTTPClientSink struct {
// TODO: missing toml tags // TODO: missing toml tags
type HTTPClientConfig struct { type HTTPClientConfig struct {
// Config // Config
URL string URL string `toml:"url"`
BufferSize int64 BufferSize int64 `toml:"buffer_size"`
BatchSize int64 BatchSize int64 `toml:"batch_size"`
BatchDelay time.Duration BatchDelay time.Duration `toml:"batch_delay_ms"`
Timeout time.Duration Timeout time.Duration `toml:"timeout_seconds"`
Headers map[string]string Headers map[string]string `toml:"headers"`
// Retry configuration // Retry configuration
MaxRetries int64 MaxRetries int64 `toml:"max_retries"`
RetryDelay time.Duration RetryDelay time.Duration `toml:"retry_delay"`
RetryBackoff float64 // Multiplier for exponential backoff RetryBackoff float64 `toml:"retry_backoff"` // Multiplier for exponential backoff
// Security // Security
Username string AuthType string `toml:"auth_type"` // "none", "basic", "bearer", "mtls"
Password string Username string `toml:"username"` // For basic auth
Password string `toml:"password"` // For basic auth
BearerToken string `toml:"bearer_token"` // For bearer auth
// TLS configuration // TLS configuration
InsecureSkipVerify bool InsecureSkipVerify bool `toml:"insecure_skip_verify"`
CAFile string CAFile string `toml:"ca_file"`
CertFile string CertFile string `toml:"cert_file"`
KeyFile string KeyFile string `toml:"key_file"`
} }
// Creates a new HTTP client sink // Creates a new HTTP client sink
@ -129,12 +131,64 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
if insecure, ok := options["insecure_skip_verify"].(bool); ok { if insecure, ok := options["insecure_skip_verify"].(bool); ok {
cfg.InsecureSkipVerify = insecure cfg.InsecureSkipVerify = insecure
} }
if authType, ok := options["auth_type"].(string); ok {
switch authType {
case "none", "basic", "bearer", "mtls":
cfg.AuthType = authType
default:
return nil, fmt.Errorf("http_client sink: invalid auth_type '%s'", authType)
}
} else {
cfg.AuthType = "none"
}
if username, ok := options["username"].(string); ok { if username, ok := options["username"].(string); ok {
cfg.Username = username cfg.Username = username
} }
if password, ok := options["password"].(string); ok { if password, ok := options["password"].(string); ok {
cfg.Password = password // TODO: change to Argon2 hashed password cfg.Password = password // TODO: change to Argon2 hashed password
} }
if token, ok := options["bearer_token"].(string); ok {
cfg.BearerToken = token
}
// Validate auth configuration and TLS enforcement
isHTTPS := strings.HasPrefix(cfg.URL, "https://")
switch cfg.AuthType {
case "basic":
if cfg.Username == "" || cfg.Password == "" {
return nil, fmt.Errorf("http_client sink: username and password required for basic auth")
}
if !isHTTPS {
return nil, fmt.Errorf("http_client sink: basic auth requires HTTPS (security: credentials would be sent in plaintext)")
}
case "bearer":
if cfg.BearerToken == "" {
return nil, fmt.Errorf("http_client sink: bearer_token required for bearer auth")
}
if !isHTTPS {
return nil, fmt.Errorf("http_client sink: bearer auth requires HTTPS (security: token would be sent in plaintext)")
}
case "mtls":
if !isHTTPS {
return nil, fmt.Errorf("http_client sink: mTLS requires HTTPS")
}
if cfg.CertFile == "" || cfg.KeyFile == "" {
return nil, fmt.Errorf("http_client sink: cert_file and key_file required for mTLS")
}
case "none":
// Clear any credentials if auth is "none"
if cfg.Username != "" || cfg.Password != "" || cfg.BearerToken != "" {
logger.Warn("msg", "Credentials provided but auth_type is 'none', ignoring",
"component", "http_client_sink")
cfg.Username = ""
cfg.Password = ""
cfg.BearerToken = ""
}
}
// Extract headers // Extract headers
if headers, ok := options["headers"].(map[string]any); ok { if headers, ok := options["headers"].(map[string]any); ok {
@ -416,6 +470,7 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
var lastErr error var lastErr error
retryDelay := h.config.RetryDelay retryDelay := h.config.RetryDelay
// TODO: verify retry loop placement is correct or should it be after acquiring resources (req :=....)
for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ { for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
// Wait before retry // Wait before retry
@ -444,11 +499,22 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short())) req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short()))
// Add Basic Auth header if credentials configured // Add authentication based on auth type
if h.config.Username != "" && h.config.Password != "" { switch h.config.AuthType {
case "basic":
creds := h.config.Username + ":" + h.config.Password creds := h.config.Username + ":" + h.config.Password
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds)) encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
req.Header.Set("Authorization", "Basic "+encodedCreds) req.Header.Set("Authorization", "Basic "+encodedCreds)
case "bearer":
req.Header.Set("Authorization", "Bearer "+h.config.BearerToken)
case "mtls":
// mTLS auth is handled at TLS layer via client certificates
// No Authorization header needed
case "none":
// No authentication
} }
// Set headers // Set headers

View File

@ -602,7 +602,7 @@ func (t *TCPSink) SetAuth(authCfg *config.AuthConfig) {
return return
} }
authenticator, err := auth.New(authCfg, t.logger) authenticator, err := auth.NewAuthenticator(authCfg, t.logger)
if err != nil { if err != nil {
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink", t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
"component", "tcp_sink", "component", "tcp_sink",

View File

@ -4,7 +4,7 @@ package sink
import ( import (
"bufio" "bufio"
"context" "context"
"encoding/base64" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -13,10 +13,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"logwisp/src/internal/auth"
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/format" "logwisp/src/internal/format"
"logwisp/src/internal/scram"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
) )
@ -32,7 +32,6 @@ type TCPClientSink struct {
startTime time.Time startTime time.Time
logger *log.Logger logger *log.Logger
formatter format.Formatter formatter format.Formatter
authenticator *auth.Authenticator
// Reconnection state // Reconnection state
reconnecting atomic.Bool reconnecting atomic.Bool
@ -49,24 +48,22 @@ type TCPClientSink struct {
// Holds TCP client sink configuration // Holds TCP client sink configuration
type TCPClientConfig struct { type TCPClientConfig struct {
Address string Address string `toml:"address"`
BufferSize int64 BufferSize int64 `toml:"buffer_size"`
DialTimeout time.Duration DialTimeout time.Duration `toml:"dial_timeout_seconds"`
WriteTimeout time.Duration WriteTimeout time.Duration `toml:"write_timeout_seconds"`
ReadTimeout time.Duration ReadTimeout time.Duration `toml:"read_timeout_seconds"`
KeepAlive time.Duration KeepAlive time.Duration `toml:"keep_alive_seconds"`
// Security // Security
Username string AuthType string `toml:"auth_type"`
Password string Username string `toml:"username"`
Password string `toml:"password"`
// Reconnection settings // Reconnection settings
ReconnectDelay time.Duration ReconnectDelay time.Duration `toml:"reconnect_delay_ms"`
MaxReconnectDelay time.Duration MaxReconnectDelay time.Duration `toml:"max_reconnect_delay_seconds"`
ReconnectBackoff float64 ReconnectBackoff float64 `toml:"reconnect_backoff"`
// TLS config
TLS *config.TLSConfig
} }
// Creates a new TCP client sink // Creates a new TCP client sink
@ -120,11 +117,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 { if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
cfg.ReconnectBackoff = backoff cfg.ReconnectBackoff = backoff
} }
if username, ok := options["username"].(string); ok { if authType, ok := options["auth_type"].(string); ok {
switch authType {
case "none":
cfg.AuthType = authType
case "scram":
cfg.AuthType = authType
if username, ok := options["username"].(string); ok && username != "" {
cfg.Username = username cfg.Username = username
} else {
return nil, fmt.Errorf("invalid scram username")
} }
if password, ok := options["password"].(string); ok { if password, ok := options["password"].(string); ok && password != "" {
cfg.Password = password cfg.Password = password
} else {
return nil, fmt.Errorf("invalid scram password")
}
default:
return nil, fmt.Errorf("tcp_client sink: invalid auth_type '%s' (must be 'none' or 'scram')", authType)
}
} }
t := &TCPClientSink{ t := &TCPClientSink{
@ -304,49 +315,115 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive) tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
} }
// Handle authentication if credentials configured // SCRAM authentication if credentials configured
if t.config.Username != "" && t.config.Password != "" { if t.config.AuthType == "scram" {
// Read auth challenge if err := t.performSCRAMAuth(conn); err != nil {
reader := bufio.NewReader(conn)
challenge, err := reader.ReadString('\n')
if err != nil {
conn.Close() conn.Close()
return nil, fmt.Errorf("failed to read auth challenge: %w", err) return nil, fmt.Errorf("SCRAM authentication failed: %w", err)
} }
t.logger.Debug("msg", "SCRAM authentication completed",
if strings.TrimSpace(challenge) == "AUTH_REQUIRED" {
// Send credentials
creds := t.config.Username + ":" + t.config.Password
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
authCmd := fmt.Sprintf("AUTH basic %s\n", encodedCreds)
if _, err := conn.Write([]byte(authCmd)); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to send auth: %w", err)
}
// Read response
response, err := reader.ReadString('\n')
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to read auth response: %w", err)
}
if strings.TrimSpace(response) != "AUTH_OK" {
conn.Close()
return nil, fmt.Errorf("authentication failed: %s", response)
}
t.logger.Debug("msg", "TCP authentication successful",
"component", "tcp_client_sink", "component", "tcp_client_sink",
"address", t.config.Address, "address", t.config.Address)
"username", t.config.Username)
}
} }
return conn, nil return conn, nil
} }
func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
reader := bufio.NewReader(conn)
// Create SCRAM client
scramClient := scram.NewClient(t.config.Username, t.config.Password)
// Step 1: Send ClientFirst
clientFirst, err := scramClient.StartAuthentication()
if err != nil {
return fmt.Errorf("failed to start SCRAM: %w", err)
}
clientFirstJSON, _ := json.Marshal(clientFirst)
msg := fmt.Sprintf("SCRAM-FIRST %s\n", clientFirstJSON)
if _, err := conn.Write([]byte(msg)); err != nil {
return fmt.Errorf("failed to send SCRAM-FIRST: %w", err)
}
// Step 2: Receive ServerFirst challenge
response, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read SCRAM challenge: %w", err)
}
parts := strings.Fields(strings.TrimSpace(response))
if len(parts) != 2 || parts[0] != "SCRAM-CHALLENGE" {
return fmt.Errorf("unexpected server response: %s", response)
}
var serverFirst scram.ServerFirst
if err := json.Unmarshal([]byte(parts[1]), &serverFirst); err != nil {
return fmt.Errorf("failed to parse server challenge: %w", err)
}
// Step 3: Process challenge and send proof
clientFinal, err := scramClient.ProcessServerFirst(&serverFirst)
if err != nil {
return fmt.Errorf("failed to process challenge: %w", err)
}
clientFinalJSON, _ := json.Marshal(clientFinal)
msg = fmt.Sprintf("SCRAM-PROOF %s\n", clientFinalJSON)
if _, err := conn.Write([]byte(msg)); err != nil {
return fmt.Errorf("failed to send SCRAM-PROOF: %w", err)
}
// Step 4: Receive ServerFinal
response, err = reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read SCRAM result: %w", err)
}
parts = strings.Fields(strings.TrimSpace(response))
if len(parts) < 1 {
return fmt.Errorf("empty server response")
}
switch parts[0] {
case "SCRAM-OK":
if len(parts) != 2 {
return fmt.Errorf("invalid SCRAM-OK response")
}
var serverFinal scram.ServerFinal
if err := json.Unmarshal([]byte(parts[1]), &serverFinal); err != nil {
return fmt.Errorf("failed to parse server signature: %w", err)
}
// Verify server signature
if err := scramClient.VerifyServerFinal(&serverFinal); err != nil {
return fmt.Errorf("server signature verification failed: %w", err)
}
t.logger.Info("msg", "SCRAM authentication successful",
"component", "tcp_client_sink",
"address", t.config.Address,
"username", t.config.Username,
"session_id", serverFinal.SessionID)
return nil
case "SCRAM-FAIL":
reason := "unknown"
if len(parts) > 1 {
reason = strings.Join(parts[1:], " ")
}
return fmt.Errorf("authentication failed: %s", reason)
default:
return fmt.Errorf("unexpected response: %s", response)
}
}
func (t *TCPClientSink) monitorConnection(conn net.Conn) { func (t *TCPClientSink) monitorConnection(conn net.Conn) {
// Simple connection monitoring by periodic zero-byte reads // Simple connection monitoring by periodic zero-byte reads
ticker := time.NewTicker(5 * time.Second) ticker := time.NewTicker(5 * time.Second)

View File

@ -117,12 +117,6 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok { if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
cfg.MaxConnectionsPerIP = maxPerIP cfg.MaxConnectionsPerIP = maxPerIP
} }
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
cfg.MaxConnectionsPerUser = maxPerUser
}
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
cfg.MaxConnectionsPerToken = maxPerToken
}
if maxTotal, ok := nl["max_connections_total"].(int64); ok { if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.MaxConnectionsTotal = maxTotal cfg.MaxConnectionsTotal = maxTotal
} }
@ -313,6 +307,27 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
} }
} }
// 2.5. Check TLS requirement for auth (early reject)
if h.authenticator != nil && h.authConfig.Type != "none" {
// Check if connection is TLS
isTLS := ctx.IsTLS() || h.tlsManager != nil
if !isTLS {
h.logger.Error("msg", "Authentication configured but connection is not TLS",
"component", "http_source",
"remote_addr", remoteAddr,
"auth_type", h.authConfig.Type)
ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "TLS required for authentication",
"hint": "Use HTTPS to submit authenticated requests",
})
return
}
}
// 3. Path check (only process ingest path) // 3. Path check (only process ingest path)
path := string(ctx.Path()) path := string(ctx.Path())
if path != h.path { if path != h.path {
@ -543,7 +558,7 @@ func (h *HTTPSource) SetAuth(authCfg *config.AuthConfig) {
} }
h.authConfig = authCfg h.authConfig = authCfg
authenticator, err := auth.New(authCfg, h.logger) authenticator, err := auth.NewAuthenticator(authCfg, h.logger)
if err != nil { if err != nil {
h.logger.Error("msg", "Failed to initialize authenticator for HTTP source", h.logger.Error("msg", "Failed to initialize authenticator for HTTP source",
"component", "http_source", "component", "http_source",

View File

@ -4,6 +4,7 @@ package source
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
@ -16,6 +17,7 @@ import (
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/limit" "logwisp/src/internal/limit"
"logwisp/src/internal/scram"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
"github.com/lixenwraith/log/compat" "github.com/lixenwraith/log/compat"
@ -41,7 +43,7 @@ type TCPSource struct {
wg sync.WaitGroup wg sync.WaitGroup
netLimiter *limit.NetLimiter netLimiter *limit.NetLimiter
logger *log.Logger logger *log.Logger
authenticator *auth.Authenticator scramManager *scram.ScramManager
// Statistics // Statistics
totalEntries atomic.Uint64 totalEntries atomic.Uint64
@ -255,12 +257,13 @@ func (t *TCPSource) publish(entry core.LogEntry) bool {
// Represents a connected TCP client // Represents a connected TCP client
type tcpClient struct { type tcpClient struct {
conn gnet.Conn conn gnet.Conn
buffer bytes.Buffer buffer *bytes.Buffer
authenticated bool authenticated bool
authTimeout time.Time authTimeout time.Time
session *auth.Session session *auth.Session
maxBufferSeen int maxBufferSeen int
cumulativeEncrypted int64 cumulativeEncrypted int64
scramState *scram.HandshakeState
} }
// Handles gnet events // Handles gnet events
@ -314,11 +317,9 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
// Create client state // Create client state
client := &tcpClient{ client := &tcpClient{
conn: c, conn: c,
authenticated: s.source.authenticator == nil, buffer: bytes.NewBuffer(nil),
} authTimeout: time.Now().Add(30 * time.Second),
authenticated: s.source.scramManager == nil,
if s.source.authenticator != nil {
client.authTimeout = time.Now().Add(30 * time.Second)
} }
s.mu.Lock() s.mu.Lock()
@ -330,12 +331,7 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
"component", "tcp_source", "component", "tcp_source",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
"active_connections", newCount, "active_connections", newCount,
"requires_auth", s.source.authenticator != nil) "requires_auth", s.source.scramManager != nil)
// Send auth challenge if required
if s.source.authenticator != nil {
return []byte("AUTH_REQUIRED\n"), gnet.None
}
return nil, gnet.None return nil, gnet.None
} }
@ -380,52 +376,107 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.Close return gnet.Close
} }
// Authentication phase // SCRAM Authentication phase
if !client.authenticated { if !client.authenticated && s.source.scramManager != nil {
if time.Now().After(client.authTimeout) { // Check auth timeout
if !client.authTimeout.IsZero() && time.Now().After(client.authTimeout) {
s.source.logger.Warn("msg", "Authentication timeout", s.source.logger.Warn("msg", "Authentication timeout",
"component", "tcp_source", "component", "tcp_source",
"remote_addr", c.RemoteAddr().String()) "remote_addr", c.RemoteAddr().String())
return gnet.Close return gnet.Close
} }
if len(data) == 0 {
return gnet.None
}
client.buffer.Write(data) client.buffer.Write(data)
// Look for auth line // Look for complete line
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 { for {
idx := bytes.IndexByte(client.buffer.Bytes(), '\n')
if idx < 0 {
break
}
line := client.buffer.Bytes()[:idx] line := client.buffer.Bytes()[:idx]
client.buffer.Next(idx + 1) client.buffer.Next(idx + 1)
parts := strings.SplitN(string(line), " ", 3) // Parse SCRAM messages
if len(parts) != 3 || parts[0] != "AUTH" { parts := strings.Fields(string(line))
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil) if len(parts) < 2 {
c.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil)
return gnet.Close return gnet.Close
} }
session, err := s.source.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String()) switch parts[0] {
case "SCRAM-FIRST":
// Parse ClientFirst JSON
var clientFirst scram.ClientFirst
if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil {
c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
return gnet.Close
}
// Process with SCRAM server
serverFirst, err := s.source.scramManager.HandleClientFirst(&clientFirst)
if err != nil { if err != nil {
s.source.authFailures.Add(1) // Still send challenge to prevent user enumeration
s.source.logger.Warn("msg", "Authentication failed", response, _ := json.Marshal(serverFirst)
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
return gnet.Close
}
// Send ServerFirst challenge
response, _ := json.Marshal(serverFirst)
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
case "SCRAM-PROOF":
// Parse ClientFinal JSON
var clientFinal scram.ClientFinal
if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil {
c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
return gnet.Close
}
// Verify proof
serverFinal, err := s.source.scramManager.HandleClientFinal(&clientFinal)
if err != nil {
s.source.logger.Warn("msg", "SCRAM authentication failed",
"component", "tcp_source", "component", "tcp_source",
"remote_addr", c.RemoteAddr().String(), "remote_addr", c.RemoteAddr().String(),
"error", err) "error", err)
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil) c.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil)
return gnet.Close return gnet.Close
} }
s.source.authSuccesses.Add(1) // Authentication successful
s.mu.Lock() s.mu.Lock()
client.authenticated = true client.authenticated = true
client.session = session client.session = &auth.Session{
ID: serverFinal.SessionID,
Method: "scram-sha-256",
RemoteAddr: c.RemoteAddr().String(),
CreatedAt: time.Now(),
}
s.mu.Unlock() s.mu.Unlock()
s.source.logger.Info("msg", "TCP client authenticated", // Send ServerFinal with signature
response, _ := json.Marshal(serverFinal)
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil)
s.source.logger.Info("msg", "Client authenticated via SCRAM",
"component", "tcp_source", "component", "tcp_source",
"remote_addr", c.RemoteAddr().String(), "remote_addr", c.RemoteAddr().String(),
"username", session.Username) "session_id", serverFinal.SessionID)
c.AsyncWrite([]byte("AUTH_OK\n"), nil) // Clear auth buffer
client.buffer.Reset() client.buffer.Reset()
default:
c.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil)
return gnet.Close
}
} }
return gnet.None return gnet.None
} }
@ -522,22 +573,46 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.None return gnet.None
} }
func (t *TCPSource) InitSCRAMManager(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type != "scram" || authCfg.ScramAuth == nil {
return
}
t.scramManager = scram.NewScramManager()
// Load users from SCRAM config
for _, user := range authCfg.ScramAuth.Users {
storedKey, _ := base64.StdEncoding.DecodeString(user.StoredKey)
serverKey, _ := base64.StdEncoding.DecodeString(user.ServerKey)
salt, _ := base64.StdEncoding.DecodeString(user.Salt)
cred := &scram.Credential{
Username: user.Username,
StoredKey: storedKey,
ServerKey: serverKey,
Salt: salt,
ArgonTime: user.ArgonTime,
ArgonMemory: user.ArgonMemory,
ArgonThreads: user.ArgonThreads,
}
t.scramManager.AddCredential(cred)
}
t.logger.Info("msg", "SCRAM authentication configured",
"component", "tcp_source",
"users", len(authCfg.ScramAuth.Users))
}
// Configure TCP source auth // Configure TCP source auth
func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) { func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" { if authCfg == nil || authCfg.Type == "none" {
return return
} }
authenticator, err := auth.New(authCfg, t.logger) // Initialize SCRAM manager
if err != nil { if authCfg.Type == "scram" {
t.logger.Error("msg", "Failed to initialize authenticator for TCP source", t.InitSCRAMManager(authCfg)
"component", "tcp_source", t.logger.Info("msg", "SCRAM authentication configured for TCP source",
"error", err) "component", "tcp_source")
return
} }
t.authenticator = authenticator
t.logger.Info("msg", "Authentication configured for TCP source",
"component", "tcp_source",
"auth_type", authCfg.Type)
} }

View File

@ -9,6 +9,7 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io"
"math/big" "math/big"
"net" "net"
"os" "os"
@ -16,14 +17,21 @@ import (
"time" "time"
) )
type CertGeneratorCommand struct{} type CertGeneratorCommand struct {
output io.Writer
func NewCertGeneratorCommand() *CertGeneratorCommand { errOut io.Writer
return &CertGeneratorCommand{}
} }
func (c *CertGeneratorCommand) Execute(args []string) error { func NewCertGeneratorCommand() *CertGeneratorCommand {
return &CertGeneratorCommand{
output: os.Stdout,
errOut: os.Stderr,
}
}
func (cg *CertGeneratorCommand) Execute(args []string) error {
cmd := flag.NewFlagSet("tls", flag.ContinueOnError) cmd := flag.NewFlagSet("tls", flag.ContinueOnError)
cmd.SetOutput(cg.errOut)
// Subcommands // Subcommands
var ( var (
@ -50,20 +58,21 @@ func (c *CertGeneratorCommand) Execute(args []string) error {
) )
cmd.Usage = func() { cmd.Usage = func() {
fmt.Fprintln(os.Stderr, "Generate TLS certificates for LogWisp") fmt.Fprintln(cg.errOut, "Generate TLS certificates for LogWisp")
fmt.Fprintln(os.Stderr, "\nUsage: logwisp tls [options]") fmt.Fprintln(cg.errOut, "\nUsage: logwisp tls [options]")
fmt.Fprintln(os.Stderr, "\nExamples:") fmt.Fprintln(cg.errOut, "\nExamples:")
fmt.Fprintln(os.Stderr, " # Generate self-signed certificate") fmt.Fprintln(cg.errOut, " # Generate self-signed certificate")
fmt.Fprintln(os.Stderr, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1") fmt.Fprintln(cg.errOut, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1")
fmt.Fprintln(os.Stderr, " ") fmt.Fprintln(cg.errOut, " ")
fmt.Fprintln(os.Stderr, " # Generate CA certificate") fmt.Fprintln(cg.errOut, " # Generate CA certificate")
fmt.Fprintln(os.Stderr, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key") fmt.Fprintln(cg.errOut, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key")
fmt.Fprintln(os.Stderr, " ") fmt.Fprintln(cg.errOut, " ")
fmt.Fprintln(os.Stderr, " # Generate server certificate signed by CA") fmt.Fprintln(cg.errOut, " # Generate server certificate signed by CA")
fmt.Fprintln(os.Stderr, " logwisp tls --server --cn server.example.com --hosts server.example.com \\") fmt.Fprintln(cg.errOut, " logwisp tls --server --cn server.example.com --hosts server.example.com \\")
fmt.Fprintln(os.Stderr, " --ca-cert ca.crt --ca-key ca.key") fmt.Fprintln(cg.errOut, " --ca-cert ca.crt --ca-key ca.key")
fmt.Fprintln(os.Stderr, "\nOptions:") fmt.Fprintln(cg.errOut, "\nOptions:")
cmd.PrintDefaults() cmd.PrintDefaults()
fmt.Fprintln(cg.errOut)
} }
if err := cmd.Parse(args); err != nil { if err := cmd.Parse(args); err != nil {
@ -79,13 +88,13 @@ func (c *CertGeneratorCommand) Execute(args []string) error {
// Route to appropriate generator // Route to appropriate generator
switch { switch {
case *genCA: case *genCA:
return c.generateCA(*commonName, *org, *country, *validDays, *keySize, *certOut, *keyOut) return cg.generateCA(*commonName, *org, *country, *validDays, *keySize, *certOut, *keyOut)
case *selfSign: case *selfSign:
return c.generateSelfSigned(*commonName, *org, *country, *hosts, *validDays, *keySize, *certOut, *keyOut) return cg.generateSelfSigned(*commonName, *org, *country, *hosts, *validDays, *keySize, *certOut, *keyOut)
case *genServer: case *genServer:
return c.generateServerCert(*commonName, *org, *country, *hosts, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) return cg.generateServerCert(*commonName, *org, *country, *hosts, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut)
case *genClient: case *genClient:
return c.generateClientCert(*commonName, *org, *country, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) return cg.generateClientCert(*commonName, *org, *country, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut)
default: default:
cmd.Usage() cmd.Usage()
return fmt.Errorf("specify certificate type: --ca, --self-signed, --server, or --client") return fmt.Errorf("specify certificate type: --ca, --self-signed, --server, or --client")
@ -93,7 +102,7 @@ func (c *CertGeneratorCommand) Execute(args []string) error {
} }
// Create and manage private CA // Create and manage private CA
func (c *CertGeneratorCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error { func (cg *CertGeneratorCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error {
// Generate RSA key // Generate RSA key
priv, err := rsa.GenerateKey(rand.Reader, bits) priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil { if err != nil {
@ -169,7 +178,7 @@ func parseHosts(hostList string) ([]string, []net.IP) {
} }
// Generate self-signed certificate // Generate self-signed certificate
func (c *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error { func (cg *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error {
// 1. Generate an RSA private key with the specified bit size // 1. Generate an RSA private key with the specified bit size
priv, err := rsa.GenerateKey(rand.Reader, bits) priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil { if err != nil {
@ -236,7 +245,7 @@ func (c *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string
} }
// Generate server cert with CA // Generate server cert with CA
func (c *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { func (cg *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
caCert, caKey, err := loadCA(caFile, caKeyFile) caCert, caKey, err := loadCA(caFile, caKeyFile)
if err != nil { if err != nil {
return err return err
@ -299,7 +308,7 @@ func (c *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFil
} }
// Generate client cert with CA // Generate client cert with CA
func (c *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { func (cg *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
caCert, caKey, err := loadCA(caFile, caKeyFile) caCert, caKey, err := loadCA(caFile, caKeyFile)
if err != nil { if err != nil {
return err return err
@ -356,66 +365,105 @@ func (c *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKe
} }
// Load cert with CA // Load cert with CA
func loadCA(caFile, caKeyFile string) (*x509.Certificate, *rsa.PrivateKey, error) { func loadCA(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) {
if caFile == "" || caKeyFile == "" { // Load CA certificate
return nil, nil, fmt.Errorf("--ca-cert and --ca-key are required for signing") certPEM, err := os.ReadFile(certFile)
}
caCertPEM, err := os.ReadFile(caFile)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to read CA certificate: %w", err) return nil, nil, fmt.Errorf("failed to read CA certificate: %w", err)
} }
caCertBlock, _ := pem.Decode(caCertPEM)
if caCertBlock == nil { certBlock, _ := pem.Decode(certPEM)
return nil, nil, fmt.Errorf("failed to decode CA certificate PEM") if certBlock == nil || certBlock.Type != "CERTIFICATE" {
return nil, nil, fmt.Errorf("invalid CA certificate format")
} }
caCert, err := x509.ParseCertificate(caCertBlock.Bytes)
caCert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to parse CA certificate: %w", err) return nil, nil, fmt.Errorf("failed to parse CA certificate: %w", err)
} }
if !caCert.IsCA { // Load CA private key
return nil, nil, fmt.Errorf("provided certificate is not a valid CA") keyPEM, err := os.ReadFile(keyFile)
}
caKeyPEM, err := os.ReadFile(caKeyFile)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to read CA key: %w", err) return nil, nil, fmt.Errorf("failed to read CA key: %w", err)
} }
caKeyBlock, _ := pem.Decode(caKeyPEM)
if caKeyBlock == nil { keyBlock, _ := pem.Decode(keyPEM)
return nil, nil, fmt.Errorf("failed to decode CA key PEM") if keyBlock == nil {
return nil, nil, fmt.Errorf("invalid CA key format")
} }
caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes)
var caKey *rsa.PrivateKey
switch keyBlock.Type {
case "RSA PRIVATE KEY":
caKey, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "PRIVATE KEY":
parsedKey, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse CA key: %w", err)
}
var ok bool
caKey, ok = parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, nil, fmt.Errorf("CA key is not RSA")
}
default:
return nil, nil, fmt.Errorf("unsupported CA key type: %s", keyBlock.Type)
}
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to parse CA private key: %w", err) return nil, nil, fmt.Errorf("failed to parse CA private key: %w", err)
} }
// Verify key matches certificate // Verify CA certificate is actually a CA
if caCert.PublicKey.(*rsa.PublicKey).N.Cmp(caKey.N) != 0 { if !caCert.IsCA {
return nil, nil, fmt.Errorf("CA private key does not match CA certificate") return nil, nil, fmt.Errorf("certificate is not a CA certificate")
} }
return caCert, caKey, nil return caCert, caKey, nil
} }
func saveCert(filename string, derBytes []byte) error { func saveCert(filename string, certDER []byte) error {
certOut, err := os.Create(filename) certFile, err := os.Create(filename)
if err != nil { if err != nil {
return fmt.Errorf("failed to create cert file %s: %w", filename, err) return fmt.Errorf("failed to create certificate file: %w", err)
} }
defer certOut.Close() defer certFile.Close()
return pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
if err := pem.Encode(certFile, &pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
}); err != nil {
return fmt.Errorf("failed to write certificate: %w", err)
}
// Set readable permissions
if err := os.Chmod(filename, 0644); err != nil {
return fmt.Errorf("failed to set certificate permissions: %w", err)
}
return nil
} }
func saveKey(filename string, key *rsa.PrivateKey) error { func saveKey(filename string, key *rsa.PrivateKey) error {
keyOut, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) keyFile, err := os.Create(filename)
if err != nil { if err != nil {
return fmt.Errorf("failed to create key file %s: %w", filename, err) return fmt.Errorf("failed to create key file: %w", err)
} }
defer keyOut.Close() defer keyFile.Close()
return pem.Encode(keyOut, &pem.Block{
privKeyDER := x509.MarshalPKCS1PrivateKey(key)
if err := pem.Encode(keyFile, &pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key), Bytes: privKeyDER,
}) }); err != nil {
return fmt.Errorf("failed to write private key: %w", err)
}
// Set restricted permissions for private key
if err := os.Chmod(filename, 0600); err != nil {
return fmt.Errorf("failed to set key permissions: %w", err)
}
return nil
} }

View File

@ -6,7 +6,7 @@
### Notes: ### Notes:
- The tests create configuration files and log files. Debug tests do not clean up these files. - The tests create configuration files and log files. Most tests set logging at debug level and don't clean up their temp files that are created in the current execution directory.
- Some tests may require to be run on different hosts (containers can be used). - Some tests may need to be run on different hosts (containers can be used).

View File

@ -1,5 +1,5 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# FILE: test-logwisp-auth-debug.sh # FILE: test-basic-auth.sh
# Creates test directories and starts network services # Creates test directories and starts network services
set -e set -e
@ -32,7 +32,7 @@ fi
cat > test-auth.toml << EOF cat > test-auth.toml << EOF
# General LogWisp settings # General LogWisp settings
log_dir = "test-logs" log_dir = "test-logs"
log_level = "debug" # CHANGED: Set to debug log_level = "debug"
data_dir = "test-data" data_dir = "test-data"
# Logging configuration for troubleshooting # Logging configuration for troubleshooting
@ -41,7 +41,7 @@ target = "all"
level = "debug" level = "debug"
[logging.console] [logging.console]
enabled = true enabled = true
target = "stdout" # CHANGED: Log to stdout for visibility target = "stdout"
format = "txt" format = "txt"
[[pipelines]] [[pipelines]]
@ -59,7 +59,9 @@ port = 5514
host = "127.0.0.1" host = "127.0.0.1"
[[pipelines.sinks]] [[pipelines.sinks]]
type = "stdout" type = "console"
[pipelines.sinks.options]
target = "stdout"
# Second pipeline for HTTP # Second pipeline for HTTP
[[pipelines]] [[pipelines]]
@ -78,7 +80,10 @@ host = "127.0.0.1"
path = "/ingest" path = "/ingest"
[[pipelines.sinks]] [[pipelines.sinks]]
type = "stdout" # CHANGED: Simplify to stdout for debugging type = "console"
[pipelines.sinks.options]
target = "stdout"
EOF EOF
# Start LogWisp with visible debug output # Start LogWisp with visible debug output