From 490fb777ab13a70b4442916cc73865bc219c62393c15bba25945cc0893617d39 Mon Sep 17 00:00:00 2001 From: Lixen Wraith Date: Thu, 2 Oct 2025 17:16:43 -0400 Subject: [PATCH] v0.6.0 auth restructuring, scram auth added, more tests added --- config/logwisp.toml.defaults | 200 ++++++------ go.mod | 3 +- go.sum | 4 - src/cmd/logwisp/commands.go | 2 +- src/cmd/logwisp/status.go | 12 +- src/internal/auth/authenticator.go | 301 +----------------- src/internal/auth/generator.go | 140 +++++--- src/internal/config/auth.go | 66 ++-- src/internal/config/loader.go | 9 +- src/internal/config/pipeline.go | 6 +- src/internal/filter/chain_test.go | 80 +++++ src/internal/filter/filter.go | 31 ++ src/internal/filter/filter_test.go | 159 +++++++++ src/internal/format/format.go | 2 +- src/internal/format/format_test.go | 65 ++++ src/internal/format/json_test.go | 129 ++++++++ src/internal/format/raw_test.go | 29 ++ src/internal/format/text.go | 2 +- src/internal/format/text_test.go | 81 +++++ src/internal/scram/client.go | 106 ++++++ src/internal/scram/credential.go | 99 ++++++ src/internal/scram/integration.go | 117 +++++++ src/internal/scram/message.go | 101 ++++++ src/internal/scram/scram_test.go | 228 +++++++++++++ src/internal/scram/server.go | 179 +++++++++++ src/internal/service/service.go | 6 +- src/internal/sink/console.go | 250 +++++---------- src/internal/sink/file.go | 17 +- src/internal/sink/http.go | 17 +- src/internal/sink/http_client.go | 100 +++++- src/internal/sink/tcp.go | 2 +- src/internal/sink/tcp_client.go | 211 ++++++++---- src/internal/source/http.go | 29 +- src/internal/source/tcp.go | 203 ++++++++---- src/internal/tls/generator.go | 166 ++++++---- test/README.md | 4 +- ...gwisp-auth-debug.sh => test-basic-auth.sh} | 15 +- 37 files changed, 2283 insertions(+), 888 deletions(-) create mode 100644 src/internal/filter/chain_test.go create mode 100644 src/internal/filter/filter_test.go create mode 100644 src/internal/format/format_test.go create mode 100644 src/internal/format/json_test.go create mode 100644 src/internal/format/raw_test.go create mode 100644 src/internal/format/text_test.go create mode 100644 src/internal/scram/client.go create mode 100644 src/internal/scram/credential.go create mode 100644 src/internal/scram/integration.go create mode 100644 src/internal/scram/message.go create mode 100644 src/internal/scram/scram_test.go create mode 100644 src/internal/scram/server.go rename test/{test-logwisp-auth-debug.sh => test-basic-auth.sh} (94%) diff --git a/config/logwisp.toml.defaults b/config/logwisp.toml.defaults index 8fde11f..c2e4dbd 100644 --- a/config/logwisp.toml.defaults +++ b/config/logwisp.toml.defaults @@ -3,16 +3,16 @@ ### Configuration Precedence: CLI flags > Environment > File > Defaults ### Default values shown - uncommented lines represent active configuration -### Global settings +### Global Settings background = false # Run as daemon quiet = false # Suppress console output -disable_status_reporter = false # Status logging -config_auto_reload = false # File change detection +disable_status_reporter = false # Disable status logging +config_auto_reload = false # Reload config on file change config_save_on_exit = false # Persist runtime changes ### Logging Configuration [logging] -output = "stdout" # file|stdout|stderr|split|all|none +output = "stdout" # file|stdout|stderr|both|none level = "info" # debug|info|warn|error [logging.file] @@ -20,7 +20,7 @@ directory = "./log" # Log directory path name = "logwisp" # Base filename max_size_mb = 100 # Rotation threshold max_total_size_mb = 1000 # Total size limit -retention_hours = 168.0 # Delete logs older than +retention_hours = 168.0 # Delete logs older than (7 days) [logging.console] target = "stdout" # stdout|stderr|split @@ -30,38 +30,56 @@ format = "txt" # txt|json [[pipelines]] name = "default" # Pipeline identifier -### Directory Sources +### Rate Limiting (Pipeline-level) +# [pipelines.rate_limit] +# rate = 0.0 # Entries per second (0=disabled) +# burst = 0.0 # Burst capacity (defaults to rate) +# policy = "pass" # pass|drop +# max_entry_size_bytes = 0 # Max entry size (0=unlimited) + +### Filters +# [[pipelines.filters]] +# type = "include" # include|exclude +# logic = "or" # or|and +# patterns = [".*ERROR.*", ".*WARN.*"] # Regex patterns + +### Sources + +### Directory Source [[pipelines.sources]] type = "directory" + [pipelines.sources.options] path = "./" # Directory to monitor pattern = "*.log" # Glob pattern -check_interval_ms = 100 # Scan interval +check_interval_ms = 100 # Scan interval (min: 10ms) -### Console Sources +### Stdin Source # [[pipelines.sources]] # type = "stdin" + # [pipelines.sources.options] # buffer_size = 1000 # Input buffer size -### HTTP Sources +### HTTP Source # [[pipelines.sources]] # type = "http" # [pipelines.sources.options] # host = "0.0.0.0" # Listen address # port = 8081 # Listen port -# path = "/ingest" # Ingest endpoint -# max_body_size = 1048576 # Max request size +# ingest_path = "/ingest" # Ingest endpoint +# buffer_size = 1000 # Input buffer size +# max_body_size = 1048576 # Max request size bytes # [pipelines.sources.options.tls] # enabled = false # Enable TLS -# cert_file = "" # TLS certificate -# key_file = "" # TLS key +# cert_file = "" # Server certificate +# key_file = "" # Server key # client_auth = false # Require client certs # client_ca_file = "" # Client CA cert # verify_client_cert = false # Verify client certs -# insecure_skip_verify = false # Skip verification +# insecure_skip_verify = false # Skip verification (server-side) # ca_file = "" # Custom CA file # min_version = "TLS1.2" # Min TLS version # max_version = "TLS1.3" # Max TLS version @@ -69,8 +87,8 @@ check_interval_ms = 100 # Scan interval # [pipelines.sources.options.net_limit] # enabled = false # Enable rate limiting -# ip_whitelist = [] # Allowed IPs/CIDRs -# ip_blacklist = [] # Blocked IPs/CIDRs +# ip_whitelist = [] # Allowed IPs/CIDRs (IPv4 only) +# ip_blacklist = [] # Blocked IPs/CIDRs (IPv4 only) # requests_per_second = 100.0 # Rate limit per client # burst_size = 100 # Burst capacity # response_code = 429 # HTTP status when limited @@ -78,59 +96,32 @@ check_interval_ms = 100 # Scan interval # max_connections_per_ip = 10 # Max concurrent per IP # max_connections_total = 1000 # Max total connections -### TCP Sources +### TCP Source # [[pipelines.sources]] # type = "tcp" # [pipelines.sources.options] # host = "0.0.0.0" # Listen address # port = 9091 # Listen port +# buffer_size = 1000 # Input buffer size # [pipelines.sources.options.net_limit] # enabled = false # Enable rate limiting -# ip_whitelist = [] # Allowed IPs/CIDRs -# ip_blacklist = [] # Blocked IPs/CIDRs +# ip_whitelist = [] # Allowed IPs/CIDRs (IPv4 only) +# ip_blacklist = [] # Blocked IPs/CIDRs (IPv4 only) # requests_per_second = 100.0 # Rate limit per client # burst_size = 100 # Burst capacity -# response_code = 429 # Response code when limited +# response_code = 429 # TCP rejection # response_message = "Rate limit exceeded" # max_connections_per_ip = 10 # Max concurrent per IP # max_connections_per_user = 10 # Max concurrent per user # max_connections_per_token = 10 # Max concurrent per token # max_connections_total = 1000 # Max total connections -### Rate limiting -# [pipelines.rate_limit] -# rate = 0.0 # Entries/second (0=unlimited) -# burst = 0.0 # Burst capacity -# policy = "drop" # pass|drop -# max_entry_size_bytes = 0 # Entry size limit +### Format Configuration -### Filters -# [[pipelines.filters]] -# type = "include" # include|exclude -# logic = "or" # or|and -# patterns = [] # Regex patterns - -## Examples of filter patterns: -## Include only error or fatal logs containing "database": -## type = "include" -## logic = "and" -## patterns = ["(?i)(error|fatal)", "database"] -## -## Exclude debug logs from test environment: -## type = "exclude" -## logic = "or" -## patterns = ["(?i)debug", "test-env"] -## -## Include only JSON formatted logs: -## type = "include" -## patterns = ["^\\{.*\\}$"] - -### Format - -### Raw formatter (default) -# format = "raw" # raw|json|text +### Raw formatter (default - passes through unchanged) +# format = "raw" ### No options for raw formatter ### JSON formatter @@ -143,12 +134,14 @@ check_interval_ms = 100 # Scan interval # source_field = "source" # Source field name ### Text formatter -# format = "text" +# format = "txt" # [pipelines.format_options] # template = "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}" # timestamp_format = "2006-01-02T15:04:05Z07:00" # Go time format -### HTTP Sinks +### Sinks + +### HTTP Sink (SSE Server) [[pipelines.sinks]] type = "http" @@ -162,14 +155,14 @@ status_path = "/status" # Status endpoint [pipelines.sinks.options.heartbeat] enabled = true # Send heartbeats interval_seconds = 30 # Heartbeat interval -include_timestamp = true # Include time +include_timestamp = true # Include timestamp include_stats = false # Include statistics format = "comment" # comment|message # [pipelines.sinks.options.tls] # enabled = false # Enable TLS -# cert_file = "" # TLS certificate -# key_file = "" # TLS key +# cert_file = "" # Server certificate +# key_file = "" # Server key # client_auth = false # Require client certs # client_ca_file = "" # Client CA cert # verify_client_cert = false # Verify client certs @@ -181,8 +174,8 @@ format = "comment" # comment|message # [pipelines.sinks.options.net_limit] # enabled = false # Enable rate limiting -# ip_whitelist = [] # Allowed IPs/CIDRs -# ip_blacklist = [] # Blocked IPs/CIDRs +# ip_whitelist = [] # Allowed IPs/CIDRs (IPv4 only) +# ip_blacklist = [] # Blocked IPs/CIDRs (IPv4 only) # requests_per_second = 100.0 # Rate limit per client # burst_size = 100 # Burst capacity # response_code = 429 # HTTP status when limited @@ -190,7 +183,7 @@ format = "comment" # comment|message # max_connections_per_ip = 10 # Max concurrent per IP # max_connections_total = 1000 # Max total connections -### TCP Sinks +### TCP Sink (TCP Server) # [[pipelines.sinks]] # type = "tcp" @@ -198,28 +191,33 @@ format = "comment" # comment|message # host = "0.0.0.0" # Listen address # port = 9090 # Server port # buffer_size = 1000 # Buffer size +# auth_type = "none" # none|scram # [pipelines.sinks.options.heartbeat] # enabled = false # Send heartbeats # interval_seconds = 30 # Heartbeat interval -# include_timestamp = false # Include time +# include_timestamp = false # Include timestamp # include_stats = false # Include statistics # format = "comment" # comment|message # [pipelines.sinks.options.net_limit] # enabled = false # Enable rate limiting -# ip_whitelist = [] # Allowed IPs/CIDRs -# ip_blacklist = [] # Blocked IPs/CIDRs +# ip_whitelist = [] # Allowed IPs/CIDRs (IPv4 only) +# ip_blacklist = [] # Blocked IPs/CIDRs (IPv4 only) # requests_per_second = 100.0 # Rate limit per client # burst_size = 100 # Burst capacity -# response_code = 429 # HTTP status when limited +# response_code = 429 # TCP rejection code # response_message = "Rate limit exceeded" # max_connections_per_ip = 10 # Max concurrent per IP # max_connections_per_user = 10 # Max concurrent per user # max_connections_per_token = 10 # Max concurrent per token # max_connections_total = 1000 # Max total connections -### HTTP Client Sinks +# [pipelines.sinks.options.scram] +# username = "" # SCRAM auth username +# password = "" # SCRAM auth password + +### HTTP Client Sink (Forward to remote HTTP endpoint) # [[pipelines.sinks]] # type = "http_client" @@ -231,16 +229,31 @@ format = "comment" # comment|message # timeout_seconds = 30 # Request timeout # max_retries = 3 # Retry attempts # retry_delay_ms = 1000 # Initial retry delay -# retry_backoff = 2.0 # Exponential backoff +# retry_backoff = 2.0 # Exponential backoff multiplier # insecure_skip_verify = false # Skip TLS verification -# ca_file = "" # Custom CA certificate -# headers = {} # Custom HTTP headers +# auth_type = "none" # none|basic|bearer|mtls + +# [pipelines.sinks.options.basic] +# username = "" # Basic auth username +# password_hash = "" # Argon2 password hash + +# [pipelines.sinks.options.bearer] +# token = "" # Bearer token + +====== not needed: +## Custom HTTP headers +# [pipelines.sinks.options.headers] +# Content-Type = "application/json" +# Authorization = "Bearer token" + +## Client certificate for mTLS # [pipelines.sinks.options.tls] +# ca_file = "" # Custom CA certificate # cert_file = "" # Client certificate # key_file = "" # Client key -### TCP Client Sinks +### TCP Client Sink (Forward to remote TCP endpoint) # [[pipelines.sinks]] # type = "tcp_client" @@ -253,23 +266,21 @@ format = "comment" # comment|message # keep_alive_seconds = 30 # TCP keepalive # reconnect_delay_ms = 1000 # Initial reconnect delay # max_reconnect_delay_seconds = 30 # Max reconnect delay -# reconnect_backoff = 1.5 # Exponential backoff +# reconnect_backoff = 1.5 # Exponential backoff multiplier -# [pipelines.sinks.options.tls] -# enabled = false # Enable TLS -# insecure_skip_verify = false # Skip verification -# ca_file = "" # Custom CA certificate -# cert_file = "" # Client certificate -# key_file = "" # Client key -### File Sinks +# [pipelines.sinks.options.scram] +# username = "" # Auth username +# password_hash = "" # Argon2 password hash + +### File Sink # [[pipelines.sinks]] # type = "file" # [pipelines.sinks.options] -# directory = "" # Output dir (required) -# name = "" # Base name (required) -# buffer_size = 1000 # Input channel buffer size +# directory = "./" # Output dir +# name = "logwisp.output" # Base name +# buffer_size = 1000 # Input channel buffer # max_size_mb = 100 # Rotation size # max_total_size_mb = 0 # Total limit (0=unlimited) # retention_hours = 0.0 # Retention (0=disabled) @@ -277,39 +288,32 @@ format = "comment" # comment|message ### Console Sinks # [[pipelines.sinks]] -# type = "stdout" +# type = "console" # [pipelines.sinks.options] +# target = "stdout" # stdout|stderr|split # buffer_size = 1000 # Buffer size -# target = "stdout" # Override for split mode -# [[pipelines.sinks]] -# type = "stderr" - -# [pipelines.sinks.options] -# buffer_size = 1000 # Buffer size -# target = "stderr" # Override for split mode - -### Authentication +### Authentication Configuration # [pipelines.auth] # type = "none" # none|basic|bearer|mtls -### Basic authentication +### Basic Authentication # [pipelines.auth.basic_auth] # realm = "LogWisp" # WWW-Authenticate realm -# users_file = "" # External users file +# users_file = "" # External users file path # [[pipelines.auth.basic_auth.users]] # username = "" # Username -# password_hash = "" # bcrypt hash +# password_hash = "" # Argon2 password hash -### Bearer authentication +### Bearer Token Authentication # [pipelines.auth.bearer_auth] # tokens = [] # Static bearer tokens -### JWT validation +### JWT Validation # [pipelines.auth.bearer_auth.jwt] -# jwks_url = "" # JWKS endpoint -# signing_key = "" # Static signing key -# issuer = "" # Expected issuer -# audience = "" # Expected audience \ No newline at end of file +# jwks_url = "" # JWKS endpoint for key discovery +# signing_key = "" # Static signing key (if not using JWKS) +# issuer = "" # Expected issuer claim +# audience = "" # Expected audience claim \ No newline at end of file diff --git a/go.mod b/go.mod index 5353093..25600e0 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module logwisp go 1.25.1 require ( - github.com/golang-jwt/jwt/v5 v5.3.0 github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 github.com/panjf2000/gnet/v2 v2.9.4 + github.com/stretchr/testify v1.10.0 github.com/valyala/fasthttp v1.66.0 golang.org/x/crypto v0.42.0 golang.org/x/term v0.35.0 @@ -20,6 +20,7 @@ require ( github.com/klauspost/compress v1.18.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect diff --git a/go.sum b/go.sum index fce8071..d7220f2 100644 --- a/go.sum +++ b/go.sum @@ -6,14 +6,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw= github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw= -github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= -github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0= github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0= -github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e h1:/XWCqFdSOiUf0/a5a63GHsvEdpglsYfn3qieNxTeyDc= -github.com/lixenwraith/log v0.0.0-20250929084748-210374d95b3e/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0= github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 h1:9Qf+BR83sKjok2E1Nct+3Sfzoj2dLGwC/zyQDVNmmqs= github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0= github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= diff --git a/src/cmd/logwisp/commands.go b/src/cmd/logwisp/commands.go index 56ccd24..d0898e4 100644 --- a/src/cmd/logwisp/commands.go +++ b/src/cmd/logwisp/commands.go @@ -105,7 +105,7 @@ func (c *helpCommand) Description() string { type authCommand struct{} func (c *authCommand) Execute(args []string) error { - gen := auth.NewGeneratorCommand() + gen := auth.NewAuthGeneratorCommand() return gen.Execute(args) } diff --git a/src/cmd/logwisp/status.go b/src/cmd/logwisp/status.go index 9cde4bf..dddd136 100644 --- a/src/cmd/logwisp/status.go +++ b/src/cmd/logwisp/status.go @@ -183,11 +183,13 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) { "name", name) } - case "stdout", "stderr": - logger.Info("msg", "Console sink configured", - "pipeline", cfg.Name, - "sink_index", i, - "type", sinkCfg.Type) + case "console": + if target, ok := sinkCfg.Options["target"].(string); ok { + logger.Info("msg", "Console sink configured", + "pipeline", cfg.Name, + "sink_index", i, + "target", target) + } } } diff --git a/src/internal/auth/authenticator.go b/src/internal/auth/authenticator.go index 159792a..1815296 100644 --- a/src/internal/auth/authenticator.go +++ b/src/internal/auth/authenticator.go @@ -2,22 +2,17 @@ package auth import ( - "bufio" "crypto/rand" - "crypto/subtle" "encoding/base64" "fmt" "net" - "os" "strings" "sync" "time" "logwisp/src/internal/config" - "github.com/golang-jwt/jwt/v5" "github.com/lixenwraith/log" - "golang.org/x/crypto/argon2" "golang.org/x/time/rate" ) @@ -28,10 +23,7 @@ const maxAuthTrackedIPs = 10000 type Authenticator struct { config *config.AuthConfig logger *log.Logger - basicUsers map[string]string // username -> password hash - bearerTokens map[string]bool // token -> valid - jwtParser *jwt.Parser - jwtKeyFunc jwt.Keyfunc + bearerTokens map[string]bool // token -> valid mu sync.RWMutex // Session tracking @@ -55,69 +47,32 @@ type ipAuthState struct { type Session struct { ID string Username string - Method string // basic, bearer, jwt, mtls + Method string // basic, bearer, mtls RemoteAddr string CreatedAt time.Time LastActivity time.Time - Metadata map[string]any } // Creates a new authenticator from config -func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { - if cfg == nil || cfg.Type == "none" { +func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { + // SCRAM is handled by ScramManager in sources + if cfg == nil || cfg.Type == "none" || cfg.Type == "scram" { return nil, nil } a := &Authenticator{ config: cfg, logger: logger, - basicUsers: make(map[string]string), bearerTokens: make(map[string]bool), sessions: make(map[string]*Session), ipAuthAttempts: make(map[string]*ipAuthState), } - // Initialize Basic Auth users - if cfg.Type == "basic" && cfg.BasicAuth != nil { - for _, user := range cfg.BasicAuth.Users { - a.basicUsers[user.Username] = user.PasswordHash - } - - // Load users from file if specified - if cfg.BasicAuth.UsersFile != "" { - if err := a.loadUsersFile(cfg.BasicAuth.UsersFile); err != nil { - return nil, fmt.Errorf("failed to load users file: %w", err) - } - } - } - // Initialize Bearer tokens if cfg.Type == "bearer" && cfg.BearerAuth != nil { for _, token := range cfg.BearerAuth.Tokens { a.bearerTokens[token] = true } - - // Setup JWT validation if configured - if cfg.BearerAuth.JWT != nil { - a.jwtParser = jwt.NewParser( - jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}), - jwt.WithLeeway(5*time.Second), - jwt.WithExpirationRequired(), - ) - - // Setup key function - if cfg.BearerAuth.JWT.SigningKey != "" { - // Static key - key := []byte(cfg.BearerAuth.JWT.SigningKey) - a.jwtKeyFunc = func(token *jwt.Token) (any, error) { - return key, nil - } - } else if cfg.BearerAuth.JWT.JWKSURL != "" { - // JWKS support would require additional implementation - // ☢ SECURITY: JWKS rotation not implemented - tokens won't refresh keys - return nil, fmt.Errorf("JWKS support not yet implemented") - } - } } // Start session cleanup @@ -276,8 +231,6 @@ func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Sessio var err error switch a.config.Type { - case "basic": - session, err = a.authenticateBasic(authHeader, remoteAddr) case "bearer": session, err = a.authenticateBearer(authHeader, remoteAddr) default: @@ -322,24 +275,6 @@ func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string) session, err = a.validateToken(credentials, remoteAddr) } - case "basic": - if a.config.Type != "basic" { - err = fmt.Errorf("basic auth not configured") - } else { - // Expect base64(username:password) - decoded, decErr := base64.StdEncoding.DecodeString(credentials) - if decErr != nil { - err = fmt.Errorf("invalid credentials encoding") - } else { - parts := strings.SplitN(string(decoded), ":", 2) - if len(parts) != 2 { - err = fmt.Errorf("invalid credentials format") - } else { - session, err = a.validateBasicAuth(parts[0], parts[1], remoteAddr) - } - } - } - default: err = fmt.Errorf("unsupported auth method: %s", method) } @@ -355,91 +290,6 @@ func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string) return session, nil } -func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) { - if !strings.HasPrefix(authHeader, "Basic ") { - return nil, fmt.Errorf("invalid basic auth header") - } - - payload, err := base64.StdEncoding.DecodeString(authHeader[6:]) - if err != nil { - return nil, fmt.Errorf("invalid base64 encoding") - } - - parts := strings.SplitN(string(payload), ":", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid credentials format") - } - - return a.validateBasicAuth(parts[0], parts[1], remoteAddr) -} - -func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string) (*Session, error) { - a.mu.RLock() - expectedHash, exists := a.basicUsers[username] - a.mu.RUnlock() - - if !exists { - // Perform argon2 anyway to prevent timing attacks - dummySalt := make([]byte, 16) - argon2.IDKey([]byte(password), dummySalt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen) - return nil, fmt.Errorf("invalid credentials") - } - - // Parse PHC format hash - if !verifyArgon2idHash(password, expectedHash) { - return nil, fmt.Errorf("invalid credentials") - } - - session := &Session{ - ID: generateSessionID(), - Username: username, - Method: "basic", - RemoteAddr: remoteAddr, - CreatedAt: time.Now(), - LastActivity: time.Now(), - } - - a.storeSession(session) - return session, nil -} - -// Verify Argon2id hashes -func verifyArgon2idHash(password, hash string) bool { - // Parse PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash - parts := strings.Split(hash, "$") - if len(parts) != 6 || parts[1] != "argon2id" { - return false - } - - // Parse version - var version int - fmt.Sscanf(parts[2], "v=%d", &version) - if version != argon2.Version { - return false - } - - // Parse parameters - var memory, time, threads uint32 - fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) - - // Decode salt and hash - salt, err := base64.RawStdEncoding.DecodeString(parts[4]) - if err != nil { - return false - } - - expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) - if err != nil { - return false - } - - // Compute hash - computedHash := argon2.IDKey([]byte(password), salt, time, memory, uint8(threads), uint32(len(expectedHash))) - - // Constant time comparison - return subtle.ConstantTimeCompare(computedHash, expectedHash) == 1 -} - func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) { if !strings.HasPrefix(authHeader, "Bearer ") { return nil, fmt.Errorf("invalid bearer auth header") @@ -452,97 +302,22 @@ func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Sess func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) { // Check static tokens first a.mu.RLock() - isStatic := a.bearerTokens[token] + isValid := a.bearerTokens[token] a.mu.RUnlock() - if isStatic { - session := &Session{ - ID: generateSessionID(), - Method: "bearer", - RemoteAddr: remoteAddr, - CreatedAt: time.Now(), - LastActivity: time.Now(), - Metadata: map[string]any{"token_type": "static"}, - } - a.storeSession(session) - return session, nil + if !isValid { + return nil, fmt.Errorf("invalid token") } - // Try JWT validation if configured - if a.jwtParser != nil && a.jwtKeyFunc != nil { - claims := jwt.MapClaims{} - parsedToken, err := a.jwtParser.ParseWithClaims(token, claims, a.jwtKeyFunc) - if err != nil { - return nil, fmt.Errorf("JWT validation failed: %w", err) - } - - if !parsedToken.Valid { - return nil, fmt.Errorf("invalid JWT token") - } - - // Explicit expiration check - if exp, ok := claims["exp"].(float64); ok { - if time.Now().Unix() > int64(exp) { - return nil, fmt.Errorf("token expired") - } - } else { - // Reject tokens without expiration - return nil, fmt.Errorf("token missing expiration claim") - } - - // Check not-before claim - if nbf, ok := claims["nbf"].(float64); ok { - if time.Now().Unix() < int64(nbf) { - return nil, fmt.Errorf("token not yet valid") - } - } - - // Check issuer if configured - if a.config.BearerAuth.JWT.Issuer != "" { - if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer { - return nil, fmt.Errorf("invalid token issuer") - } - } - - // Check audience if configured - if a.config.BearerAuth.JWT.Audience != "" { - // Handle both string and []string audience formats - audValid := false - switch aud := claims["aud"].(type) { - case string: - audValid = aud == a.config.BearerAuth.JWT.Audience - case []any: - for _, aa := range aud { - if audStr, ok := aa.(string); ok && audStr == a.config.BearerAuth.JWT.Audience { - audValid = true - break - } - } - } - if !audValid { - return nil, fmt.Errorf("invalid token audience") - } - } - - username := "" - if sub, ok := claims["sub"].(string); ok { - username = sub - } - - session := &Session{ - ID: generateSessionID(), - Username: username, - Method: "jwt", - RemoteAddr: remoteAddr, - CreatedAt: time.Now(), - LastActivity: time.Now(), - Metadata: map[string]any{"claims": claims}, - } - a.storeSession(session) - return session, nil + session := &Session{ + ID: generateSessionID(), + Method: "bearer", + RemoteAddr: remoteAddr, + CreatedAt: time.Now(), + LastActivity: time.Now(), } - - return nil, fmt.Errorf("invalid token") + a.storeSession(session) + return session, nil } func (a *Authenticator) storeSession(session *Session) { @@ -598,49 +373,6 @@ func (a *Authenticator) authAttemptCleanup() { } } -func (a *Authenticator) loadUsersFile(path string) error { - file, err := os.Open(path) - if err != nil { - return fmt.Errorf("could not open users file: %w", err) - } - defer file.Close() - - scanner := bufio.NewScanner(file) - lineNumber := 0 - for scanner.Scan() { - lineNumber++ - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue // Skip empty lines and comments - } - - parts := strings.SplitN(line, ":", 2) - if len(parts) != 2 { - a.logger.Warn("msg", "Skipping malformed line in users file", - "component", "auth", - "path", path, - "line_number", lineNumber) - continue - } - username, hash := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) - if username != "" && hash != "" { - // File-based users can overwrite inline users if names conflict - a.basicUsers[username] = hash - } - } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("error reading users file: %w", err) - } - - a.logger.Info("msg", "Loaded users from file", - "component", "auth", - "path", path, - "user_count", len(a.basicUsers)) - - return nil -} - func generateSessionID() string { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { @@ -686,7 +418,6 @@ func (a *Authenticator) GetStats() map[string]any { "enabled": true, "type": a.config.Type, "active_sessions": sessionCount, - "basic_users": len(a.basicUsers), "static_tokens": len(a.bearerTokens), } } \ No newline at end of file diff --git a/src/internal/auth/generator.go b/src/internal/auth/generator.go index 5db70a2..2daf21a 100644 --- a/src/internal/auth/generator.go +++ b/src/internal/auth/generator.go @@ -10,6 +10,8 @@ import ( "os" "syscall" + "logwisp/src/internal/scram" + "golang.org/x/crypto/argon2" "golang.org/x/term" ) @@ -23,40 +25,45 @@ const ( argon2KeyLen = 32 ) -type GeneratorCommand struct { +type AuthGeneratorCommand struct { output io.Writer errOut io.Writer } -func NewGeneratorCommand() *GeneratorCommand { - return &GeneratorCommand{ +func NewAuthGeneratorCommand() *AuthGeneratorCommand { + return &AuthGeneratorCommand{ output: os.Stdout, errOut: os.Stderr, } } -func (g *GeneratorCommand) Execute(args []string) error { +func (ag *AuthGeneratorCommand) Execute(args []string) error { cmd := flag.NewFlagSet("auth", flag.ContinueOnError) - cmd.SetOutput(g.errOut) + cmd.SetOutput(ag.errOut) var ( - username = cmd.String("u", "", "Username for basic auth") - password = cmd.String("p", "", "Password to hash (will prompt if not provided)") + username = cmd.String("u", "", "Username") + password = cmd.String("p", "", "Password (will prompt if not provided)") + authType = cmd.String("type", "basic", "Auth type: basic (HTTP) or scram (TCP)") genToken = cmd.Bool("t", false, "Generate random bearer token") - tokenLen = cmd.Int("l", 32, "Token length in bytes") + tokenLen = cmd.Int("l", 32, "Token length in bytes (min 16, max 512)") ) cmd.Usage = func() { - fmt.Fprintln(g.errOut, "Generate authentication credentials for LogWisp") - fmt.Fprintln(g.errOut, "\nUsage: logwisp auth [options]") - fmt.Fprintln(g.errOut, "\nExamples:") - fmt.Fprintln(g.errOut, " # Generate Argon2id hash for user") - fmt.Fprintln(g.errOut, " logwisp auth -u admin") - fmt.Fprintln(g.errOut, " ") - fmt.Fprintln(g.errOut, " # Generate 64-byte bearer token") - fmt.Fprintln(g.errOut, " logwisp auth -t -l 64") - fmt.Fprintln(g.errOut, "\nOptions:") + fmt.Fprintln(ag.errOut, "Generate authentication credentials for LogWisp") + fmt.Fprintln(ag.errOut, "\nUsage: logwisp auth [options]") + fmt.Fprintln(ag.errOut, "\nExamples:") + fmt.Fprintln(ag.errOut, " # Generate basic auth hash for HTTP sources/sinks") + fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type basic") + fmt.Fprintln(ag.errOut, " ") + fmt.Fprintln(ag.errOut, " # Generate SCRAM credentials for TCP sources/sinks") + fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type scram") + fmt.Fprintln(ag.errOut, " ") + fmt.Fprintln(ag.errOut, " # Generate 64-byte bearer token") + fmt.Fprintln(ag.errOut, " logwisp auth -t -l 64") + fmt.Fprintln(ag.errOut, "\nOptions:") cmd.PrintDefaults() + fmt.Fprintln(ag.errOut) } if err := cmd.Parse(args); err != nil { @@ -64,22 +71,29 @@ func (g *GeneratorCommand) Execute(args []string) error { } if *genToken { - return g.generateToken(*tokenLen) + return ag.generateToken(*tokenLen) } if *username == "" { cmd.Usage() - return fmt.Errorf("username required for password hash generation") + return fmt.Errorf("username required for credential generation") } - return g.generatePasswordHash(*username, *password) + switch *authType { + case "basic": + return ag.generateBasicAuth(*username, *password) + case "scram": + return ag.generateScramAuth(*username, *password) + default: + return fmt.Errorf("invalid auth type: %s (use 'basic' or 'scram')", *authType) + } } -func (g *GeneratorCommand) generatePasswordHash(username, password string) error { +func (ag *AuthGeneratorCommand) generateBasicAuth(username, password string) error { // Get password if not provided if password == "" { - pass1 := g.promptPassword("Enter password: ") - pass2 := g.promptPassword("Confirm password: ") + pass1 := ag.promptPassword("Enter password: ") + pass2 := ag.promptPassword("Confirm password: ") if pass1 != pass2 { return fmt.Errorf("passwords don't match") } @@ -102,20 +116,61 @@ func (g *GeneratorCommand) generatePasswordHash(username, password string) error argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64) // Output configuration snippets - fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):") - fmt.Fprintln(g.output, "[[pipelines.auth.basic_auth.users]]") - fmt.Fprintf(g.output, "username = %q\n", username) - fmt.Fprintf(g.output, "password_hash = %q\n\n", phcHash) - - fmt.Fprintln(g.output, "# Users File Format (for external auth file):") - fmt.Fprintf(g.output, "%s:%s\n", username, phcHash) + fmt.Fprintln(ag.output, "\n# Basic Auth Configuration (HTTP sources/sinks)") + fmt.Fprintln(ag.output, "# REQUIRES HTTPS/TLS for security") + fmt.Fprintln(ag.output, "# Add to logwisp.toml under [[pipelines]]:") + fmt.Fprintln(ag.output, "") + fmt.Fprintln(ag.output, "[pipelines.auth]") + fmt.Fprintln(ag.output, `type = "basic"`) + fmt.Fprintln(ag.output, "") + fmt.Fprintln(ag.output, "[[pipelines.auth.basic_auth.users]]") + fmt.Fprintf(ag.output, "username = %q\n", username) + fmt.Fprintf(ag.output, "password_hash = %q\n\n", phcHash) return nil } -func (g *GeneratorCommand) generateToken(length int) error { +func (ag *AuthGeneratorCommand) generateScramAuth(username, password string) error { + // Get password if not provided + if password == "" { + pass1 := ag.promptPassword("Enter password: ") + pass2 := ag.promptPassword("Confirm password: ") + if pass1 != pass2 { + return fmt.Errorf("passwords don't match") + } + password = pass1 + } + + // Generate salt + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return fmt.Errorf("failed to generate salt: %w", err) + } + + // Derive SCRAM credential + cred, err := scram.DeriveCredential(username, password, salt, 3, 65536, 4) + if err != nil { + return fmt.Errorf("failed to derive SCRAM credential: %w", err) + } + + // Output SCRAM configuration + fmt.Fprintln(ag.output, "\n# SCRAM Auth Configuration (for TCP sources/sinks)") + fmt.Fprintln(ag.output, "# Add to logwisp.toml:") + fmt.Fprintln(ag.output, "[[pipelines.auth.scram_auth.users]]") + fmt.Fprintf(ag.output, "username = %q\n", username) + fmt.Fprintf(ag.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey)) + fmt.Fprintf(ag.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey)) + fmt.Fprintf(ag.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt)) + fmt.Fprintf(ag.output, "argon_time = %d\n", cred.ArgonTime) + fmt.Fprintf(ag.output, "argon_memory = %d\n", cred.ArgonMemory) + fmt.Fprintf(ag.output, "argon_threads = %d\n\n", cred.ArgonThreads) + + return nil +} + +func (ag *AuthGeneratorCommand) generateToken(length int) error { if length < 16 { - fmt.Fprintln(g.errOut, "Warning: tokens < 16 bytes are cryptographically weak") + fmt.Fprintln(ag.errOut, "Warning: tokens < 16 bytes are cryptographically weak") } if length > 512 { return fmt.Errorf("token length exceeds maximum (512 bytes)") @@ -129,22 +184,23 @@ func (g *GeneratorCommand) generateToken(length int) error { b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token) hex := fmt.Sprintf("%x", token) - fmt.Fprintln(g.output, "\n# TOML Configuration (add to logwisp.toml):") - fmt.Fprintf(g.output, "tokens = [%q]\n\n", b64) + fmt.Fprintln(ag.output, "\n# Bearer Token Configuration") + fmt.Fprintln(ag.output, "# Add to logwisp.toml:") + fmt.Fprintf(ag.output, "tokens = [%q]\n\n", b64) - fmt.Fprintln(g.output, "# Generated Token:") - fmt.Fprintf(g.output, "Base64: %s\n", b64) - fmt.Fprintf(g.output, "Hex: %s\n", hex) + fmt.Fprintln(ag.output, "# Generated Token:") + fmt.Fprintf(ag.output, "Base64: %s\n", b64) + fmt.Fprintf(ag.output, "Hex: %s\n", hex) return nil } -func (g *GeneratorCommand) promptPassword(prompt string) string { - fmt.Fprint(g.errOut, prompt) - password, err := term.ReadPassword(int(syscall.Stdin)) - fmt.Fprintln(g.errOut) +func (ag *AuthGeneratorCommand) promptPassword(prompt string) string { + fmt.Fprint(ag.errOut, prompt) + password, err := term.ReadPassword(syscall.Stdin) + fmt.Fprintln(ag.errOut) if err != nil { - fmt.Fprintf(g.errOut, "Failed to read password: %v\n", err) + fmt.Fprintf(ag.errOut, "Failed to read password: %v\n", err) os.Exit(1) } return string(password) diff --git a/src/internal/config/auth.go b/src/internal/config/auth.go index 2c4acd6..8b22cbc 100644 --- a/src/internal/config/auth.go +++ b/src/internal/config/auth.go @@ -6,61 +6,61 @@ import ( ) type AuthConfig struct { - // Authentication type: "none", "basic", "bearer", "mtls" + // Authentication type: "none", "basic", "scram", "bearer", "mtls" Type string `toml:"type"` - // Basic auth - BasicAuth *BasicAuthConfig `toml:"basic_auth"` - - // Bearer token auth + BasicAuth *BasicAuthConfig `toml:"basic_auth"` + ScramAuth *ScramAuthConfig `toml:"scram_auth"` BearerAuth *BearerAuthConfig `toml:"bearer_auth"` } type BasicAuthConfig struct { - // Static users (for simple deployments) Users []BasicAuthUser `toml:"users"` - - // External auth file - UsersFile string `toml:"users_file"` - - // Realm for WWW-Authenticate header - Realm string `toml:"realm"` + Realm string `toml:"realm"` } type BasicAuthUser struct { - Username string `toml:"username"` - // Password hash (Argon2id) - PasswordHash string `toml:"password_hash"` + Username string `toml:"username"` + PasswordHash string `toml:"password_hash"` // Argon2 +} + +type ScramAuthConfig struct { + Users []ScramUser `toml:"users"` +} + +type ScramUser struct { + Username string `toml:"username"` + StoredKey string `toml:"stored_key"` // base64 + ServerKey string `toml:"server_key"` // base64 + Salt string `toml:"salt"` // base64 + ArgonTime uint32 `toml:"argon_time"` + ArgonMemory uint32 `toml:"argon_memory"` + ArgonThreads uint8 `toml:"argon_threads"` } type BearerAuthConfig struct { // Static tokens Tokens []string `toml:"tokens"` - // JWT validation - JWT *JWTConfig `toml:"jwt"` + // TODO: Maybe future development + // // JWT validation + // JWT *JWTConfig `toml:"jwt"` } -type JWTConfig struct { - // JWKS URL for key discovery - JWKSURL string `toml:"jwks_url"` - - // Static signing key (if not using JWKS) - SigningKey string `toml:"signing_key"` - - // Expected issuer - Issuer string `toml:"issuer"` - - // Expected audience - Audience string `toml:"audience"` -} +// TODO: Maybe future development +// type JWTConfig struct { +// JWKSURL string `toml:"jwks_url"` +// SigningKey string `toml:"signing_key"` +// Issuer string `toml:"issuer"` +// Audience string `toml:"audience"` +// } func validateAuth(pipelineName string, auth *AuthConfig) error { if auth == nil { return nil } - validTypes := map[string]bool{"none": true, "basic": true, "bearer": true, "mtls": true} + validTypes := map[string]bool{"none": true, "basic": true, "scram": true, "bearer": true, "mtls": true} if !validTypes[auth.Type] { return fmt.Errorf("pipeline '%s': invalid auth type: %s", pipelineName, auth.Type) } @@ -69,6 +69,10 @@ func validateAuth(pipelineName string, auth *AuthConfig) error { return fmt.Errorf("pipeline '%s': basic auth type specified but config missing", pipelineName) } + if auth.Type == "scram" && auth.ScramAuth == nil { + return fmt.Errorf("pipeline '%s': scram auth type specified but config missing", pipelineName) + } + if auth.Type == "bearer" && auth.BearerAuth == nil { return fmt.Errorf("pipeline '%s': bearer auth type specified but config missing", pipelineName) } diff --git a/src/internal/config/loader.go b/src/internal/config/loader.go index d4922cd..e1b580b 100644 --- a/src/internal/config/loader.go +++ b/src/internal/config/loader.go @@ -180,10 +180,10 @@ func applyConsoleTargetOverrides(cfg *Config) error { return fmt.Errorf("invalid LOGWISP_CONSOLE_TARGET value: %s", consoleTarget) } - // Apply to all console sinks + // Apply to console sinks for i, pipeline := range cfg.Pipelines { for j, sink := range pipeline.Sinks { - if sink.Type == "stdout" || sink.Type == "stderr" { + if sink.Type == "console" { if sink.Options == nil { cfg.Pipelines[i].Sinks[j].Options = make(map[string]any) } @@ -193,10 +193,5 @@ func applyConsoleTargetOverrides(cfg *Config) error { } } - // Also update logging console target if applicable - if cfg.Logging.Console != nil && consoleTarget == "split" { - cfg.Logging.Console.Target = "split" - } - return nil } \ No newline at end of file diff --git a/src/internal/config/pipeline.go b/src/internal/config/pipeline.go index e084101..bcc2b16 100644 --- a/src/internal/config/pipeline.go +++ b/src/internal/config/pipeline.go @@ -36,7 +36,7 @@ type PipelineConfig struct { // Represents an input data source type SourceConfig struct { - // Source type: "directory", "stdin", "tcp", "http" + // Source type Type string `toml:"type"` // Type-specific configuration options @@ -45,7 +45,7 @@ type SourceConfig struct { // Represents an output destination type SinkConfig struct { - // Sink type: "http", "tcp", "file", "stdout", "stderr" + // Sink type Type string `toml:"type"` // Type-specific configuration options @@ -404,7 +404,7 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts } } - case "stdout", "stderr": + case "console": // No specific validation needed for console sinks default: diff --git a/src/internal/filter/chain_test.go b/src/internal/filter/chain_test.go new file mode 100644 index 0000000..91a67cd --- /dev/null +++ b/src/internal/filter/chain_test.go @@ -0,0 +1,80 @@ +// FILE: logwisp/src/internal/filter/chain_test.go +package filter + +import ( + "testing" + + "logwisp/src/internal/config" + "logwisp/src/internal/core" + + "github.com/stretchr/testify/assert" +) + +func TestNewChain(t *testing.T) { + logger := newTestLogger() + + t.Run("Success", func(t *testing.T) { + configs := []config.FilterConfig{ + {Type: config.FilterTypeInclude, Patterns: []string{"apple"}}, + {Type: config.FilterTypeExclude, Patterns: []string{"banana"}}, + } + chain, err := NewChain(configs, logger) + assert.NoError(t, err) + assert.NotNil(t, chain) + assert.Len(t, chain.filters, 2) + }) + + t.Run("ErrorInvalidRegexInChain", func(t *testing.T) { + configs := []config.FilterConfig{ + {Patterns: []string{"apple"}}, + {Patterns: []string{"["}}, + } + chain, err := NewChain(configs, logger) + assert.Error(t, err) + assert.Nil(t, chain) + assert.Contains(t, err.Error(), "filter[1]") + }) +} + +func TestChain_Apply(t *testing.T) { + logger := newTestLogger() + entry := core.LogEntry{Message: "an apple a day"} + + t.Run("EmptyChain", func(t *testing.T) { + chain, err := NewChain([]config.FilterConfig{}, logger) + assert.NoError(t, err) + assert.True(t, chain.Apply(entry)) + }) + + t.Run("AllFiltersPass", func(t *testing.T) { + configs := []config.FilterConfig{ + {Type: config.FilterTypeInclude, Patterns: []string{"apple"}}, + {Type: config.FilterTypeInclude, Patterns: []string{"day"}}, + {Type: config.FilterTypeExclude, Patterns: []string{"banana"}}, + } + chain, err := NewChain(configs, logger) + assert.NoError(t, err) + assert.True(t, chain.Apply(entry)) + }) + + t.Run("OneFilterFails", func(t *testing.T) { + configs := []config.FilterConfig{ + {Type: config.FilterTypeInclude, Patterns: []string{"apple"}}, + {Type: config.FilterTypeExclude, Patterns: []string{"day"}}, // This one will fail + {Type: config.FilterTypeInclude, Patterns: []string{"a"}}, + } + chain, err := NewChain(configs, logger) + assert.NoError(t, err) + assert.False(t, chain.Apply(entry)) + }) + + t.Run("FirstFilterFails", func(t *testing.T) { + configs := []config.FilterConfig{ + {Type: config.FilterTypeInclude, Patterns: []string{"banana"}}, // This one will fail + {Type: config.FilterTypeInclude, Patterns: []string{"apple"}}, + } + chain, err := NewChain(configs, logger) + assert.NoError(t, err) + assert.False(t, chain.Apply(entry)) + }) +} \ No newline at end of file diff --git a/src/internal/filter/filter.go b/src/internal/filter/filter.go index a5bfc7c..6f1045c 100644 --- a/src/internal/filter/filter.go +++ b/src/internal/filter/filter.go @@ -66,6 +66,9 @@ func (f *Filter) Apply(entry core.LogEntry) bool { // No patterns means pass everything if len(f.patterns) == 0 { + f.logger.Debug("msg", "No patterns configured, passing entry", + "component", "filter", + "type", f.config.Type) return true } @@ -78,10 +81,32 @@ func (f *Filter) Apply(entry core.LogEntry) bool { text = entry.Source + " " + text } + f.logger.Debug("msg", "Filter checking entry", + "component", "filter", + "type", f.config.Type, + "logic", f.config.Logic, + "entry_level", entry.Level, + "entry_source", entry.Source, + "entry_message", entry.Message[:min(100, len(entry.Message))], // First 100 chars + "text_to_match", text[:min(150, len(text))], // First 150 chars + "patterns", f.config.Patterns) + + for i, pattern := range f.config.Patterns { + isMatch := f.patterns[i].MatchString(text) + f.logger.Debug("msg", "Pattern match result", + "component", "filter", + "pattern_index", i, + "pattern", pattern, + "matched", isMatch) + } + matched := f.matches(text) if matched { f.totalMatched.Add(1) } + f.logger.Debug("msg", "Filter final match result", + "component", "filter", + "matched", matched) // Determine if we should pass or drop shouldPass := false @@ -92,6 +117,12 @@ func (f *Filter) Apply(entry core.LogEntry) bool { shouldPass = !matched } + f.logger.Debug("msg", "Filter decision", + "component", "filter", + "type", f.config.Type, + "matched", matched, + "should_pass", shouldPass) + if !shouldPass { f.totalDropped.Add(1) } diff --git a/src/internal/filter/filter_test.go b/src/internal/filter/filter_test.go new file mode 100644 index 0000000..b561b15 --- /dev/null +++ b/src/internal/filter/filter_test.go @@ -0,0 +1,159 @@ +// FILE: logwisp/src/internal/filter/filter_test.go +package filter + +import ( + "logwisp/src/internal/config" + "logwisp/src/internal/core" + "testing" + + "github.com/lixenwraith/log" + "github.com/stretchr/testify/assert" +) + +func newTestLogger() *log.Logger { + return log.NewLogger() +} + +func TestNewFilter(t *testing.T) { + logger := newTestLogger() + + t.Run("SuccessWithDefaults", func(t *testing.T) { + cfg := config.FilterConfig{Patterns: []string{"test"}} + f, err := NewFilter(cfg, logger) + assert.NoError(t, err) + assert.NotNil(t, f) + assert.Equal(t, config.FilterTypeInclude, f.config.Type) + assert.Equal(t, config.FilterLogicOr, f.config.Logic) + }) + + t.Run("SuccessWithCustomConfig", func(t *testing.T) { + cfg := config.FilterConfig{ + Type: config.FilterTypeExclude, + Logic: config.FilterLogicAnd, + Patterns: []string{"test", "pattern"}, + } + f, err := NewFilter(cfg, logger) + assert.NoError(t, err) + assert.NotNil(t, f) + assert.Equal(t, config.FilterTypeExclude, f.config.Type) + assert.Equal(t, config.FilterLogicAnd, f.config.Logic) + assert.Len(t, f.patterns, 2) + }) + + t.Run("ErrorInvalidRegex", func(t *testing.T) { + cfg := config.FilterConfig{Patterns: []string{"["}} + f, err := NewFilter(cfg, logger) + assert.Error(t, err) + assert.Nil(t, f) + assert.Contains(t, err.Error(), "invalid regex pattern") + }) +} + +func TestFilter_Apply(t *testing.T) { + logger := newTestLogger() + + testCases := []struct { + name string + cfg config.FilterConfig + entry core.LogEntry + expected bool + }{ + // Include OR logic + { + name: "IncludeOR_MatchOne", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}}, + entry: core.LogEntry{Message: "this is an apple"}, + expected: true, + }, + { + name: "IncludeOR_NoMatch", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}}, + entry: core.LogEntry{Message: "this is a pear"}, + expected: false, + }, + // Include AND logic + { + name: "IncludeAND_MatchAll", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}}, + entry: core.LogEntry{Message: "an apple keeps the doctor away"}, + expected: true, + }, + { + name: "IncludeAND_MatchOne", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}}, + entry: core.LogEntry{Message: "this is an apple"}, + expected: false, + }, + // Exclude OR logic + { + name: "ExcludeOR_MatchOne", + cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}}, + entry: core.LogEntry{Message: "this is an error"}, + expected: false, + }, + { + name: "ExcludeOR_NoMatch", + cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}}, + entry: core.LogEntry{Message: "this is a warning"}, + expected: true, + }, + // Exclude AND logic + { + name: "ExcludeAND_MatchAll", + cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}}, + entry: core.LogEntry{Message: "critical error in database"}, + expected: false, + }, + { + name: "ExcludeAND_MatchOne", + cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}}, + entry: core.LogEntry{Message: "critical error in app"}, + expected: true, + }, + // Edge Cases + { + name: "NoPatterns", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{}}, + entry: core.LogEntry{Message: "any message"}, + expected: true, + }, + { + name: "EmptyEntry_NoPatterns", + cfg: config.FilterConfig{Patterns: []string{}}, + entry: core.LogEntry{}, + expected: true, + }, + { + name: "EmptyEntry_DoesNotMatchSpace", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{" "}}, + entry: core.LogEntry{Level: "", Source: "", Message: ""}, + expected: false, // CORRECTED: An empty entry results in an empty string, which doesn't match a space. + }, + { + name: "MatchOnLevel", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"ERROR"}}, + entry: core.LogEntry{Level: "ERROR", Message: "A message"}, + expected: true, + }, + { + name: "MatchOnSource", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"database"}}, + entry: core.LogEntry{Source: "database", Message: "A message"}, + expected: true, + }, + { + name: "MatchOnCombinedFields", + cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"^app ERROR"}}, + entry: core.LogEntry{Source: "app", Level: "ERROR", Message: "A message"}, + expected: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + f, err := NewFilter(tc.cfg, logger) + assert.NoError(t, err) + assert.Equal(t, tc.expected, f.Apply(tc.entry)) + }) + } +} \ No newline at end of file diff --git a/src/internal/format/format.go b/src/internal/format/format.go index a909b6e..d62967f 100644 --- a/src/internal/format/format.go +++ b/src/internal/format/format.go @@ -28,7 +28,7 @@ func NewFormatter(name string, options map[string]any, logger *log.Logger) (Form switch name { case "json": return NewJSONFormatter(options, logger) - case "text": + case "txt": return NewTextFormatter(options, logger) case "raw": return NewRawFormatter(options, logger) diff --git a/src/internal/format/format_test.go b/src/internal/format/format_test.go new file mode 100644 index 0000000..2d803f0 --- /dev/null +++ b/src/internal/format/format_test.go @@ -0,0 +1,65 @@ +// FILE: logwisp/src/internal/format/format_test.go +package format + +import ( + "testing" + + "github.com/lixenwraith/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestLogger() *log.Logger { + return log.NewLogger() +} + +func TestNewFormatter(t *testing.T) { + logger := newTestLogger() + + testCases := []struct { + name string + formatName string + expected string + expectError bool + }{ + { + name: "JSONFormatter", + formatName: "json", + expected: "json", + }, + { + name: "TextFormatter", + formatName: "txt", + expected: "txt", + }, + { + name: "RawFormatter", + formatName: "raw", + expected: "raw", + }, + { + name: "DefaultToRaw", + formatName: "", + expected: "raw", + }, + { + name: "UnknownFormatter", + formatName: "xml", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + formatter, err := NewFormatter(tc.formatName, nil, logger) + if tc.expectError { + assert.Error(t, err) + assert.Nil(t, formatter) + } else { + require.NoError(t, err) + require.NotNil(t, formatter) + assert.Equal(t, tc.expected, formatter.Name()) + } + }) + } +} \ No newline at end of file diff --git a/src/internal/format/json_test.go b/src/internal/format/json_test.go new file mode 100644 index 0000000..0e448b2 --- /dev/null +++ b/src/internal/format/json_test.go @@ -0,0 +1,129 @@ +// FILE: logwisp/src/internal/format/json_test.go +package format + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "logwisp/src/internal/core" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONFormatter_Format(t *testing.T) { + logger := newTestLogger() + testTime := time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC) + entry := core.LogEntry{ + Time: testTime, + Source: "test-app", + Level: "INFO", + Message: "this is a test", + } + + t.Run("BasicFormatting", func(t *testing.T) { + formatter, err := NewJSONFormatter(nil, logger) + require.NoError(t, err) + + output, err := formatter.Format(entry) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(output, &result) + require.NoError(t, err, "Output should be valid JSON") + + assert.Equal(t, testTime.Format(time.RFC3339Nano), result["timestamp"]) + assert.Equal(t, "INFO", result["level"]) + assert.Equal(t, "test-app", result["source"]) + assert.Equal(t, "this is a test", result["message"]) + assert.True(t, strings.HasSuffix(string(output), "\n"), "Output should end with a newline") + }) + + t.Run("PrettyFormatting", func(t *testing.T) { + formatter, err := NewJSONFormatter(map[string]any{"pretty": true}, logger) + require.NoError(t, err) + + output, err := formatter.Format(entry) + require.NoError(t, err) + + assert.Contains(t, string(output), ` "level": "INFO"`) + assert.True(t, strings.HasSuffix(string(output), "\n")) + }) + + t.Run("MessageIsJSON", func(t *testing.T) { + jsonMessageEntry := entry + jsonMessageEntry.Message = `{"user":"test","request_id":"abc-123"}` + formatter, err := NewJSONFormatter(nil, logger) + require.NoError(t, err) + + output, err := formatter.Format(jsonMessageEntry) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(output, &result) + require.NoError(t, err) + + assert.Equal(t, "test", result["user"]) + assert.Equal(t, "abc-123", result["request_id"]) + _, messageExists := result["message"] + assert.False(t, messageExists, "message field should not exist when message is merged JSON") + }) + + t.Run("MessageIsJSONWithConflicts", func(t *testing.T) { + jsonMessageEntry := entry + jsonMessageEntry.Level = "INFO" // top-level + jsonMessageEntry.Message = `{"level":"DEBUG","msg":"hello"}` + formatter, err := NewJSONFormatter(nil, logger) + require.NoError(t, err) + + output, err := formatter.Format(jsonMessageEntry) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(output, &result) + require.NoError(t, err) + + assert.Equal(t, "INFO", result["level"], "Top-level LogEntry field should take precedence") + }) + + t.Run("CustomFieldNames", func(t *testing.T) { + options := map[string]any{"timestamp_field": "@timestamp"} + formatter, err := NewJSONFormatter(options, logger) + require.NoError(t, err) + + output, err := formatter.Format(entry) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(output, &result) + require.NoError(t, err) + + _, defaultExists := result["timestamp"] + assert.False(t, defaultExists) + assert.Equal(t, testTime.Format(time.RFC3339Nano), result["@timestamp"]) + }) +} + +func TestJSONFormatter_FormatBatch(t *testing.T) { + logger := newTestLogger() + formatter, err := NewJSONFormatter(nil, logger) + require.NoError(t, err) + + entries := []core.LogEntry{ + {Time: time.Now(), Level: "INFO", Message: "First message"}, + {Time: time.Now(), Level: "WARN", Message: "Second message"}, + } + + output, err := formatter.FormatBatch(entries) + require.NoError(t, err) + + var result []map[string]interface{} + err = json.Unmarshal(output, &result) + require.NoError(t, err, "Batch output should be a valid JSON array") + require.Len(t, result, 2) + + assert.Equal(t, "First message", result[0]["message"]) + assert.Equal(t, "WARN", result[1]["level"]) +} \ No newline at end of file diff --git a/src/internal/format/raw_test.go b/src/internal/format/raw_test.go new file mode 100644 index 0000000..84c8b98 --- /dev/null +++ b/src/internal/format/raw_test.go @@ -0,0 +1,29 @@ +// FILE: logwisp/src/internal/format/raw_test.go +package format + +import ( + "testing" + "time" + + "logwisp/src/internal/core" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRawFormatter_Format(t *testing.T) { + logger := newTestLogger() + formatter, err := NewRawFormatter(nil, logger) + require.NoError(t, err) + + entry := core.LogEntry{ + Time: time.Now(), + Message: "This is a raw log line.", + } + + output, err := formatter.Format(entry) + require.NoError(t, err) + + expected := "This is a raw log line.\n" + assert.Equal(t, expected, string(output)) +} \ No newline at end of file diff --git a/src/internal/format/text.go b/src/internal/format/text.go index 8604c0a..a7c6d4e 100644 --- a/src/internal/format/text.go +++ b/src/internal/format/text.go @@ -104,5 +104,5 @@ func (f *TextFormatter) Format(entry core.LogEntry) ([]byte, error) { // Returns the formatter name func (f *TextFormatter) Name() string { - return "text" + return "txt" } \ No newline at end of file diff --git a/src/internal/format/text_test.go b/src/internal/format/text_test.go new file mode 100644 index 0000000..464441f --- /dev/null +++ b/src/internal/format/text_test.go @@ -0,0 +1,81 @@ +// FILE: logwisp/src/internal/format/text_test.go +package format + +import ( + "fmt" + "strings" + "testing" + "time" + + "logwisp/src/internal/core" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTextFormatter(t *testing.T) { + logger := newTestLogger() + t.Run("InvalidTemplate", func(t *testing.T) { + options := map[string]any{"template": "{{ .Timestamp | InvalidFunc }}"} + _, err := NewTextFormatter(options, logger) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid template") + }) +} + +func TestTextFormatter_Format(t *testing.T) { + logger := newTestLogger() + testTime := time.Date(2023, 10, 27, 10, 30, 0, 0, time.UTC) + entry := core.LogEntry{ + Time: testTime, + Source: "api", + Level: "WARN", + Message: "rate limit exceeded", + } + + t.Run("DefaultTemplate", func(t *testing.T) { + formatter, err := NewTextFormatter(nil, logger) + require.NoError(t, err) + + output, err := formatter.Format(entry) + require.NoError(t, err) + + expected := fmt.Sprintf("[%s] [WARN] api - rate limit exceeded\n", testTime.Format(time.RFC3339)) + assert.Equal(t, expected, string(output)) + }) + + t.Run("CustomTemplate", func(t *testing.T) { + options := map[string]any{"template": "{{.Level}}:{{.Source}}:{{.Message}}"} + formatter, err := NewTextFormatter(options, logger) + require.NoError(t, err) + + output, err := formatter.Format(entry) + require.NoError(t, err) + + expected := "WARN:api:rate limit exceeded\n" + assert.Equal(t, expected, string(output)) + }) + + t.Run("CustomTimestampFormat", func(t *testing.T) { + options := map[string]any{"timestamp_format": "2006-01-02"} + formatter, err := NewTextFormatter(options, logger) + require.NoError(t, err) + + output, err := formatter.Format(entry) + require.NoError(t, err) + + assert.True(t, strings.HasPrefix(string(output), "[2023-10-27]")) + }) + + t.Run("EmptyLevelDefaultsToInfo", func(t *testing.T) { + emptyLevelEntry := entry + emptyLevelEntry.Level = "" + formatter, err := NewTextFormatter(nil, logger) + require.NoError(t, err) + + output, err := formatter.Format(emptyLevelEntry) + require.NoError(t, err) + + assert.Contains(t, string(output), "[INFO]") + }) +} \ No newline at end of file diff --git a/src/internal/scram/client.go b/src/internal/scram/client.go new file mode 100644 index 0000000..8e04459 --- /dev/null +++ b/src/internal/scram/client.go @@ -0,0 +1,106 @@ +// FILE: src/internal/scram/client.go +package scram + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "fmt" + + "golang.org/x/crypto/argon2" +) + +// Client handles SCRAM client-side authentication +type Client struct { + Username string + Password string + + // Handshake state + clientNonce string + serverFirst *ServerFirst + authMessage string + serverKey []byte +} + +// NewClient creates SCRAM client +func NewClient(username, password string) *Client { + return &Client{ + Username: username, + Password: password, + } +} + +// StartAuthentication generates ClientFirst message +func (c *Client) StartAuthentication() (*ClientFirst, error) { + // Generate client nonce + nonce := make([]byte, 32) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %w", err) + } + c.clientNonce = base64.StdEncoding.EncodeToString(nonce) + + return &ClientFirst{ + Username: c.Username, + ClientNonce: c.clientNonce, + }, nil +} + +// ProcessServerFirst handles server challenge +func (c *Client) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) { + c.serverFirst = msg + + // Decode salt + salt, err := base64.StdEncoding.DecodeString(msg.Salt) + if err != nil { + return nil, fmt.Errorf("invalid salt encoding: %w", err) + } + + // Derive keys using Argon2id + saltedPassword := argon2.IDKey([]byte(c.Password), salt, + msg.ArgonTime, msg.ArgonMemory, msg.ArgonThreads, 32) + + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + storedKey := sha256.Sum256(clientKey) + + // Build auth message + clientFirstBare := fmt.Sprintf("u=%s,n=%s", c.Username, c.clientNonce) + clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce) + c.authMessage = clientFirstBare + "," + msg.Marshal() + "," + clientFinalBare + + // Compute client proof + clientSignature := computeHMAC(storedKey[:], []byte(c.authMessage)) + clientProof := xorBytes(clientKey, clientSignature) + + // Store server key for verification + c.serverKey = serverKey + + return &ClientFinal{ + FullNonce: msg.FullNonce, + ClientProof: base64.StdEncoding.EncodeToString(clientProof), + }, nil +} + +// VerifyServerFinal validates server signature +func (c *Client) VerifyServerFinal(msg *ServerFinal) error { + if c.authMessage == "" || c.serverKey == nil { + return fmt.Errorf("invalid handshake state") + } + + // Compute expected server signature + expectedSig := computeHMAC(c.serverKey, []byte(c.authMessage)) + + // Decode received signature + receivedSig, err := base64.StdEncoding.DecodeString(msg.ServerSignature) + if err != nil { + return fmt.Errorf("invalid signature encoding: %w", err) + } + + // ☢ SECURITY: Constant-time comparison + if subtle.ConstantTimeCompare(expectedSig, receivedSig) != 1 { + return fmt.Errorf("server authentication failed") + } + + return nil +} \ No newline at end of file diff --git a/src/internal/scram/credential.go b/src/internal/scram/credential.go new file mode 100644 index 0000000..ac77650 --- /dev/null +++ b/src/internal/scram/credential.go @@ -0,0 +1,99 @@ +// FILE: src/internal/scram/credential.go +package scram + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +// Credential stores SCRAM authentication data +type Credential struct { + Username string + Salt []byte // 16+ bytes + ArgonTime uint32 // e.g., 3 + ArgonMemory uint32 // e.g., 64*1024 KiB + ArgonThreads uint8 // e.g., 4 + StoredKey []byte // SHA256(ClientKey) + ServerKey []byte // For server auth + PHCHash string +} + +// DeriveCredential creates SCRAM credential from password +func DeriveCredential(username, password string, salt []byte, time, memory uint32, threads uint8) (*Credential, error) { + if len(salt) < 16 { + return nil, fmt.Errorf("salt must be at least 16 bytes") + } + + // Derive salted password using Argon2id + saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, 32) + + // Derive keys + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + storedKey := sha256.Sum256(clientKey) + + return &Credential{ + Username: username, + Salt: salt, + ArgonTime: time, + ArgonMemory: memory, + ArgonThreads: threads, + StoredKey: storedKey[:], + ServerKey: serverKey, + }, nil +} + +// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential +func MigrateFromPHC(username, password, phcHash string) (*Credential, error) { + // Parse PHC: $argon2id$v=19$m=65536,t=3,p=4$salt$hash + parts := strings.Split(phcHash, "$") + if len(parts) != 6 || parts[1] != "argon2id" { + return nil, fmt.Errorf("invalid PHC format") + } + + var memory, time uint32 + var threads uint8 + fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) + + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return nil, fmt.Errorf("invalid salt encoding: %w", err) + } + + expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return nil, fmt.Errorf("invalid hash encoding: %w", err) + } + + // Verify password matches + computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash))) + if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 { + return nil, fmt.Errorf("password verification failed") + } + + // Now derive SCRAM credential + return DeriveCredential(username, password, salt, time, memory, threads) +} + +func computeHMAC(key, message []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(message) + return mac.Sum(nil) +} + +func xorBytes(a, b []byte) []byte { + if len(a) != len(b) { + panic("xor length mismatch") + } + result := make([]byte, len(a)) + for i := range a { + result[i] = a[i] ^ b[i] + } + return result +} \ No newline at end of file diff --git a/src/internal/scram/integration.go b/src/internal/scram/integration.go new file mode 100644 index 0000000..be190d9 --- /dev/null +++ b/src/internal/scram/integration.go @@ -0,0 +1,117 @@ +// FILE: src/internal/scram/integration.go +package scram + +import ( + "crypto/rand" + "fmt" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// ScramManager provides high-level SCRAM operations with rate limiting +type ScramManager struct { + server *Server + sessions map[string]*SessionInfo + limiter map[string]*rate.Limiter + mu sync.RWMutex +} + +// SessionInfo tracks authenticated sessions +type SessionInfo struct { + Username string + RemoteAddr string + SessionID string + CreatedAt time.Time + LastActivity time.Time + Method string // "scram-sha-256" +} + +// NewScramManager creates SCRAM manager +func NewScramManager() *ScramManager { + m := &ScramManager{ + server: NewServer(), + sessions: make(map[string]*SessionInfo), + limiter: make(map[string]*rate.Limiter), + } + + // Start cleanup goroutine + go m.cleanupLoop() + return m +} + +// RegisterUser creates new user credential +func (sm *ScramManager) RegisterUser(username, password string) error { + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return fmt.Errorf("salt generation failed: %w", err) + } + + cred, err := DeriveCredential(username, password, salt, + sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads) + if err != nil { + return err + } + + sm.server.AddCredential(cred) + return nil +} + +// GetRateLimiter returns per-IP rate limiter +func (sm *ScramManager) GetRateLimiter(remoteAddr string) *rate.Limiter { + sm.mu.Lock() + defer sm.mu.Unlock() + + if limiter, exists := sm.limiter[remoteAddr]; exists { + return limiter + } + + // 10 attempts per minute, burst of 3 + limiter := rate.NewLimiter(rate.Every(6*time.Second), 3) + sm.limiter[remoteAddr] = limiter + + // Prevent unbounded growth + if len(sm.limiter) > 10000 { + // Remove oldest entries + for addr := range sm.limiter { + delete(sm.limiter, addr) + if len(sm.limiter) < 8000 { + break + } + } + } + + return limiter +} + +func (sm *ScramManager) cleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + sm.mu.Lock() + cutoff := time.Now().Add(-30 * time.Minute) + for sid, session := range sm.sessions { + if session.LastActivity.Before(cutoff) { + delete(sm.sessions, sid) + } + } + sm.mu.Unlock() + } +} + +// HandleClientFirst wraps server's HandleClientFirst +func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { + return sm.server.HandleClientFirst(msg) +} + +// HandleClientFinal wraps server's HandleClientFinal +func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { + return sm.server.HandleClientFinal(msg) +} + +// AddCredential wraps server's AddCredential +func (sm *ScramManager) AddCredential(cred *Credential) { + sm.server.AddCredential(cred) +} \ No newline at end of file diff --git a/src/internal/scram/message.go b/src/internal/scram/message.go new file mode 100644 index 0000000..6b97a5b --- /dev/null +++ b/src/internal/scram/message.go @@ -0,0 +1,101 @@ +// FILE: src/internal/scram/message.go +package scram + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" +) + +// ClientFirst initiates authentication +type ClientFirst struct { + Username string `json:"u"` + ClientNonce string `json:"n"` +} + +// ServerFirst contains server challenge +type ServerFirst struct { + FullNonce string `json:"r"` // client_nonce + server_nonce + Salt string `json:"s"` // base64 + ArgonTime uint32 `json:"t"` + ArgonMemory uint32 `json:"m"` + ArgonThreads uint8 `json:"p"` +} + +// ClientFinal contains client proof +type ClientFinal struct { + FullNonce string `json:"r"` + ClientProof string `json:"p"` // base64 +} + +// ServerFinal contains server signature for mutual auth +type ServerFinal struct { + ServerSignature string `json:"v"` // base64 + SessionID string `json:"sid,omitempty"` +} + +// Marshal/Unmarshal helpers for TCP protocol (line-based) +func (cf *ClientFirst) Marshal() string { + return fmt.Sprintf("u=%s,n=%s", cf.Username, cf.ClientNonce) +} + +func ParseClientFirst(data string) (*ClientFirst, error) { + parts := strings.Split(data, ",") + msg := &ClientFirst{} + for _, part := range parts { + kv := strings.SplitN(part, "=", 2) + if len(kv) != 2 { + continue + } + switch kv[0] { + case "u": + msg.Username = kv[1] + case "n": + msg.ClientNonce = kv[1] + } + } + if msg.Username == "" || msg.ClientNonce == "" { + return nil, fmt.Errorf("missing required fields") + } + return msg, nil +} + +func (sf *ServerFirst) Marshal() string { + return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d", + sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads) +} + +func ParseServerFirst(data string) (*ServerFirst, error) { + parts := strings.Split(data, ",") + msg := &ServerFirst{} + for _, part := range parts { + kv := strings.SplitN(part, "=", 2) + if len(kv) != 2 { + continue + } + switch kv[0] { + case "r": + msg.FullNonce = kv[1] + case "s": + msg.Salt = kv[1] + case "t": + fmt.Sscanf(kv[1], "%d", &msg.ArgonTime) + case "m": + fmt.Sscanf(kv[1], "%d", &msg.ArgonMemory) + case "p": + fmt.Sscanf(kv[1], "%d", &msg.ArgonThreads) + } + } + return msg, nil +} + +// JSON variants for HTTP +func (cf *ClientFirst) MarshalJSON() ([]byte, error) { + return json.Marshal(*cf) +} + +func (sf *ServerFirst) MarshalJSON() ([]byte, error) { + sf.Salt = base64.StdEncoding.EncodeToString([]byte(sf.Salt)) + return json.Marshal(*sf) +} \ No newline at end of file diff --git a/src/internal/scram/scram_test.go b/src/internal/scram/scram_test.go new file mode 100644 index 0000000..0b05fda --- /dev/null +++ b/src/internal/scram/scram_test.go @@ -0,0 +1,228 @@ +// FILE: src/internal/scram/scram_test.go +package scram + +import ( + "crypto/rand" + "encoding/base64" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/argon2" +) + +func TestCredentialDerivation(t *testing.T) { + salt := make([]byte, 16) + _, err := rand.Read(salt) + require.NoError(t, err) + + cred, err := DeriveCredential("testuser", "testpass123", salt, 3, 64*1024, 4) + require.NoError(t, err) + + assert.Equal(t, "testuser", cred.Username) + assert.Equal(t, uint32(3), cred.ArgonTime) + assert.Equal(t, uint32(64*1024), cred.ArgonMemory) + assert.Equal(t, uint8(4), cred.ArgonThreads) + assert.Len(t, cred.StoredKey, 32) + assert.Len(t, cred.ServerKey, 32) +} + +func TestFullHandshake(t *testing.T) { + // Setup server with credential + server := NewServer() + salt := make([]byte, 16) + _, err := rand.Read(salt) + require.NoError(t, err) + + cred, err := DeriveCredential("alice", "secret123", salt, 3, 64*1024, 4) + require.NoError(t, err) + server.AddCredential(cred) + + // Setup client + client := NewClient("alice", "secret123") + + // Step 1: Client starts + clientFirst, err := client.StartAuthentication() + require.NoError(t, err) + assert.Equal(t, "alice", clientFirst.Username) + assert.NotEmpty(t, clientFirst.ClientNonce) + + // Step 2: Server responds + serverFirst, err := server.HandleClientFirst(clientFirst) + require.NoError(t, err) + assert.Contains(t, serverFirst.FullNonce, clientFirst.ClientNonce) + assert.Equal(t, base64.StdEncoding.EncodeToString(salt), serverFirst.Salt) + + // Step 3: Client proves + clientFinal, err := client.ProcessServerFirst(serverFirst) + require.NoError(t, err) + assert.Equal(t, serverFirst.FullNonce, clientFinal.FullNonce) + assert.NotEmpty(t, clientFinal.ClientProof) + + // Step 4: Server verifies and signs + serverFinal, err := server.HandleClientFinal(clientFinal) + require.NoError(t, err) + assert.NotEmpty(t, serverFinal.ServerSignature) + assert.NotEmpty(t, serverFinal.SessionID) + + // Step 5: Client verifies server + err = client.VerifyServerFinal(serverFinal) + assert.NoError(t, err) +} + +func TestInvalidPassword(t *testing.T) { + server := NewServer() + salt := make([]byte, 16) + rand.Read(salt) + + cred, _ := DeriveCredential("bob", "correct", salt, 3, 64*1024, 4) + server.AddCredential(cred) + + // Client with wrong password + client := NewClient("bob", "wrong") + + clientFirst, _ := client.StartAuthentication() + serverFirst, _ := server.HandleClientFirst(clientFirst) + clientFinal, _ := client.ProcessServerFirst(serverFirst) + + clientFinal, err := client.ProcessServerFirst(serverFirst) + require.NoError(t, err) // Check error to prevent panic + + // Server should reject + _, err = server.HandleClientFinal(clientFinal) + assert.Error(t, err) + assert.Contains(t, err.Error(), "authentication failed") +} + +func TestHandshakeTimeout(t *testing.T) { + server := NewServer() + salt := make([]byte, 16) + rand.Read(salt) + + cred, _ := DeriveCredential("charlie", "pass", salt, 3, 64*1024, 4) + server.AddCredential(cred) + + client := NewClient("charlie", "pass") + clientFirst, _ := client.StartAuthentication() + serverFirst, _ := server.HandleClientFirst(clientFirst) + + // Manipulate handshake timestamp + server.mu.Lock() + if state, exists := server.handshakes[serverFirst.FullNonce]; exists { + state.CreatedAt = time.Now().Add(-61 * time.Second) + } + server.mu.Unlock() + + clientFinal, _ := client.ProcessServerFirst(serverFirst) + _, err := server.HandleClientFinal(clientFinal) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "timeout") +} + +func TestMessageParsing(t *testing.T) { + // Test ClientFirst parsing + cf, err := ParseClientFirst("u=alice,n=abc123") + require.NoError(t, err) + assert.Equal(t, "alice", cf.Username) + assert.Equal(t, "abc123", cf.ClientNonce) + + // Test marshal round-trip + marshaled := cf.Marshal() + parsed, err := ParseClientFirst(marshaled) + require.NoError(t, err) + assert.Equal(t, cf.Username, parsed.Username) + assert.Equal(t, cf.ClientNonce, parsed.ClientNonce) +} + +func TestPHCMigration(t *testing.T) { + // Create PHC hash using known values + password := "testpass" + salt := []byte("saltsaltsaltsalt") + saltB64 := base64.RawStdEncoding.EncodeToString(salt) + + // This would be generated by the existing auth system + phcHash := "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + "dGVzdGhhc2g" // dummy hash + + // For testing, we need a valid hash - generate it + hash := argon2.IDKey([]byte(password), salt, 3, 65536, 4, 32) + hashB64 := base64.RawStdEncoding.EncodeToString(hash) + phcHash = "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + hashB64 + + cred, err := MigrateFromPHC("alice", password, phcHash) + require.NoError(t, err) + assert.Equal(t, "alice", cred.Username) + assert.Equal(t, salt, cred.Salt) +} + +func TestRateLimiting(t *testing.T) { + manager := NewScramManager() + + limiter := manager.GetRateLimiter("192.168.1.1") + + // Should allow burst + for i := 0; i < 3; i++ { + assert.True(t, limiter.Allow()) + } + + // 4th should be rate limited + assert.False(t, limiter.Allow()) +} + +func TestNonceUniqueness(t *testing.T) { + nonces := make(map[string]bool) + + for i := 0; i < 1000; i++ { + nonce := generateNonce() + assert.False(t, nonces[nonce], "Duplicate nonce generated") + nonces[nonce] = true + } +} + +func TestConstantTimeComparison(t *testing.T) { + server := NewServer() + salt := make([]byte, 16) + rand.Read(salt) + + // Add real user + cred, _ := DeriveCredential("realuser", "realpass", salt, 3, 64*1024, 4) + server.AddCredential(cred) + + // Test timing independence for invalid user + fakeClient := NewClient("fakeuser", "fakepass") + clientFirst, _ := fakeClient.StartAuthentication() + + // Should still return ServerFirst (no user enumeration) + serverFirst, err := server.HandleClientFirst(clientFirst) + assert.NotNil(t, serverFirst) + assert.Error(t, err) // Internal error, not exposed to client +} + +func BenchmarkArgon2Derivation(b *testing.B) { + salt := make([]byte, 16) + rand.Read(salt) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + DeriveCredential("user", "password", salt, 3, 64*1024, 4) + } +} + +func BenchmarkFullHandshake(b *testing.B) { + server := NewServer() + salt := make([]byte, 16) + rand.Read(salt) + cred, _ := DeriveCredential("bench", "pass", salt, 3, 64*1024, 4) + server.AddCredential(cred) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + client := NewClient("bench", "pass") + clientFirst, _ := client.StartAuthentication() + serverFirst, _ := server.HandleClientFirst(clientFirst) + clientFinal, _ := client.ProcessServerFirst(serverFirst) + serverFinal, _ := server.HandleClientFinal(clientFinal) + client.VerifyServerFinal(serverFinal) + } +} \ No newline at end of file diff --git a/src/internal/scram/server.go b/src/internal/scram/server.go new file mode 100644 index 0000000..ef61a59 --- /dev/null +++ b/src/internal/scram/server.go @@ -0,0 +1,179 @@ +// FILE: src/internal/scram/server.go +package scram + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "fmt" + "sync" + "time" +) + +// Server handles SCRAM authentication +type Server struct { + credentials map[string]*Credential + handshakes map[string]*HandshakeState + mu sync.RWMutex + + // Default Argon2 params for new registrations + DefaultTime uint32 + DefaultMemory uint32 + DefaultThreads uint8 +} + +// HandshakeState tracks ongoing authentication +type HandshakeState struct { + Username string + ClientNonce string + ServerNonce string + FullNonce string + AuthMessage string + Credential *Credential + CreatedAt time.Time + ClientProof []byte +} + +// NewServer creates SCRAM server +func NewServer() *Server { + return &Server{ + credentials: make(map[string]*Credential), + handshakes: make(map[string]*HandshakeState), + DefaultTime: 3, + DefaultMemory: 64 * 1024, + DefaultThreads: 4, + } +} + +// AddCredential registers user credential +func (s *Server) AddCredential(cred *Credential) { + s.mu.Lock() + defer s.mu.Unlock() + s.credentials[cred.Username] = cred +} + +// HandleClientFirst processes initial auth request +func (s *Server) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Check if user exists + cred, exists := s.credentials[msg.Username] + if !exists { + // Prevent user enumeration - still generate response + salt := make([]byte, 16) + rand.Read(salt) + serverNonce := generateNonce() + + return &ServerFirst{ + FullNonce: msg.ClientNonce + serverNonce, + Salt: base64.StdEncoding.EncodeToString(salt), + ArgonTime: s.DefaultTime, + ArgonMemory: s.DefaultMemory, + ArgonThreads: s.DefaultThreads, + }, fmt.Errorf("invalid credentials") + } + + // Generate server nonce + serverNonce := generateNonce() + fullNonce := msg.ClientNonce + serverNonce + + // Store handshake state + state := &HandshakeState{ + Username: msg.Username, + ClientNonce: msg.ClientNonce, + ServerNonce: serverNonce, + FullNonce: fullNonce, + Credential: cred, + CreatedAt: time.Now(), + } + s.handshakes[fullNonce] = state + + // Cleanup old handshakes + s.cleanupHandshakes() + + return &ServerFirst{ + FullNonce: fullNonce, + Salt: base64.StdEncoding.EncodeToString(cred.Salt), + ArgonTime: cred.ArgonTime, + ArgonMemory: cred.ArgonMemory, + ArgonThreads: cred.ArgonThreads, + }, nil +} + +// HandleClientFinal verifies client proof +func (s *Server) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { + s.mu.Lock() + defer s.mu.Unlock() + + state, exists := s.handshakes[msg.FullNonce] + if !exists { + return nil, fmt.Errorf("invalid nonce or expired handshake") + } + defer delete(s.handshakes, msg.FullNonce) + + // Check timeout + if time.Since(state.CreatedAt) > 60*time.Second { + return nil, fmt.Errorf("handshake timeout") + } + + // Decode client proof + clientProof, err := base64.StdEncoding.DecodeString(msg.ClientProof) + if err != nil { + return nil, fmt.Errorf("invalid proof encoding") + } + + // Build auth message + clientFirstBare := fmt.Sprintf("u=%s,n=%s", state.Username, state.ClientNonce) + serverFirst := &ServerFirst{ + FullNonce: state.FullNonce, + Salt: base64.StdEncoding.EncodeToString(state.Credential.Salt), + ArgonTime: state.Credential.ArgonTime, + ArgonMemory: state.Credential.ArgonMemory, + ArgonThreads: state.Credential.ArgonThreads, + } + clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce) + authMessage := clientFirstBare + "," + serverFirst.Marshal() + "," + clientFinalBare + + // Compute client signature + clientSignature := computeHMAC(state.Credential.StoredKey, []byte(authMessage)) + + // XOR to get ClientKey + clientKey := xorBytes(clientProof, clientSignature) + + // Verify by computing StoredKey + computedStoredKey := sha256.Sum256(clientKey) + if subtle.ConstantTimeCompare(computedStoredKey[:], state.Credential.StoredKey) != 1 { + return nil, fmt.Errorf("authentication failed") + } + + // Generate server signature for mutual auth + serverSignature := computeHMAC(state.Credential.ServerKey, []byte(authMessage)) + + return &ServerFinal{ + ServerSignature: base64.StdEncoding.EncodeToString(serverSignature), + SessionID: generateSessionID(), + }, nil +} + +func (s *Server) cleanupHandshakes() { + cutoff := time.Now().Add(-60 * time.Second) + for nonce, state := range s.handshakes { + if state.CreatedAt.Before(cutoff) { + delete(s.handshakes, nonce) + } + } +} + +func generateNonce() string { + b := make([]byte, 32) + rand.Read(b) + return base64.StdEncoding.EncodeToString(b) +} + +func generateSessionID() string { + b := make([]byte, 24) + rand.Read(b) + return base64.URLEncoding.EncodeToString(b) +} \ No newline at end of file diff --git a/src/internal/service/service.go b/src/internal/service/service.go index 34a8058..8fc5e78 100644 --- a/src/internal/service/service.go +++ b/src/internal/service/service.go @@ -281,10 +281,8 @@ func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) return sink.NewTCPClientSink(cfg.Options, s.logger, formatter) case "file": return sink.NewFileSink(cfg.Options, s.logger, formatter) - case "stdout": - return sink.NewStdoutSink(cfg.Options, s.logger, formatter) - case "stderr": - return sink.NewStderrSink(cfg.Options, s.logger, formatter) + case "console": + return sink.NewConsoleSink(cfg.Options, s.logger, formatter) default: return nil, fmt.Errorf("unknown sink type: %s", cfg.Type) } diff --git a/src/internal/sink/console.go b/src/internal/sink/console.go index 3afd8b4..39c61a3 100644 --- a/src/internal/sink/console.go +++ b/src/internal/sink/console.go @@ -2,9 +2,9 @@ package sink import ( + "bytes" "context" - "io" - "os" + "fmt" "strings" "sync/atomic" "time" @@ -16,20 +16,13 @@ import ( "github.com/lixenwraith/log" ) -// Holds common configuration for console sinks -type ConsoleConfig struct { - Target string // "stdout", "stderr", or "split" - BufferSize int64 -} - -// Writes log entries to stdout -type StdoutSink struct { +// ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance +type ConsoleSink struct { input chan core.LogEntry - config ConsoleConfig - output io.Writer + writer *log.Logger // Dedicated internal logger instance for console writing done chan struct{} startTime time.Time - logger *log.Logger + logger *log.Logger // Application logger for app logs formatter format.Formatter // Statistics @@ -37,29 +30,38 @@ type StdoutSink struct { lastProcessed atomic.Value // time.Time } -// Creates a new stdout sink -func NewStdoutSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*StdoutSink, error) { - config := ConsoleConfig{ - Target: "stdout", - BufferSize: 1000, +// Creates a new console sink +func NewConsoleSink(options map[string]any, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) { + target := "stdout" + if t, ok := options["target"].(string); ok { + target = t } - // Check for split mode configuration - if target, ok := options["target"].(string); ok { - config.Target = target + bufferSize := int64(1000) + if buf, ok := options["buffer_size"].(int64); ok && buf > 0 { + bufferSize = buf } - if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { - config.BufferSize = bufSize + // Dedicated logger instance as console writer + writer, err := log.NewBuilder(). + EnableFile(false). + EnableConsole(true). + ConsoleTarget(target). + Format("raw"). // Passthrough pre-formatted messages + ShowTimestamp(false). // Disable writer's own timestamp + ShowLevel(false). // Disable writer's own level prefix + Build() + + if err != nil { + return nil, fmt.Errorf("failed to create console writer: %w", err) } - s := &StdoutSink{ - input: make(chan core.LogEntry, config.BufferSize), - config: config, - output: os.Stdout, + s := &ConsoleSink{ + input: make(chan core.LogEntry, bufferSize), + writer: writer, done: make(chan struct{}), startTime: time.Now(), - logger: logger, + logger: appLogger, formatter: formatter, } s.lastProcessed.Store(time.Time{}) @@ -67,39 +69,52 @@ func NewStdoutSink(options map[string]any, logger *log.Logger, formatter format. return s, nil } -func (s *StdoutSink) Input() chan<- core.LogEntry { +func (s *ConsoleSink) Input() chan<- core.LogEntry { return s.input } -func (s *StdoutSink) Start(ctx context.Context) error { +func (s *ConsoleSink) Start(ctx context.Context) error { + // Start the internal writer's processing goroutine. + if err := s.writer.Start(); err != nil { + return fmt.Errorf("failed to start console writer: %w", err) + } go s.processLoop(ctx) - s.logger.Info("msg", "Stdout sink started", - "component", "stdout_sink", - "target", s.config.Target) + s.logger.Info("msg", "Console sink started", + "component", "console_sink", + "target", s.writer.GetConfig().ConsoleTarget) return nil } -func (s *StdoutSink) Stop() { - s.logger.Info("msg", "Stopping stdout sink") +func (s *ConsoleSink) Stop() { + target := s.writer.GetConfig().ConsoleTarget + s.logger.Info("msg", "Stopping console sink", "target", target) close(s.done) - s.logger.Info("msg", "Stdout sink stopped") + + // Shutdown the internal writer with a timeout. + if err := s.writer.Shutdown(2 * time.Second); err != nil { + s.logger.Error("msg", "Error shutting down console writer", + "component", "console_sink", + "error", err) + } + s.logger.Info("msg", "Console sink stopped", "target", target) } -func (s *StdoutSink) GetStats() SinkStats { +func (s *ConsoleSink) GetStats() SinkStats { lastProc, _ := s.lastProcessed.Load().(time.Time) return SinkStats{ - Type: "stdout", + Type: "console", TotalProcessed: s.totalProcessed.Load(), StartTime: s.startTime, LastProcessed: lastProc, Details: map[string]any{ - "target": s.config.Target, + "target": s.writer.GetConfig().ConsoleTarget, }, } } -func (s *StdoutSink) processLoop(ctx context.Context) { +// processLoop reads entries, formats them, and passes them to the internal writer. +func (s *ConsoleSink) processLoop(ctx context.Context) { for { select { case entry, ok := <-s.input: @@ -110,24 +125,30 @@ func (s *StdoutSink) processLoop(ctx context.Context) { s.totalProcessed.Add(1) s.lastProcessed.Store(time.Now()) - // Handle split mode - only process INFO/DEBUG for stdout - if s.config.Target == "split" { - upperLevel := strings.ToUpper(entry.Level) - if upperLevel == "ERROR" || upperLevel == "WARN" || upperLevel == "WARNING" { - // Skip ERROR/WARN levels in stdout when in split mode - continue - } - } - - // Format and write + // Format the entry using the pipeline's configured formatter. formatted, err := s.formatter.Format(entry) if err != nil { - s.logger.Error("msg", "Failed to format log entry for stdout", - "component", "stdout_sink", + s.logger.Error("msg", "Failed to format log entry for console", + "component", "console_sink", "error", err) continue } - s.output.Write(formatted) + + // Convert to string to prevent hex encoding of []byte by log package + // Strip new line, writer adds it + message := string(bytes.TrimSuffix(formatted, []byte{'\n'})) + switch strings.ToUpper(entry.Level) { + case "DEBUG": + s.writer.Debug(message) + case "INFO": + s.writer.Info(message) + case "WARN", "WARNING": + s.writer.Warn(message) + case "ERROR", "FATAL": + s.writer.Error(message) + default: + s.writer.Message(message) + } case <-ctx.Done(): return @@ -137,125 +158,6 @@ func (s *StdoutSink) processLoop(ctx context.Context) { } } -// Writes log entries to stderr -type StderrSink struct { - input chan core.LogEntry - config ConsoleConfig - output io.Writer - done chan struct{} - startTime time.Time - logger *log.Logger - formatter format.Formatter - - // Statistics - totalProcessed atomic.Uint64 - lastProcessed atomic.Value // time.Time -} - -// Creates a new stderr sink -func NewStderrSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*StderrSink, error) { - config := ConsoleConfig{ - Target: "stderr", - BufferSize: 1000, - } - - // Check for split mode configuration - if target, ok := options["target"].(string); ok { - config.Target = target - } - - if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { - config.BufferSize = bufSize - } - - s := &StderrSink{ - input: make(chan core.LogEntry, config.BufferSize), - config: config, - output: os.Stderr, - done: make(chan struct{}), - startTime: time.Now(), - logger: logger, - formatter: formatter, - } - s.lastProcessed.Store(time.Time{}) - - return s, nil -} - -func (s *StderrSink) Input() chan<- core.LogEntry { - return s.input -} - -func (s *StderrSink) Start(ctx context.Context) error { - go s.processLoop(ctx) - s.logger.Info("msg", "Stderr sink started", - "component", "stderr_sink", - "target", s.config.Target) - return nil -} - -func (s *StderrSink) Stop() { - s.logger.Info("msg", "Stopping stderr sink") - close(s.done) - s.logger.Info("msg", "Stderr sink stopped") -} - -func (s *StderrSink) GetStats() SinkStats { - lastProc, _ := s.lastProcessed.Load().(time.Time) - - return SinkStats{ - Type: "stderr", - TotalProcessed: s.totalProcessed.Load(), - StartTime: s.startTime, - LastProcessed: lastProc, - Details: map[string]any{ - "target": s.config.Target, - }, - } -} - -func (s *StderrSink) processLoop(ctx context.Context) { - for { - select { - case entry, ok := <-s.input: - if !ok { - return - } - - s.totalProcessed.Add(1) - s.lastProcessed.Store(time.Now()) - - // Handle split mode - only process ERROR/WARN for stderr - if s.config.Target == "split" { - upperLevel := strings.ToUpper(entry.Level) - if upperLevel != "ERROR" && upperLevel != "WARN" && upperLevel != "WARNING" { - // Skip non-ERROR/WARN levels in stderr when in split mode - continue - } - } - - // Format and write - formatted, err := s.formatter.Format(entry) - if err != nil { - s.logger.Error("msg", "Failed to format log entry for stderr", - "component", "stderr_sink", - "error", err) - continue - } - s.output.Write(formatted) - - case <-ctx.Done(): - return - case <-s.done: - return - } - } -} - -func (s *StdoutSink) SetAuth(auth *config.AuthConfig) { - // Authentication does not apply to stdout sink -} - -func (s *StderrSink) SetAuth(auth *config.AuthConfig) { - // Authentication does not apply to stderr sink +func (s *ConsoleSink) SetAuth(auth *config.AuthConfig) { + // Authentication does not apply to the console sink. } \ No newline at end of file diff --git a/src/internal/sink/file.go b/src/internal/sink/file.go index c85e6b1..28815f6 100644 --- a/src/internal/sink/file.go +++ b/src/internal/sink/file.go @@ -2,6 +2,7 @@ package sink import ( + "bytes" "context" "fmt" "logwisp/src/internal/config" @@ -32,12 +33,14 @@ type FileSink struct { func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*FileSink, error) { directory, ok := options["directory"].(string) if !ok || directory == "" { - return nil, fmt.Errorf("file sink requires 'directory' option") + directory = "./" + logger.Warn("No directory or invalid directory provided, current directory will be used") } name, ok := options["name"].(string) if !ok || name == "" { - return nil, fmt.Errorf("file sink requires 'name' option") + name = "logwisp.output" + logger.Warn(fmt.Sprintf("No filename provided, %s will be used", name)) } // Create configuration for the internal log writer @@ -77,7 +80,7 @@ func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Fo } // Buffer size for input channel - // TODO: Make this configurable + // TODO: Centralized constant file in core package bufferSize := int64(1000) if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { bufferSize = bufSize @@ -152,11 +155,9 @@ func (fs *FileSink) processLoop(ctx context.Context) { continue } - // Write formatted bytes (strip newline as writer adds it) - message := string(formatted) - if len(message) > 0 && message[len(message)-1] == '\n' { - message = message[:len(message)-1] - } + // Convert to string to prevent hex encoding of []byte by log package + // Strip new line, writer adds it + message := string(bytes.TrimSuffix(formatted, []byte{'\n'})) fs.writer.Message(message) case <-ctx.Done(): diff --git a/src/internal/sink/http.go b/src/internal/sink/http.go index 20bcf50..d9c3f25 100644 --- a/src/internal/sink/http.go +++ b/src/internal/sink/http.go @@ -471,6 +471,21 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { } } + // Enforce TLS for authentication + if h.authenticator != nil && h.authConfig.Type != "none" { + isTLS := ctx.IsTLS() || h.tlsManager != nil + + if !isTLS { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(map[string]string{ + "error": "TLS required for authentication", + "hint": "Use HTTPS for authenticated connections", + }) + return + } + } + path := string(ctx.Path()) // Status endpoint doesn't require auth @@ -811,7 +826,7 @@ func (h *HTTPSink) SetAuth(authCfg *config.AuthConfig) { } h.authConfig = authCfg - authenticator, err := auth.New(authCfg, h.logger) + authenticator, err := auth.NewAuthenticator(authCfg, h.logger) if err != nil { h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink", "component", "http_sink", diff --git a/src/internal/sink/http_client.go b/src/internal/sink/http_client.go index fee738d..fbfae32 100644 --- a/src/internal/sink/http_client.go +++ b/src/internal/sink/http_client.go @@ -52,27 +52,29 @@ type HTTPClientSink struct { // TODO: missing toml tags type HTTPClientConfig struct { // Config - URL string - BufferSize int64 - BatchSize int64 - BatchDelay time.Duration - Timeout time.Duration - Headers map[string]string + URL string `toml:"url"` + BufferSize int64 `toml:"buffer_size"` + BatchSize int64 `toml:"batch_size"` + BatchDelay time.Duration `toml:"batch_delay_ms"` + Timeout time.Duration `toml:"timeout_seconds"` + Headers map[string]string `toml:"headers"` // Retry configuration - MaxRetries int64 - RetryDelay time.Duration - RetryBackoff float64 // Multiplier for exponential backoff + MaxRetries int64 `toml:"max_retries"` + RetryDelay time.Duration `toml:"retry_delay"` + RetryBackoff float64 `toml:"retry_backoff"` // Multiplier for exponential backoff // Security - Username string - Password string + AuthType string `toml:"auth_type"` // "none", "basic", "bearer", "mtls" + Username string `toml:"username"` // For basic auth + Password string `toml:"password"` // For basic auth + BearerToken string `toml:"bearer_token"` // For bearer auth // TLS configuration - InsecureSkipVerify bool - CAFile string - CertFile string - KeyFile string + InsecureSkipVerify bool `toml:"insecure_skip_verify"` + CAFile string `toml:"ca_file"` + CertFile string `toml:"cert_file"` + KeyFile string `toml:"key_file"` } // Creates a new HTTP client sink @@ -129,12 +131,64 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for if insecure, ok := options["insecure_skip_verify"].(bool); ok { cfg.InsecureSkipVerify = insecure } + if authType, ok := options["auth_type"].(string); ok { + switch authType { + case "none", "basic", "bearer", "mtls": + cfg.AuthType = authType + default: + return nil, fmt.Errorf("http_client sink: invalid auth_type '%s'", authType) + } + } else { + cfg.AuthType = "none" + } if username, ok := options["username"].(string); ok { cfg.Username = username } if password, ok := options["password"].(string); ok { cfg.Password = password // TODO: change to Argon2 hashed password } + if token, ok := options["bearer_token"].(string); ok { + cfg.BearerToken = token + } + + // Validate auth configuration and TLS enforcement + isHTTPS := strings.HasPrefix(cfg.URL, "https://") + + switch cfg.AuthType { + case "basic": + if cfg.Username == "" || cfg.Password == "" { + return nil, fmt.Errorf("http_client sink: username and password required for basic auth") + } + if !isHTTPS { + return nil, fmt.Errorf("http_client sink: basic auth requires HTTPS (security: credentials would be sent in plaintext)") + } + + case "bearer": + if cfg.BearerToken == "" { + return nil, fmt.Errorf("http_client sink: bearer_token required for bearer auth") + } + if !isHTTPS { + return nil, fmt.Errorf("http_client sink: bearer auth requires HTTPS (security: token would be sent in plaintext)") + } + + case "mtls": + if !isHTTPS { + return nil, fmt.Errorf("http_client sink: mTLS requires HTTPS") + } + if cfg.CertFile == "" || cfg.KeyFile == "" { + return nil, fmt.Errorf("http_client sink: cert_file and key_file required for mTLS") + } + + case "none": + // Clear any credentials if auth is "none" + if cfg.Username != "" || cfg.Password != "" || cfg.BearerToken != "" { + logger.Warn("msg", "Credentials provided but auth_type is 'none', ignoring", + "component", "http_client_sink") + cfg.Username = "" + cfg.Password = "" + cfg.BearerToken = "" + } + } // Extract headers if headers, ok := options["headers"].(map[string]any); ok { @@ -416,6 +470,7 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { var lastErr error retryDelay := h.config.RetryDelay + // TODO: verify retry loop placement is correct or should it be after acquiring resources (req :=....) for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ { if attempt > 0 { // Wait before retry @@ -444,11 +499,22 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short())) - // Add Basic Auth header if credentials configured - if h.config.Username != "" && h.config.Password != "" { + // Add authentication based on auth type + switch h.config.AuthType { + case "basic": creds := h.config.Username + ":" + h.config.Password encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds)) req.Header.Set("Authorization", "Basic "+encodedCreds) + + case "bearer": + req.Header.Set("Authorization", "Bearer "+h.config.BearerToken) + + case "mtls": + // mTLS auth is handled at TLS layer via client certificates + // No Authorization header needed + + case "none": + // No authentication } // Set headers diff --git a/src/internal/sink/tcp.go b/src/internal/sink/tcp.go index 08398e9..5d8085c 100644 --- a/src/internal/sink/tcp.go +++ b/src/internal/sink/tcp.go @@ -602,7 +602,7 @@ func (t *TCPSink) SetAuth(authCfg *config.AuthConfig) { return } - authenticator, err := auth.New(authCfg, t.logger) + authenticator, err := auth.NewAuthenticator(authCfg, t.logger) if err != nil { t.logger.Error("msg", "Failed to initialize authenticator for TCP sink", "component", "tcp_sink", diff --git a/src/internal/sink/tcp_client.go b/src/internal/sink/tcp_client.go index 5a1c243..657e55c 100644 --- a/src/internal/sink/tcp_client.go +++ b/src/internal/sink/tcp_client.go @@ -4,7 +4,7 @@ package sink import ( "bufio" "context" - "encoding/base64" + "encoding/json" "errors" "fmt" "net" @@ -13,26 +13,25 @@ import ( "sync/atomic" "time" - "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" + "logwisp/src/internal/scram" "github.com/lixenwraith/log" ) // Forwards log entries to a remote TCP endpoint type TCPClientSink struct { - input chan core.LogEntry - config TCPClientConfig - conn net.Conn - connMu sync.RWMutex - done chan struct{} - wg sync.WaitGroup - startTime time.Time - logger *log.Logger - formatter format.Formatter - authenticator *auth.Authenticator + input chan core.LogEntry + config TCPClientConfig + conn net.Conn + connMu sync.RWMutex + done chan struct{} + wg sync.WaitGroup + startTime time.Time + logger *log.Logger + formatter format.Formatter // Reconnection state reconnecting atomic.Bool @@ -49,24 +48,22 @@ type TCPClientSink struct { // Holds TCP client sink configuration type TCPClientConfig struct { - Address string - BufferSize int64 - DialTimeout time.Duration - WriteTimeout time.Duration - ReadTimeout time.Duration - KeepAlive time.Duration + Address string `toml:"address"` + BufferSize int64 `toml:"buffer_size"` + DialTimeout time.Duration `toml:"dial_timeout_seconds"` + WriteTimeout time.Duration `toml:"write_timeout_seconds"` + ReadTimeout time.Duration `toml:"read_timeout_seconds"` + KeepAlive time.Duration `toml:"keep_alive_seconds"` // Security - Username string - Password string + AuthType string `toml:"auth_type"` + Username string `toml:"username"` + Password string `toml:"password"` // Reconnection settings - ReconnectDelay time.Duration - MaxReconnectDelay time.Duration - ReconnectBackoff float64 - - // TLS config - TLS *config.TLSConfig + ReconnectDelay time.Duration `toml:"reconnect_delay_ms"` + MaxReconnectDelay time.Duration `toml:"max_reconnect_delay_seconds"` + ReconnectBackoff float64 `toml:"reconnect_backoff"` } // Creates a new TCP client sink @@ -120,11 +117,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 { cfg.ReconnectBackoff = backoff } - if username, ok := options["username"].(string); ok { - cfg.Username = username - } - if password, ok := options["password"].(string); ok { - cfg.Password = password + if authType, ok := options["auth_type"].(string); ok { + switch authType { + case "none": + cfg.AuthType = authType + case "scram": + cfg.AuthType = authType + if username, ok := options["username"].(string); ok && username != "" { + cfg.Username = username + } else { + return nil, fmt.Errorf("invalid scram username") + } + if password, ok := options["password"].(string); ok && password != "" { + cfg.Password = password + } else { + return nil, fmt.Errorf("invalid scram password") + } + default: + return nil, fmt.Errorf("tcp_client sink: invalid auth_type '%s' (must be 'none' or 'scram')", authType) + } } t := &TCPClientSink{ @@ -304,49 +315,115 @@ func (t *TCPClientSink) connect() (net.Conn, error) { tcpConn.SetKeepAlivePeriod(t.config.KeepAlive) } - // Handle authentication if credentials configured - if t.config.Username != "" && t.config.Password != "" { - // Read auth challenge - reader := bufio.NewReader(conn) - challenge, err := reader.ReadString('\n') - if err != nil { + // SCRAM authentication if credentials configured + if t.config.AuthType == "scram" { + if err := t.performSCRAMAuth(conn); err != nil { conn.Close() - return nil, fmt.Errorf("failed to read auth challenge: %w", err) - } - - if strings.TrimSpace(challenge) == "AUTH_REQUIRED" { - // Send credentials - creds := t.config.Username + ":" + t.config.Password - encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds)) - authCmd := fmt.Sprintf("AUTH basic %s\n", encodedCreds) - - if _, err := conn.Write([]byte(authCmd)); err != nil { - conn.Close() - return nil, fmt.Errorf("failed to send auth: %w", err) - } - - // Read response - response, err := reader.ReadString('\n') - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to read auth response: %w", err) - } - - if strings.TrimSpace(response) != "AUTH_OK" { - conn.Close() - return nil, fmt.Errorf("authentication failed: %s", response) - } - - t.logger.Debug("msg", "TCP authentication successful", - "component", "tcp_client_sink", - "address", t.config.Address, - "username", t.config.Username) + return nil, fmt.Errorf("SCRAM authentication failed: %w", err) } + t.logger.Debug("msg", "SCRAM authentication completed", + "component", "tcp_client_sink", + "address", t.config.Address) } return conn, nil } +func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error { + reader := bufio.NewReader(conn) + + // Create SCRAM client + scramClient := scram.NewClient(t.config.Username, t.config.Password) + + // Step 1: Send ClientFirst + clientFirst, err := scramClient.StartAuthentication() + if err != nil { + return fmt.Errorf("failed to start SCRAM: %w", err) + } + + clientFirstJSON, _ := json.Marshal(clientFirst) + msg := fmt.Sprintf("SCRAM-FIRST %s\n", clientFirstJSON) + + if _, err := conn.Write([]byte(msg)); err != nil { + return fmt.Errorf("failed to send SCRAM-FIRST: %w", err) + } + + // Step 2: Receive ServerFirst challenge + response, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("failed to read SCRAM challenge: %w", err) + } + + parts := strings.Fields(strings.TrimSpace(response)) + if len(parts) != 2 || parts[0] != "SCRAM-CHALLENGE" { + return fmt.Errorf("unexpected server response: %s", response) + } + + var serverFirst scram.ServerFirst + if err := json.Unmarshal([]byte(parts[1]), &serverFirst); err != nil { + return fmt.Errorf("failed to parse server challenge: %w", err) + } + + // Step 3: Process challenge and send proof + clientFinal, err := scramClient.ProcessServerFirst(&serverFirst) + if err != nil { + return fmt.Errorf("failed to process challenge: %w", err) + } + + clientFinalJSON, _ := json.Marshal(clientFinal) + msg = fmt.Sprintf("SCRAM-PROOF %s\n", clientFinalJSON) + + if _, err := conn.Write([]byte(msg)); err != nil { + return fmt.Errorf("failed to send SCRAM-PROOF: %w", err) + } + + // Step 4: Receive ServerFinal + response, err = reader.ReadString('\n') + if err != nil { + return fmt.Errorf("failed to read SCRAM result: %w", err) + } + + parts = strings.Fields(strings.TrimSpace(response)) + if len(parts) < 1 { + return fmt.Errorf("empty server response") + } + + switch parts[0] { + case "SCRAM-OK": + if len(parts) != 2 { + return fmt.Errorf("invalid SCRAM-OK response") + } + + var serverFinal scram.ServerFinal + if err := json.Unmarshal([]byte(parts[1]), &serverFinal); err != nil { + return fmt.Errorf("failed to parse server signature: %w", err) + } + + // Verify server signature + if err := scramClient.VerifyServerFinal(&serverFinal); err != nil { + return fmt.Errorf("server signature verification failed: %w", err) + } + + t.logger.Info("msg", "SCRAM authentication successful", + "component", "tcp_client_sink", + "address", t.config.Address, + "username", t.config.Username, + "session_id", serverFinal.SessionID) + + return nil + + case "SCRAM-FAIL": + reason := "unknown" + if len(parts) > 1 { + reason = strings.Join(parts[1:], " ") + } + return fmt.Errorf("authentication failed: %s", reason) + + default: + return fmt.Errorf("unexpected response: %s", response) + } +} + func (t *TCPClientSink) monitorConnection(conn net.Conn) { // Simple connection monitoring by periodic zero-byte reads ticker := time.NewTicker(5 * time.Second) diff --git a/src/internal/source/http.go b/src/internal/source/http.go index 95fcd78..b1275af 100644 --- a/src/internal/source/http.go +++ b/src/internal/source/http.go @@ -117,12 +117,6 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok { cfg.MaxConnectionsPerIP = maxPerIP } - if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok { - cfg.MaxConnectionsPerUser = maxPerUser - } - if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok { - cfg.MaxConnectionsPerToken = maxPerToken - } if maxTotal, ok := nl["max_connections_total"].(int64); ok { cfg.MaxConnectionsTotal = maxTotal } @@ -313,6 +307,27 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { } } + // 2.5. Check TLS requirement for auth (early reject) + if h.authenticator != nil && h.authConfig.Type != "none" { + // Check if connection is TLS + isTLS := ctx.IsTLS() || h.tlsManager != nil + + if !isTLS { + h.logger.Error("msg", "Authentication configured but connection is not TLS", + "component", "http_source", + "remote_addr", remoteAddr, + "auth_type", h.authConfig.Type) + + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(map[string]string{ + "error": "TLS required for authentication", + "hint": "Use HTTPS to submit authenticated requests", + }) + return + } + } + // 3. Path check (only process ingest path) path := string(ctx.Path()) if path != h.path { @@ -543,7 +558,7 @@ func (h *HTTPSource) SetAuth(authCfg *config.AuthConfig) { } h.authConfig = authCfg - authenticator, err := auth.New(authCfg, h.logger) + authenticator, err := auth.NewAuthenticator(authCfg, h.logger) if err != nil { h.logger.Error("msg", "Failed to initialize authenticator for HTTP source", "component", "http_source", diff --git a/src/internal/source/tcp.go b/src/internal/source/tcp.go index f000970..31f4fa4 100644 --- a/src/internal/source/tcp.go +++ b/src/internal/source/tcp.go @@ -4,6 +4,7 @@ package source import ( "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "net" @@ -16,6 +17,7 @@ import ( "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/limit" + "logwisp/src/internal/scram" "github.com/lixenwraith/log" "github.com/lixenwraith/log/compat" @@ -29,19 +31,19 @@ const ( // Receives log entries via TCP connections type TCPSource struct { - host string - port int64 - bufferSize int64 - server *tcpSourceServer - subscribers []chan core.LogEntry - mu sync.RWMutex - done chan struct{} - engine *gnet.Engine - engineMu sync.Mutex - wg sync.WaitGroup - netLimiter *limit.NetLimiter - logger *log.Logger - authenticator *auth.Authenticator + host string + port int64 + bufferSize int64 + server *tcpSourceServer + subscribers []chan core.LogEntry + mu sync.RWMutex + done chan struct{} + engine *gnet.Engine + engineMu sync.Mutex + wg sync.WaitGroup + netLimiter *limit.NetLimiter + logger *log.Logger + scramManager *scram.ScramManager // Statistics totalEntries atomic.Uint64 @@ -255,12 +257,13 @@ func (t *TCPSource) publish(entry core.LogEntry) bool { // Represents a connected TCP client type tcpClient struct { conn gnet.Conn - buffer bytes.Buffer + buffer *bytes.Buffer authenticated bool authTimeout time.Time session *auth.Session maxBufferSeen int cumulativeEncrypted int64 + scramState *scram.HandshakeState } // Handles gnet events @@ -314,11 +317,9 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { // Create client state client := &tcpClient{ conn: c, - authenticated: s.source.authenticator == nil, - } - - if s.source.authenticator != nil { - client.authTimeout = time.Now().Add(30 * time.Second) + buffer: bytes.NewBuffer(nil), + authTimeout: time.Now().Add(30 * time.Second), + authenticated: s.source.scramManager == nil, } s.mu.Lock() @@ -330,12 +331,7 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { "component", "tcp_source", "remote_addr", remoteAddr, "active_connections", newCount, - "requires_auth", s.source.authenticator != nil) - - // Send auth challenge if required - if s.source.authenticator != nil { - return []byte("AUTH_REQUIRED\n"), gnet.None - } + "requires_auth", s.source.scramManager != nil) return nil, gnet.None } @@ -380,52 +376,107 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { return gnet.Close } - // Authentication phase - if !client.authenticated { - if time.Now().After(client.authTimeout) { + // SCRAM Authentication phase + if !client.authenticated && s.source.scramManager != nil { + // Check auth timeout + if !client.authTimeout.IsZero() && time.Now().After(client.authTimeout) { s.source.logger.Warn("msg", "Authentication timeout", "component", "tcp_source", "remote_addr", c.RemoteAddr().String()) return gnet.Close } + if len(data) == 0 { + return gnet.None + } + client.buffer.Write(data) - // Look for auth line - if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 { + // Look for complete line + for { + idx := bytes.IndexByte(client.buffer.Bytes(), '\n') + if idx < 0 { + break + } + line := client.buffer.Bytes()[:idx] client.buffer.Next(idx + 1) - parts := strings.SplitN(string(line), " ", 3) - if len(parts) != 3 || parts[0] != "AUTH" { - c.AsyncWrite([]byte("AUTH_FAIL\n"), nil) + // Parse SCRAM messages + parts := strings.Fields(string(line)) + if len(parts) < 2 { + c.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil) return gnet.Close } - session, err := s.source.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String()) - if err != nil { - s.source.authFailures.Add(1) - s.source.logger.Warn("msg", "Authentication failed", + switch parts[0] { + case "SCRAM-FIRST": + // Parse ClientFirst JSON + var clientFirst scram.ClientFirst + if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil { + c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil) + return gnet.Close + } + + // Process with SCRAM server + serverFirst, err := s.source.scramManager.HandleClientFirst(&clientFirst) + if err != nil { + // Still send challenge to prevent user enumeration + response, _ := json.Marshal(serverFirst) + c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil) + return gnet.Close + } + + // Send ServerFirst challenge + response, _ := json.Marshal(serverFirst) + c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil) + + case "SCRAM-PROOF": + // Parse ClientFinal JSON + var clientFinal scram.ClientFinal + if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil { + c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil) + return gnet.Close + } + + // Verify proof + serverFinal, err := s.source.scramManager.HandleClientFinal(&clientFinal) + if err != nil { + s.source.logger.Warn("msg", "SCRAM authentication failed", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String(), + "error", err) + c.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil) + return gnet.Close + } + + // Authentication successful + s.mu.Lock() + client.authenticated = true + client.session = &auth.Session{ + ID: serverFinal.SessionID, + Method: "scram-sha-256", + RemoteAddr: c.RemoteAddr().String(), + CreatedAt: time.Now(), + } + s.mu.Unlock() + + // Send ServerFinal with signature + response, _ := json.Marshal(serverFinal) + c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil) + + s.source.logger.Info("msg", "Client authenticated via SCRAM", "component", "tcp_source", "remote_addr", c.RemoteAddr().String(), - "error", err) - c.AsyncWrite([]byte("AUTH_FAIL\n"), nil) + "session_id", serverFinal.SessionID) + + // Clear auth buffer + client.buffer.Reset() + + default: + c.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil) return gnet.Close } - - s.source.authSuccesses.Add(1) - s.mu.Lock() - client.authenticated = true - client.session = session - s.mu.Unlock() - - s.source.logger.Info("msg", "TCP client authenticated", - "component", "tcp_source", - "remote_addr", c.RemoteAddr().String(), - "username", session.Username) - - c.AsyncWrite([]byte("AUTH_OK\n"), nil) - client.buffer.Reset() } return gnet.None } @@ -522,22 +573,46 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { return gnet.None } +func (t *TCPSource) InitSCRAMManager(authCfg *config.AuthConfig) { + if authCfg == nil || authCfg.Type != "scram" || authCfg.ScramAuth == nil { + return + } + + t.scramManager = scram.NewScramManager() + + // Load users from SCRAM config + for _, user := range authCfg.ScramAuth.Users { + storedKey, _ := base64.StdEncoding.DecodeString(user.StoredKey) + serverKey, _ := base64.StdEncoding.DecodeString(user.ServerKey) + salt, _ := base64.StdEncoding.DecodeString(user.Salt) + + cred := &scram.Credential{ + Username: user.Username, + StoredKey: storedKey, + ServerKey: serverKey, + Salt: salt, + ArgonTime: user.ArgonTime, + ArgonMemory: user.ArgonMemory, + ArgonThreads: user.ArgonThreads, + } + t.scramManager.AddCredential(cred) + } + + t.logger.Info("msg", "SCRAM authentication configured", + "component", "tcp_source", + "users", len(authCfg.ScramAuth.Users)) +} + // Configure TCP source auth func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) { if authCfg == nil || authCfg.Type == "none" { return } - authenticator, err := auth.New(authCfg, t.logger) - if err != nil { - t.logger.Error("msg", "Failed to initialize authenticator for TCP source", - "component", "tcp_source", - "error", err) - return + // Initialize SCRAM manager + if authCfg.Type == "scram" { + t.InitSCRAMManager(authCfg) + t.logger.Info("msg", "SCRAM authentication configured for TCP source", + "component", "tcp_source") } - t.authenticator = authenticator - - t.logger.Info("msg", "Authentication configured for TCP source", - "component", "tcp_source", - "auth_type", authCfg.Type) } \ No newline at end of file diff --git a/src/internal/tls/generator.go b/src/internal/tls/generator.go index e6e0581..63d91c3 100644 --- a/src/internal/tls/generator.go +++ b/src/internal/tls/generator.go @@ -9,6 +9,7 @@ import ( "encoding/pem" "flag" "fmt" + "io" "math/big" "net" "os" @@ -16,14 +17,21 @@ import ( "time" ) -type CertGeneratorCommand struct{} - -func NewCertGeneratorCommand() *CertGeneratorCommand { - return &CertGeneratorCommand{} +type CertGeneratorCommand struct { + output io.Writer + errOut io.Writer } -func (c *CertGeneratorCommand) Execute(args []string) error { +func NewCertGeneratorCommand() *CertGeneratorCommand { + return &CertGeneratorCommand{ + output: os.Stdout, + errOut: os.Stderr, + } +} + +func (cg *CertGeneratorCommand) Execute(args []string) error { cmd := flag.NewFlagSet("tls", flag.ContinueOnError) + cmd.SetOutput(cg.errOut) // Subcommands var ( @@ -50,20 +58,21 @@ func (c *CertGeneratorCommand) Execute(args []string) error { ) cmd.Usage = func() { - fmt.Fprintln(os.Stderr, "Generate TLS certificates for LogWisp") - fmt.Fprintln(os.Stderr, "\nUsage: logwisp tls [options]") - fmt.Fprintln(os.Stderr, "\nExamples:") - fmt.Fprintln(os.Stderr, " # Generate self-signed certificate") - fmt.Fprintln(os.Stderr, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1") - fmt.Fprintln(os.Stderr, " ") - fmt.Fprintln(os.Stderr, " # Generate CA certificate") - fmt.Fprintln(os.Stderr, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key") - fmt.Fprintln(os.Stderr, " ") - fmt.Fprintln(os.Stderr, " # Generate server certificate signed by CA") - fmt.Fprintln(os.Stderr, " logwisp tls --server --cn server.example.com --hosts server.example.com \\") - fmt.Fprintln(os.Stderr, " --ca-cert ca.crt --ca-key ca.key") - fmt.Fprintln(os.Stderr, "\nOptions:") + fmt.Fprintln(cg.errOut, "Generate TLS certificates for LogWisp") + fmt.Fprintln(cg.errOut, "\nUsage: logwisp tls [options]") + fmt.Fprintln(cg.errOut, "\nExamples:") + fmt.Fprintln(cg.errOut, " # Generate self-signed certificate") + fmt.Fprintln(cg.errOut, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1") + fmt.Fprintln(cg.errOut, " ") + fmt.Fprintln(cg.errOut, " # Generate CA certificate") + fmt.Fprintln(cg.errOut, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key") + fmt.Fprintln(cg.errOut, " ") + fmt.Fprintln(cg.errOut, " # Generate server certificate signed by CA") + fmt.Fprintln(cg.errOut, " logwisp tls --server --cn server.example.com --hosts server.example.com \\") + fmt.Fprintln(cg.errOut, " --ca-cert ca.crt --ca-key ca.key") + fmt.Fprintln(cg.errOut, "\nOptions:") cmd.PrintDefaults() + fmt.Fprintln(cg.errOut) } if err := cmd.Parse(args); err != nil { @@ -79,13 +88,13 @@ func (c *CertGeneratorCommand) Execute(args []string) error { // Route to appropriate generator switch { case *genCA: - return c.generateCA(*commonName, *org, *country, *validDays, *keySize, *certOut, *keyOut) + return cg.generateCA(*commonName, *org, *country, *validDays, *keySize, *certOut, *keyOut) case *selfSign: - return c.generateSelfSigned(*commonName, *org, *country, *hosts, *validDays, *keySize, *certOut, *keyOut) + return cg.generateSelfSigned(*commonName, *org, *country, *hosts, *validDays, *keySize, *certOut, *keyOut) case *genServer: - return c.generateServerCert(*commonName, *org, *country, *hosts, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) + return cg.generateServerCert(*commonName, *org, *country, *hosts, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) case *genClient: - return c.generateClientCert(*commonName, *org, *country, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) + return cg.generateClientCert(*commonName, *org, *country, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) default: cmd.Usage() return fmt.Errorf("specify certificate type: --ca, --self-signed, --server, or --client") @@ -93,7 +102,7 @@ func (c *CertGeneratorCommand) Execute(args []string) error { } // Create and manage private CA -func (c *CertGeneratorCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error { +func (cg *CertGeneratorCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error { // Generate RSA key priv, err := rsa.GenerateKey(rand.Reader, bits) if err != nil { @@ -169,7 +178,7 @@ func parseHosts(hostList string) ([]string, []net.IP) { } // Generate self-signed certificate -func (c *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error { +func (cg *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error { // 1. Generate an RSA private key with the specified bit size priv, err := rsa.GenerateKey(rand.Reader, bits) if err != nil { @@ -236,7 +245,7 @@ func (c *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string } // Generate server cert with CA -func (c *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { +func (cg *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { caCert, caKey, err := loadCA(caFile, caKeyFile) if err != nil { return err @@ -299,7 +308,7 @@ func (c *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFil } // Generate client cert with CA -func (c *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { +func (cg *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { caCert, caKey, err := loadCA(caFile, caKeyFile) if err != nil { return err @@ -356,66 +365,105 @@ func (c *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKe } // Load cert with CA -func loadCA(caFile, caKeyFile string) (*x509.Certificate, *rsa.PrivateKey, error) { - if caFile == "" || caKeyFile == "" { - return nil, nil, fmt.Errorf("--ca-cert and --ca-key are required for signing") - } - - caCertPEM, err := os.ReadFile(caFile) +func loadCA(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) { + // Load CA certificate + certPEM, err := os.ReadFile(certFile) if err != nil { return nil, nil, fmt.Errorf("failed to read CA certificate: %w", err) } - caCertBlock, _ := pem.Decode(caCertPEM) - if caCertBlock == nil { - return nil, nil, fmt.Errorf("failed to decode CA certificate PEM") + + certBlock, _ := pem.Decode(certPEM) + if certBlock == nil || certBlock.Type != "CERTIFICATE" { + return nil, nil, fmt.Errorf("invalid CA certificate format") } - caCert, err := x509.ParseCertificate(caCertBlock.Bytes) + + caCert, err := x509.ParseCertificate(certBlock.Bytes) if err != nil { return nil, nil, fmt.Errorf("failed to parse CA certificate: %w", err) } - if !caCert.IsCA { - return nil, nil, fmt.Errorf("provided certificate is not a valid CA") - } - - caKeyPEM, err := os.ReadFile(caKeyFile) + // Load CA private key + keyPEM, err := os.ReadFile(keyFile) if err != nil { return nil, nil, fmt.Errorf("failed to read CA key: %w", err) } - caKeyBlock, _ := pem.Decode(caKeyPEM) - if caKeyBlock == nil { - return nil, nil, fmt.Errorf("failed to decode CA key PEM") + + keyBlock, _ := pem.Decode(keyPEM) + if keyBlock == nil { + return nil, nil, fmt.Errorf("invalid CA key format") } - caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes) + + var caKey *rsa.PrivateKey + switch keyBlock.Type { + case "RSA PRIVATE KEY": + caKey, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + case "PRIVATE KEY": + parsedKey, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse CA key: %w", err) + } + var ok bool + caKey, ok = parsedKey.(*rsa.PrivateKey) + if !ok { + return nil, nil, fmt.Errorf("CA key is not RSA") + } + default: + return nil, nil, fmt.Errorf("unsupported CA key type: %s", keyBlock.Type) + } + if err != nil { return nil, nil, fmt.Errorf("failed to parse CA private key: %w", err) } - // Verify key matches certificate - if caCert.PublicKey.(*rsa.PublicKey).N.Cmp(caKey.N) != 0 { - return nil, nil, fmt.Errorf("CA private key does not match CA certificate") + // Verify CA certificate is actually a CA + if !caCert.IsCA { + return nil, nil, fmt.Errorf("certificate is not a CA certificate") } return caCert, caKey, nil } -func saveCert(filename string, derBytes []byte) error { - certOut, err := os.Create(filename) +func saveCert(filename string, certDER []byte) error { + certFile, err := os.Create(filename) if err != nil { - return fmt.Errorf("failed to create cert file %s: %w", filename, err) + return fmt.Errorf("failed to create certificate file: %w", err) } - defer certOut.Close() - return pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + defer certFile.Close() + + if err := pem.Encode(certFile, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }); err != nil { + return fmt.Errorf("failed to write certificate: %w", err) + } + + // Set readable permissions + if err := os.Chmod(filename, 0644); err != nil { + return fmt.Errorf("failed to set certificate permissions: %w", err) + } + + return nil } func saveKey(filename string, key *rsa.PrivateKey) error { - keyOut, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + keyFile, err := os.Create(filename) if err != nil { - return fmt.Errorf("failed to create key file %s: %w", filename, err) + return fmt.Errorf("failed to create key file: %w", err) } - defer keyOut.Close() - return pem.Encode(keyOut, &pem.Block{ + defer keyFile.Close() + + privKeyDER := x509.MarshalPKCS1PrivateKey(key) + if err := pem.Encode(keyFile, &pem.Block{ Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }) + Bytes: privKeyDER, + }); err != nil { + return fmt.Errorf("failed to write private key: %w", err) + } + + // Set restricted permissions for private key + if err := os.Chmod(filename, 0600); err != nil { + return fmt.Errorf("failed to set key permissions: %w", err) + } + + return nil } \ No newline at end of file diff --git a/test/README.md b/test/README.md index de0fb6a..5bf2a65 100644 --- a/test/README.md +++ b/test/README.md @@ -6,7 +6,7 @@ ### Notes: -- The tests create configuration files and log files. Debug tests do not clean up these files. +- The tests create configuration files and log files. Most tests set logging at debug level and don't clean up their temp files that are created in the current execution directory. -- Some tests may require to be run on different hosts (containers can be used). +- Some tests may need to be run on different hosts (containers can be used). diff --git a/test/test-logwisp-auth-debug.sh b/test/test-basic-auth.sh similarity index 94% rename from test/test-logwisp-auth-debug.sh rename to test/test-basic-auth.sh index 8bddd43..0bc839a 100755 --- a/test/test-logwisp-auth-debug.sh +++ b/test/test-basic-auth.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# FILE: test-logwisp-auth-debug.sh +# FILE: test-basic-auth.sh # Creates test directories and starts network services set -e @@ -32,7 +32,7 @@ fi cat > test-auth.toml << EOF # General LogWisp settings log_dir = "test-logs" -log_level = "debug" # CHANGED: Set to debug +log_level = "debug" data_dir = "test-data" # Logging configuration for troubleshooting @@ -41,7 +41,7 @@ target = "all" level = "debug" [logging.console] enabled = true -target = "stdout" # CHANGED: Log to stdout for visibility +target = "stdout" format = "txt" [[pipelines]] @@ -59,7 +59,9 @@ port = 5514 host = "127.0.0.1" [[pipelines.sinks]] -type = "stdout" +type = "console" +[pipelines.sinks.options] +target = "stdout" # Second pipeline for HTTP [[pipelines]] @@ -78,7 +80,10 @@ host = "127.0.0.1" path = "/ingest" [[pipelines.sinks]] -type = "stdout" # CHANGED: Simplify to stdout for debugging +type = "console" +[pipelines.sinks.options] +target = "stdout" + EOF # Start LogWisp with visible debug output