diff --git a/.gitignore b/.gitignore index 64d71d5..e2a31d8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,6 @@ cert bin script build +test *.log -*.toml \ No newline at end of file +*.toml diff --git a/doc/environment.md b/doc/environment.md index c7dfcf6..7410e4a 100644 --- a/doc/environment.md +++ b/doc/environment.md @@ -18,7 +18,6 @@ Examples: ```bash LOGWISP_CONFIG_FILE=/etc/logwisp/config.toml LOGWISP_CONFIG_DIR=/etc/logwisp -LOGWISP_ROUTER=true LOGWISP_BACKGROUND=true LOGWISP_QUIET=true LOGWISP_DISABLE_STATUS_REPORTER=true @@ -221,7 +220,6 @@ LOGWISP_PIPELINES_0_SINKS_0_OPTIONS_TARGET=stdout #!/usr/bin/env bash # General settings -export LOGWISP_ROUTER=true export LOGWISP_DISABLE_STATUS_REPORTER=false # Logging diff --git a/go.mod b/go.mod index 341a6cb..bc94d64 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,15 @@ module logwisp -go 1.24.5 +go 1.25.1 require ( - github.com/lixenwraith/config v0.0.0-20250901201021-59a461e31cd4 - github.com/lixenwraith/log v0.0.0-20250722012845-16a3079e46e2 + github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 + github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208 github.com/panjf2000/gnet/v2 v2.9.3 github.com/valyala/fasthttp v1.65.0 + golang.org/x/crypto v0.42.0 + golang.org/x/term v0.35.0 ) require ( @@ -19,8 +22,8 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/sync v0.16.0 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sync v0.17.0 // indirect + golang.org/x/sys v0.36.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c1e891b..9de57e3 100644 --- a/go.sum +++ b/go.sum @@ -6,12 +6,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw= github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/lixenwraith/config v0.0.0-20250901201021-59a461e31cd4 h1:SxqXt6J7ZLA39SP4zvJU0Jv3GbXLzM5iB7cgk5d7Pe4= -github.com/lixenwraith/config v0.0.0-20250901201021-59a461e31cd4/go.mod h1:l+1PZ8JsohLAXOJKu5loFa+zCdOSb/lXf3JUwa5ST/4= -github.com/lixenwraith/log v0.0.0-20250722012845-16a3079e46e2 h1:nP/12l+gKkZnZRoM3Vy4iT2anBQm1jCtrppyZq9pcq4= -github.com/lixenwraith/log v0.0.0-20250722012845-16a3079e46e2/go.mod h1:sLCRfKeLInCj2LcMnAo2knULwfszU8QPuIFOQ8crcFo= +github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0= +github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0= +github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208 h1:IB1O/HLv9VR/4mL1Tkjlr91lk+r8anP6bab7rYdS/oE= +github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0= github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek= github.com/panjf2000/gnet/v2 v2.9.3 h1:auV3/A9Na3jiBDmYAAU00rPhFKnsAI+TnI1F7YUJMHQ= @@ -32,10 +34,14 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= +golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= +golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/src/cmd/auth-gen/main.go b/src/cmd/auth-gen/main.go new file mode 100644 index 0000000..38564e2 --- /dev/null +++ b/src/cmd/auth-gen/main.go @@ -0,0 +1,110 @@ +// FILE: logwisp/src/cmd/auth-gen/main.go +package main + +import ( + "crypto/rand" + "encoding/base64" + "flag" + "fmt" + "os" + "syscall" + + "golang.org/x/crypto/bcrypt" + "golang.org/x/term" +) + +func main() { + var ( + username = flag.String("u", "", "Username for basic auth") + password = flag.String("p", "", "Password to hash (will prompt if not provided)") + cost = flag.Int("c", 10, "Bcrypt cost (10-31)") + genToken = flag.Bool("t", false, "Generate random bearer token") + tokenLen = flag.Int("l", 32, "Token length in bytes") + ) + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "LogWisp Authentication Utility\n\n") + fmt.Fprintf(os.Stderr, "Usage:\n") + fmt.Fprintf(os.Stderr, " Generate bcrypt hash: %s -u [-p ]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " Generate bearer token: %s -t [-l ]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "\nOptions:\n") + flag.PrintDefaults() + } + + flag.Parse() + + if *genToken { + generateToken(*tokenLen) + return + } + + if *username == "" { + fmt.Fprintf(os.Stderr, "Error: Username required for basic auth\n") + flag.Usage() + os.Exit(1) + } + + // Get password + pass := *password + if pass == "" { + pass = promptPassword("Enter password: ") + confirm := promptPassword("Confirm password: ") + if pass != confirm { + fmt.Fprintf(os.Stderr, "Error: Passwords don't match\n") + os.Exit(1) + } + } + + // Generate bcrypt hash + hash, err := bcrypt.GenerateFromPassword([]byte(pass), *cost) + if err != nil { + fmt.Fprintf(os.Stderr, "Error generating hash: %v\n", err) + os.Exit(1) + } + + // Output TOML config format + fmt.Println("\n# Add to logwisp.toml under [[pipelines.auth.basic_auth.users]]:") + fmt.Printf("[[pipelines.auth.basic_auth.users]]\n") + fmt.Printf("username = \"%s\"\n", *username) + fmt.Printf("password_hash = \"%s\"\n", string(hash)) + + // Also output for users file format + fmt.Println("\n# Or add to users file:") + fmt.Printf("%s:%s\n", *username, string(hash)) +} + +func promptPassword(prompt string) string { + fmt.Fprint(os.Stderr, prompt) + password, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Fprintln(os.Stderr) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading password: %v\n", err) + os.Exit(1) + } + return string(password) +} + +func generateToken(length int) { + if length < 16 { + fmt.Fprintf(os.Stderr, "Warning: Token length < 16 bytes is insecure\n") + } + + token := make([]byte, length) + if _, err := rand.Read(token); err != nil { + fmt.Fprintf(os.Stderr, "Error generating token: %v\n", err) + os.Exit(1) + } + + // Output in various formats + b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token) + hex := fmt.Sprintf("%x", token) + + fmt.Println("\n# Add to logwisp.toml under [pipelines.auth.bearer_auth]:") + fmt.Printf("tokens = [\"%s\"]\n", b64) + + fmt.Println("\n# Alternative hex encoding:") + fmt.Printf("# tokens = [\"%s\"]\n", hex) + + fmt.Printf("\n# Token (base64): %s\n", b64) + fmt.Printf("# Token (hex): %s\n", hex) +} \ No newline at end of file diff --git a/src/cmd/logwisp/bootstrap.go b/src/cmd/logwisp/bootstrap.go index 59caee0..501be26 100644 --- a/src/cmd/logwisp/bootstrap.go +++ b/src/cmd/logwisp/bootstrap.go @@ -14,17 +14,10 @@ import ( ) // bootstrapService creates and initializes the log transport service -func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service, *service.HTTPRouter, error) { +func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service, error) { // Create service with logger dependency injection svc := service.New(ctx, logger) - // Create HTTP router if requested - var router *service.HTTPRouter - if cfg.UseRouter { - router = service.NewHTTPRouter(svc, logger) - logger.Info("msg", "HTTP router mode enabled") - } - // Initialize pipelines successCount := 0 for _, pipelineCfg := range cfg.Pipelines { @@ -37,32 +30,19 @@ func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service "error", err) continue } - - // If using router mode, register HTTP sinks - if cfg.UseRouter { - pipeline, err := svc.GetPipeline(pipelineCfg.Name) - if err == nil && len(pipeline.HTTPSinks) > 0 { - if err := router.RegisterPipeline(pipeline); err != nil { - logger.Error("msg", "Failed to register pipeline with router", - "pipeline", pipelineCfg.Name, - "error", err) - } - } - } - successCount++ - displayPipelineEndpoints(pipelineCfg, cfg.UseRouter) + displayPipelineEndpoints(pipelineCfg) } if successCount == 0 { - return nil, nil, fmt.Errorf("no pipelines successfully started (attempted %d)", len(cfg.Pipelines)) + return nil, fmt.Errorf("no pipelines successfully started (attempted %d)", len(cfg.Pipelines)) } logger.Info("msg", "LogWisp started", "version", version.Short(), "pipelines", successCount) - return svc, router, nil + return svc, nil } // initializeLogger sets up the logger based on configuration diff --git a/src/cmd/logwisp/help.go b/src/cmd/logwisp/help.go index 385cf8d..2d10aff 100644 --- a/src/cmd/logwisp/help.go +++ b/src/cmd/logwisp/help.go @@ -16,7 +16,6 @@ Application Control: -v, --version Display version information and exit. -b, --background Run LogWisp in the background as a daemon. -q, --quiet Suppress all console output, including errors. - --router Enable HTTP router mode for multiplexing pipelines. Runtime Behavior: --disable-status-reporter Disable the periodic status reporter. @@ -24,7 +23,7 @@ Runtime Behavior: Configuration Sources (Precedence: CLI > Env > File > Defaults): - CLI flags override all other settings. - - Environment variables (e.g., LOGWISP_ROUTER=true) override file settings. + - Environment variables override file settings. - TOML configuration file is the primary method for defining pipelines. Logging ([logging] section or LOGWISP_LOGGING_* env vars): diff --git a/src/cmd/logwisp/main.go b/src/cmd/logwisp/main.go index 75ebb0b..b600084 100644 --- a/src/cmd/logwisp/main.go +++ b/src/cmd/logwisp/main.go @@ -77,7 +77,6 @@ func main() { "version", version.String(), "config_file", cfg.ConfigFile, "log_output", cfg.Logging.Output, - "router_mode", cfg.UseRouter, "background_mode", cfg.Background) // Create context for shutdown @@ -117,7 +116,7 @@ func main() { // Traditional static bootstrap logger.Info("msg", "Config auto-reload disabled") - svc, router, err := bootstrapService(ctx, cfg) + svc, err := bootstrapService(ctx, cfg) if err != nil { logger.Error("msg", "Failed to bootstrap service", "error", err) os.Exit(1) @@ -142,12 +141,6 @@ func main() { logger.Info("msg", "Shutdown signal received, starting graceful shutdown...") - // Shutdown router first if using it - if router != nil { - logger.Info("msg", "Shutting down HTTP router...") - router.Shutdown() - } - // Shutdown service with timeout shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) defer shutdownCancel() diff --git a/src/cmd/logwisp/reload.go b/src/cmd/logwisp/reload.go index b8f64e6..a26dbb4 100644 --- a/src/cmd/logwisp/reload.go +++ b/src/cmd/logwisp/reload.go @@ -19,7 +19,6 @@ import ( type ReloadManager struct { configPath string service *service.Service - router *service.HTTPRouter cfg *config.Config lcfg *lconfig.Config logger *log.Logger @@ -47,14 +46,13 @@ func NewReloadManager(configPath string, initialCfg *config.Config, logger *log. // Start begins watching for configuration changes func (rm *ReloadManager) Start(ctx context.Context) error { // Bootstrap initial service - svc, router, err := bootstrapService(ctx, rm.cfg) + svc, err := bootstrapService(ctx, rm.cfg) if err != nil { return fmt.Errorf("failed to bootstrap initial service: %w", err) } rm.mu.Lock() rm.service = svc - rm.router = router rm.mu.Unlock() // Start status reporter for initial service @@ -149,11 +147,6 @@ func (rm *ReloadManager) shouldReload(path string) bool { return true } - // Router mode changes require reload - if path == "router" || path == "use_router" { - return true - } - // Logging changes don't require service reload if strings.HasPrefix(path, "logging.") { return false @@ -214,12 +207,11 @@ func (rm *ReloadManager) performReload(ctx context.Context) error { // Get current service snapshot rm.mu.RLock() oldService := rm.service - oldRouter := rm.router rm.mu.RUnlock() // Try to bootstrap with new configuration rm.logger.Debug("msg", "Bootstrapping new service with updated config") - newService, newRouter, err := bootstrapService(ctx, newCfg) + newService, err := bootstrapService(ctx, newCfg) if err != nil { // Bootstrap failed - keep old services running return fmt.Errorf("failed to bootstrap new service (old service still active): %w", err) @@ -228,7 +220,6 @@ func (rm *ReloadManager) performReload(ctx context.Context) error { // Bootstrap succeeded - swap services atomically rm.mu.Lock() rm.service = newService - rm.router = newRouter rm.cfg = newCfg rm.mu.Unlock() @@ -237,22 +228,17 @@ func (rm *ReloadManager) performReload(ctx context.Context) error { // Gracefully shutdown old services // This happens after the swap to minimize downtime - go rm.shutdownOldServices(oldRouter, oldService) + go rm.shutdownOldServices(oldService) return nil } // shutdownOldServices gracefully shuts down old services -func (rm *ReloadManager) shutdownOldServices(router *service.HTTPRouter, svc *service.Service) { +func (rm *ReloadManager) shutdownOldServices(svc *service.Service) { // Give connections time to drain rm.logger.Debug("msg", "Draining connections from old services") time.Sleep(2 * time.Second) - if router != nil { - rm.logger.Info("msg", "Shutting down old router") - router.Shutdown() - } - if svc != nil { rm.logger.Info("msg", "Shutting down old service") svc.Shutdown() @@ -337,15 +323,9 @@ func (rm *ReloadManager) Shutdown() { // Shutdown current services rm.mu.RLock() - currentRouter := rm.router currentService := rm.service rm.mu.RUnlock() - if currentRouter != nil { - rm.logger.Info("msg", "Shutting down router") - currentRouter.Shutdown() - } - if currentService != nil { rm.logger.Info("msg", "Shutting down service") currentService.Shutdown() @@ -357,11 +337,4 @@ func (rm *ReloadManager) GetService() *service.Service { rm.mu.RLock() defer rm.mu.RUnlock() return rm.service -} - -// GetRouter returns the current router (thread-safe) -func (rm *ReloadManager) GetRouter() *service.HTTPRouter { - rm.mu.RLock() - defer rm.mu.RUnlock() - return rm.router } \ No newline at end of file diff --git a/src/cmd/logwisp/status.go b/src/cmd/logwisp/status.go index 37b05f8..7233a20 100644 --- a/src/cmd/logwisp/status.go +++ b/src/cmd/logwisp/status.go @@ -109,7 +109,7 @@ func logPipelineStatus(name string, stats map[string]any) { } // displayPipelineEndpoints logs the configured endpoints for a pipeline -func displayPipelineEndpoints(cfg config.PipelineConfig, routerMode bool) { +func displayPipelineEndpoints(cfg config.PipelineConfig) { // Display sink endpoints for i, sinkCfg := range cfg.Sinks { switch sinkCfg.Type { @@ -144,19 +144,11 @@ func displayPipelineEndpoints(cfg config.PipelineConfig, routerMode bool) { statusPath = path } - if routerMode { - logger.Info("msg", "HTTP endpoints configured", - "pipeline", cfg.Name, - "sink_index", i, - "stream_path", fmt.Sprintf("/%s%s", cfg.Name, streamPath), - "status_path", fmt.Sprintf("/%s%s", cfg.Name, statusPath)) - } else { - logger.Info("msg", "HTTP endpoints configured", - "pipeline", cfg.Name, - "sink_index", i, - "stream_url", fmt.Sprintf("http://localhost:%d%s", port, streamPath), - "status_url", fmt.Sprintf("http://localhost:%d%s", port, statusPath)) - } + logger.Info("msg", "HTTP endpoints configured", + "pipeline", cfg.Name, + "sink_index", i, + "stream_url", fmt.Sprintf("http://localhost:%d%s", port, streamPath), + "status_url", fmt.Sprintf("http://localhost:%d%s", port, statusPath)) // Display net limit info if configured if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok { diff --git a/src/internal/auth/authenticator.go b/src/internal/auth/authenticator.go new file mode 100644 index 0000000..2e77bb6 --- /dev/null +++ b/src/internal/auth/authenticator.go @@ -0,0 +1,411 @@ +// FILE: logwisp/src/internal/auth/authenticator.go +package auth + +import ( + "bufio" + "encoding/base64" + "fmt" + "os" + "strings" + "sync" + "time" + + "logwisp/src/internal/config" + + "github.com/golang-jwt/jwt/v5" + "github.com/lixenwraith/log" + "golang.org/x/crypto/bcrypt" +) + +// Authenticator handles all authentication methods for a pipeline +type Authenticator struct { + config *config.AuthConfig + logger *log.Logger + basicUsers map[string]string // username -> password hash + bearerTokens map[string]bool // token -> valid + jwtParser *jwt.Parser + jwtKeyFunc jwt.Keyfunc + mu sync.RWMutex + + // Session tracking + sessions map[string]*Session + sessionMu sync.RWMutex +} + +// Session represents an authenticated connection +type Session struct { + ID string + Username string + Method string // basic, bearer, jwt, mtls + RemoteAddr string + CreatedAt time.Time + LastActivity time.Time + Metadata map[string]any +} + +// New creates a new authenticator from config +func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { + if cfg == nil || cfg.Type == "none" { + return nil, nil + } + + a := &Authenticator{ + config: cfg, + logger: logger, + basicUsers: make(map[string]string), + bearerTokens: make(map[string]bool), + sessions: make(map[string]*Session), + } + + // Initialize Basic Auth users + if cfg.Type == "basic" && cfg.BasicAuth != nil { + for _, user := range cfg.BasicAuth.Users { + a.basicUsers[user.Username] = user.PasswordHash + } + + // Load users from file if specified + if cfg.BasicAuth.UsersFile != "" { + if err := a.loadUsersFile(cfg.BasicAuth.UsersFile); err != nil { + return nil, fmt.Errorf("failed to load users file: %w", err) + } + } + } + + // Initialize Bearer tokens + if cfg.Type == "bearer" && cfg.BearerAuth != nil { + for _, token := range cfg.BearerAuth.Tokens { + a.bearerTokens[token] = true + } + + // Setup JWT validation if configured + if cfg.BearerAuth.JWT != nil { + a.jwtParser = jwt.NewParser( + jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}), + jwt.WithLeeway(5*time.Second), + ) + + // Setup key function + if cfg.BearerAuth.JWT.SigningKey != "" { + // Static key + key := []byte(cfg.BearerAuth.JWT.SigningKey) + a.jwtKeyFunc = func(token *jwt.Token) (interface{}, error) { + return key, nil + } + } else if cfg.BearerAuth.JWT.JWKSURL != "" { + // JWKS support would require additional implementation + // ☢ SECURITY: JWKS rotation not implemented - tokens won't refresh keys + return nil, fmt.Errorf("JWKS support not yet implemented") + } + } + } + + // Start session cleanup + go a.sessionCleanup() + + logger.Info("msg", "Authenticator initialized", + "component", "auth", + "type", cfg.Type) + + return a, nil +} + +// AuthenticateHTTP handles HTTP authentication headers +func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) { + if a == nil || a.config.Type == "none" { + return &Session{ + ID: generateSessionID(), + Method: "none", + RemoteAddr: remoteAddr, + CreatedAt: time.Now(), + }, nil + } + + switch a.config.Type { + case "basic": + return a.authenticateBasic(authHeader, remoteAddr) + case "bearer": + return a.authenticateBearer(authHeader, remoteAddr) + default: + return nil, fmt.Errorf("unsupported auth type: %s", a.config.Type) + } +} + +// AuthenticateTCP handles TCP connection authentication +func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string) (*Session, error) { + if a == nil || a.config.Type == "none" { + return &Session{ + ID: generateSessionID(), + Method: "none", + RemoteAddr: remoteAddr, + CreatedAt: time.Now(), + }, nil + } + + // TCP auth protocol: AUTH + switch strings.ToLower(method) { + case "token": + if a.config.Type != "bearer" { + return nil, fmt.Errorf("token auth not configured") + } + return a.validateToken(credentials, remoteAddr) + + case "basic": + if a.config.Type != "basic" { + return nil, fmt.Errorf("basic auth not configured") + } + // Expect base64(username:password) + decoded, err := base64.StdEncoding.DecodeString(credentials) + if err != nil { + return nil, fmt.Errorf("invalid credentials encoding") + } + parts := strings.SplitN(string(decoded), ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid credentials format") + } + return a.validateBasicAuth(parts[0], parts[1], remoteAddr) + + default: + return nil, fmt.Errorf("unsupported auth method: %s", method) + } +} + +func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) { + if !strings.HasPrefix(authHeader, "Basic ") { + return nil, fmt.Errorf("invalid basic auth header") + } + + payload, err := base64.StdEncoding.DecodeString(authHeader[6:]) + if err != nil { + return nil, fmt.Errorf("invalid base64 encoding") + } + + parts := strings.SplitN(string(payload), ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid credentials format") + } + + return a.validateBasicAuth(parts[0], parts[1], remoteAddr) +} + +func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string) (*Session, error) { + a.mu.RLock() + expectedHash, exists := a.basicUsers[username] + a.mu.RUnlock() + + if !exists { + // ☢ SECURITY: Perform bcrypt anyway to prevent timing attacks + bcrypt.CompareHashAndPassword([]byte("$2a$10$dummy.hash.to.prevent.timing.attacks"), []byte(password)) + return nil, fmt.Errorf("invalid credentials") + } + + if err := bcrypt.CompareHashAndPassword([]byte(expectedHash), []byte(password)); err != nil { + return nil, fmt.Errorf("invalid credentials") + } + + session := &Session{ + ID: generateSessionID(), + Username: username, + Method: "basic", + RemoteAddr: remoteAddr, + CreatedAt: time.Now(), + LastActivity: time.Now(), + } + + a.storeSession(session) + return session, nil +} + +func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) { + if !strings.HasPrefix(authHeader, "Bearer ") { + return nil, fmt.Errorf("invalid bearer auth header") + } + + token := authHeader[7:] + return a.validateToken(token, remoteAddr) +} + +func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) { + // Check static tokens first + a.mu.RLock() + isStatic := a.bearerTokens[token] + a.mu.RUnlock() + + if isStatic { + session := &Session{ + ID: generateSessionID(), + Method: "bearer", + RemoteAddr: remoteAddr, + CreatedAt: time.Now(), + LastActivity: time.Now(), + Metadata: map[string]any{"token_type": "static"}, + } + a.storeSession(session) + return session, nil + } + + // Try JWT validation if configured + if a.jwtParser != nil && a.jwtKeyFunc != nil { + claims := jwt.MapClaims{} + parsedToken, err := a.jwtParser.ParseWithClaims(token, claims, a.jwtKeyFunc) + if err != nil { + return nil, fmt.Errorf("JWT validation failed: %w", err) + } + + if !parsedToken.Valid { + return nil, fmt.Errorf("invalid JWT token") + } + + // Check issuer if configured + if a.config.BearerAuth.JWT.Issuer != "" { + if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer { + return nil, fmt.Errorf("invalid token issuer") + } + } + + // Check audience if configured + if a.config.BearerAuth.JWT.Audience != "" { + if aud, ok := claims["aud"].(string); !ok || aud != a.config.BearerAuth.JWT.Audience { + return nil, fmt.Errorf("invalid token audience") + } + } + + username := "" + if sub, ok := claims["sub"].(string); ok { + username = sub + } + + session := &Session{ + ID: generateSessionID(), + Username: username, + Method: "jwt", + RemoteAddr: remoteAddr, + CreatedAt: time.Now(), + LastActivity: time.Now(), + Metadata: map[string]any{"claims": claims}, + } + a.storeSession(session) + return session, nil + } + + return nil, fmt.Errorf("invalid token") +} + +func (a *Authenticator) storeSession(session *Session) { + a.sessionMu.Lock() + a.sessions[session.ID] = session + a.sessionMu.Unlock() + + a.logger.Info("msg", "Session created", + "component", "auth", + "session_id", session.ID, + "username", session.Username, + "method", session.Method, + "remote_addr", session.RemoteAddr) +} + +func (a *Authenticator) sessionCleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + a.sessionMu.Lock() + now := time.Now() + for id, session := range a.sessions { + if now.Sub(session.LastActivity) > 30*time.Minute { + delete(a.sessions, id) + a.logger.Debug("msg", "Session expired", + "component", "auth", + "session_id", id) + } + } + a.sessionMu.Unlock() + } +} + +func (a *Authenticator) loadUsersFile(path string) error { + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("could not open users file: %w", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + lineNumber := 0 + for scanner.Scan() { + lineNumber++ + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue // Skip empty lines and comments + } + + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + a.logger.Warn("msg", "Skipping malformed line in users file", + "component", "auth", + "path", path, + "line_number", lineNumber) + continue + } + username, hash := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + if username != "" && hash != "" { + // File-based users can overwrite inline users if names conflict + a.basicUsers[username] = hash + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading users file: %w", err) + } + + a.logger.Info("msg", "Loaded users from file", + "component", "auth", + "path", path, + "user_count", len(a.basicUsers)) + + return nil +} + +func generateSessionID() string { + return fmt.Sprintf("%d-%d", time.Now().UnixNano(), time.Now().Unix()) +} + +// ValidateSession checks if a session is still valid +func (a *Authenticator) ValidateSession(sessionID string) bool { + if a == nil { + return true + } + + a.sessionMu.RLock() + session, exists := a.sessions[sessionID] + a.sessionMu.RUnlock() + + if !exists { + return false + } + + // Update activity + a.sessionMu.Lock() + session.LastActivity = time.Now() + a.sessionMu.Unlock() + + return true +} + +// GetStats returns authentication statistics +func (a *Authenticator) GetStats() map[string]any { + if a == nil { + return map[string]any{"enabled": false} + } + + a.sessionMu.RLock() + sessionCount := len(a.sessions) + a.sessionMu.RUnlock() + + return map[string]any{ + "enabled": true, + "type": a.config.Type, + "active_sessions": sessionCount, + "basic_users": len(a.basicUsers), + "static_tokens": len(a.bearerTokens), + } +} \ No newline at end of file diff --git a/src/internal/config/auth.go b/src/internal/config/auth.go index 8b69e3b..84f7544 100644 --- a/src/internal/config/auth.go +++ b/src/internal/config/auth.go @@ -1,7 +1,11 @@ // FILE: logwisp/src/internal/config/auth.go package config -import "fmt" +import ( + "fmt" + "net" + "strings" +) type AuthConfig struct { // Authentication type: "none", "basic", "bearer", "mtls" @@ -12,10 +16,6 @@ type AuthConfig struct { // Bearer token auth BearerAuth *BearerAuthConfig `toml:"bearer_auth"` - - // IP-based access control - IPWhitelist []string `toml:"ip_whitelist"` - IPBlacklist []string `toml:"ip_blacklist"` } type BasicAuthConfig struct { diff --git a/src/internal/config/config.go b/src/internal/config/config.go index 77a446a..37186a4 100644 --- a/src/internal/config/config.go +++ b/src/internal/config/config.go @@ -3,7 +3,6 @@ package config type Config struct { // Top-level flags for application control - UseRouter bool `toml:"router"` Background bool `toml:"background"` ShowVersion bool `toml:"version"` Quiet bool `toml:"quiet"` diff --git a/src/internal/config/ratelimit.go b/src/internal/config/limit.go similarity index 69% rename from src/internal/config/ratelimit.go rename to src/internal/config/limit.go index 8507031..05c8654 100644 --- a/src/internal/config/ratelimit.go +++ b/src/internal/config/limit.go @@ -3,6 +3,7 @@ package config import ( "fmt" + "net" "strings" ) @@ -28,6 +29,37 @@ type RateLimitConfig struct { MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"` } +func validateNetAccess(pipelineName string, cfg *NetAccessConfig) error { + if cfg == nil { + return nil + } + + // Validate CIDR notation + for _, cidr := range cfg.IPWhitelist { + if !strings.Contains(cidr, "/") { + cidr = cidr + "/32" + } + if _, _, err := net.ParseCIDR(cidr); err != nil { + if net.ParseIP(cidr) == nil { + return fmt.Errorf("pipeline '%s': invalid IP whitelist entry: %s", pipelineName, cidr) + } + } + } + + for _, cidr := range cfg.IPBlacklist { + if !strings.Contains(cidr, "/") { + cidr = cidr + "/32" + } + if _, _, err := net.ParseCIDR(cidr); err != nil { + if net.ParseIP(cidr) == nil { + return fmt.Errorf("pipeline '%s': invalid IP blacklist entry: %s", pipelineName, cidr) + } + } + } + + return nil +} + func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { if cfg == nil { return nil diff --git a/src/internal/config/loader.go b/src/internal/config/loader.go index 6afc749..81f4ea4 100644 --- a/src/internal/config/loader.go +++ b/src/internal/config/loader.go @@ -19,7 +19,6 @@ type LoadContext struct { func defaults() *Config { return &Config{ // Top-level flag defaults - UseRouter: false, Background: false, ShowVersion: false, Quiet: false, diff --git a/src/internal/config/pipeline.go b/src/internal/config/pipeline.go index 22d4fda..9638838 100644 --- a/src/internal/config/pipeline.go +++ b/src/internal/config/pipeline.go @@ -20,6 +20,9 @@ type PipelineConfig struct { // Rate limiting RateLimit *RateLimitConfig `toml:"rate_limit"` + // Network access control (IP filtering) + NetAccess *NetAccessConfig `toml:"net_access"` + // Filter configuration Filters []FilterConfig `toml:"filters"` @@ -34,6 +37,12 @@ type PipelineConfig struct { Auth *AuthConfig `toml:"auth"` } +// NetAccessConfig defines IP-based access control lists +type NetAccessConfig struct { + IPWhitelist []string `toml:"ip_whitelist"` + IPBlacklist []string `toml:"ip_blacklist"` +} + // SourceConfig represents an input data source type SourceConfig struct { // Source type: "directory", "file", "stdin", etc. diff --git a/src/internal/config/validation.go b/src/internal/config/validation.go index b343683..010ff82 100644 --- a/src/internal/config/validation.go +++ b/src/internal/config/validation.go @@ -72,6 +72,11 @@ func (c *Config) validate() error { } } + // Validate net access if present + if err := validateNetAccess(pipeline.Name, pipeline.NetAccess); err != nil { + return err + } + // Validate auth if present if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil { return err diff --git a/src/internal/limit/ip.go b/src/internal/limit/ip.go new file mode 100644 index 0000000..b76afa0 --- /dev/null +++ b/src/internal/limit/ip.go @@ -0,0 +1,173 @@ +// FILE: src/internal/limit/ip.go +package limit + +import ( + "net" + "strings" + + "logwisp/src/internal/config" + + "github.com/lixenwraith/log" +) + +// IPChecker handles IP-based access control lists +type IPChecker struct { + ipWhitelist []*net.IPNet + ipBlacklist []*net.IPNet + logger *log.Logger +} + +// NewIPChecker creates a new IPChecker. Returns nil if no rules are defined. +func NewIPChecker(cfg *config.NetAccessConfig, logger *log.Logger) *IPChecker { + if cfg == nil || (len(cfg.IPWhitelist) == 0 && len(cfg.IPBlacklist) == 0) { + return nil + } + + c := &IPChecker{ + ipWhitelist: make([]*net.IPNet, 0), + ipBlacklist: make([]*net.IPNet, 0), + logger: logger, + } + + // Parse whitelist entries + for _, cidr := range cfg.IPWhitelist { + if !strings.Contains(cidr, "/") { + cidr = cidr + "/32" + } + + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + // Try parsing as plain IP + if ip := net.ParseIP(cidr); ip != nil { + if ip.To4() != nil { + ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)} + } else { + ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + } + } else { + logger.Warn("msg", "Skipping invalid IP whitelist entry", + "component", "ip_checker", + "entry", cidr, + "error", err) + continue + } + } + c.ipWhitelist = append(c.ipWhitelist, ipNet) + } + + // Parse blacklist entries + for _, cidr := range cfg.IPBlacklist { + if !strings.Contains(cidr, "/") { + cidr = cidr + "/32" + } + + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + // Try parsing as plain IP + if ip := net.ParseIP(cidr); ip != nil { + if ip.To4() != nil { + ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)} + } else { + ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + } + } else { + logger.Warn("msg", "Skipping invalid IP blacklist entry", + "component", "ip_checker", + "entry", cidr, + "error", err) + continue + } + } + c.ipBlacklist = append(c.ipBlacklist, ipNet) + } + + logger.Info("msg", "IP checker initialized", + "component", "ip_checker", + "whitelist_rules", len(c.ipWhitelist), + "blacklist_rules", len(c.ipBlacklist)) + + return c +} + +// IsAllowed validates if a remote address is permitted +func (c *IPChecker) IsAllowed(remoteAddr net.Addr) bool { + if c == nil { + return true // No checker = allow all + } + + // No rules = allow all + if len(c.ipWhitelist) == 0 && len(c.ipBlacklist) == 0 { + return true + } + + // Extract IP from address + var ipStr string + switch addr := remoteAddr.(type) { + case *net.TCPAddr: + ipStr = addr.IP.String() + case *net.UDPAddr: + ipStr = addr.IP.String() + default: + // Try string parsing + addrStr := remoteAddr.String() + host, _, err := net.SplitHostPort(addrStr) + if err != nil { + ipStr = addrStr + } else { + ipStr = host + } + } + + ip := net.ParseIP(ipStr) + if ip == nil { + c.logger.Warn("msg", "Could not parse remote address to IP", + "component", "ip_checker", + "remote_addr", remoteAddr.String()) + return false // Deny unparseable addresses + } + + // Check blacklist first (deny takes precedence) + for _, ipNet := range c.ipBlacklist { + if ipNet.Contains(ip) { + c.logger.Warn("msg", "Blacklisted IP denied", + "component", "ip_checker", + "ip", ipStr, + "rule", ipNet.String()) + return false + } + } + + // If whitelist is configured, IP must be in it + if len(c.ipWhitelist) > 0 { + for _, ipNet := range c.ipWhitelist { + if ipNet.Contains(ip) { + c.logger.Debug("msg", "IP allowed by whitelist", + "component", "ip_checker", + "ip", ipStr, + "rule", ipNet.String()) + return true + } + } + // No whitelist match = deny + c.logger.Warn("msg", "IP not in whitelist", + "component", "ip_checker", + "ip", ipStr) + return false + } + + // No blacklist match + no whitelist configured = allow + return true +} + +// GetStats returns IP checker statistics +func (c *IPChecker) GetStats() map[string]any { + if c == nil { + return map[string]any{"enabled": false} + } + + return map[string]any{ + "enabled": true, + "whitelist_rules": len(c.ipWhitelist), + "blacklist_rules": len(c.ipBlacklist), + } +} \ No newline at end of file diff --git a/src/internal/service/httprouter.go b/src/internal/service/httprouter.go deleted file mode 100644 index 34b852d..0000000 --- a/src/internal/service/httprouter.go +++ /dev/null @@ -1,232 +0,0 @@ -// FILE: logwisp/src/internal/service/httprouter.go -package service - -import ( - "fmt" - "strings" - "sync" - "sync/atomic" - "time" - - "logwisp/src/internal/sink" - - "github.com/lixenwraith/log" - "github.com/valyala/fasthttp" -) - -// HTTPRouter manages HTTP routing for multiple pipelines -type HTTPRouter struct { - service *Service - servers map[int64]*routerServer // port -> server - mu sync.RWMutex - logger *log.Logger - - // Statistics - startTime time.Time - totalRequests atomic.Uint64 - routedRequests atomic.Uint64 - failedRequests atomic.Uint64 -} - -// NewHTTPRouter creates a new HTTP router -func NewHTTPRouter(service *Service, logger *log.Logger) *HTTPRouter { - return &HTTPRouter{ - service: service, - servers: make(map[int64]*routerServer), - startTime: time.Now(), - logger: logger, - } -} - -// RegisterPipeline registers a pipeline's HTTP sinks with the router -func (r *HTTPRouter) RegisterPipeline(pipeline *Pipeline) error { - // Register all HTTP sinks in the pipeline - for _, httpSink := range pipeline.HTTPSinks { - if err := r.registerHTTPSink(pipeline.Name, httpSink); err != nil { - return err - } - } - return nil -} - -// registerHTTPSink registers a single HTTP sink -func (r *HTTPRouter) registerHTTPSink(pipelineName string, httpSink *sink.HTTPSink) error { - // Get port from sink configuration - stats := httpSink.GetStats() - details := stats.Details - port := details["port"].(int64) - - r.mu.Lock() - rs, exists := r.servers[port] - if !exists { - // Create new server for this port - rs = &routerServer{ - port: port, - routes: make(map[string]*routedSink), - router: r, - startTime: time.Now(), - logger: r.logger, - } - rs.server = &fasthttp.Server{ - Handler: rs.requestHandler, - DisableKeepalive: false, - StreamRequestBody: true, - CloseOnShutdown: true, - } - r.servers[port] = rs - - // Startup sync channel - startupDone := make(chan error, 1) - - // Start server in background - go func() { - addr := fmt.Sprintf(":%d", port) - r.logger.Info("msg", "Starting router server", - "component", "http_router", - "port", port) - - // Signal that server is about to start - startupDone <- nil - - if err := rs.server.ListenAndServe(addr); err != nil { - r.logger.Error("msg", "Router server failed", - "component", "http_router", - "port", port, - "error", err) - } - }() - - // Wait for server startup signal with timeout - select { - case err := <-startupDone: - if err != nil { - r.mu.Unlock() - return fmt.Errorf("server startup failed: %w", err) - } - case <-time.After(5 * time.Second): - r.mu.Unlock() - return fmt.Errorf("server startup timeout on port %d", port) - } - } - r.mu.Unlock() - - // Register routes for this sink - rs.routeMu.Lock() - defer rs.routeMu.Unlock() - - // Use pipeline name as path prefix - pathPrefix := "/" + pipelineName - - // Check for conflicts - for existingPath, existing := range rs.routes { - if strings.HasPrefix(pathPrefix, existingPath) || strings.HasPrefix(existingPath, pathPrefix) { - return fmt.Errorf("path conflict: '%s' conflicts with existing pipeline '%s' at '%s'", - pathPrefix, existing.pipelineName, existingPath) - } - } - - // Set the sink to router mode - httpSink.SetRouterMode() - - rs.routes[pathPrefix] = &routedSink{ - pipelineName: pipelineName, - httpSink: httpSink, - } - - r.logger.Info("msg", "Registered pipeline route", - "component", "http_router", - "pipeline", pipelineName, - "path", pathPrefix, - "port", port) - return nil -} - -// UnregisterPipeline removes a pipeline's routes -func (r *HTTPRouter) UnregisterPipeline(pipelineName string) { - r.mu.RLock() - defer r.mu.RUnlock() - - for port, rs := range r.servers { - rs.routeMu.Lock() - for path, route := range rs.routes { - if route.pipelineName == pipelineName { - delete(rs.routes, path) - r.logger.Info("msg", "Unregistered pipeline route", - "component", "http_router", - "pipeline", pipelineName, - "path", path, - "port", port) - } - } - - // Check if server has no more routes - if len(rs.routes) == 0 { - r.logger.Info("msg", "No routes left on port, considering shutdown", - "component", "http_router", - "port", port) - } - rs.routeMu.Unlock() - } -} - -// Shutdown stops all router servers -func (r *HTTPRouter) Shutdown() { - r.logger.Info("msg", "Starting router shutdown...") - - r.mu.Lock() - defer r.mu.Unlock() - - var wg sync.WaitGroup - for port, rs := range r.servers { - wg.Add(1) - go func(p int64, s *routerServer) { - defer wg.Done() - r.logger.Info("msg", "Shutting down server", - "component", "http_router", - "port", p) - if err := s.server.Shutdown(); err != nil { - r.logger.Error("msg", "Error shutting down server", - "component", "http_router", - "port", p, - "error", err) - } - }(port, rs) - } - wg.Wait() - - r.logger.Info("msg", "Router shutdown complete") -} - -// GetStats returns router statistics -func (r *HTTPRouter) GetStats() map[string]any { - r.mu.RLock() - defer r.mu.RUnlock() - - serverStats := make(map[int64]any) - totalRoutes := 0 - - for port, rs := range r.servers { - rs.routeMu.RLock() - routes := make([]string, 0, len(rs.routes)) - for path := range rs.routes { - routes = append(routes, path) - totalRoutes++ - } - rs.routeMu.RUnlock() - - serverStats[port] = map[string]any{ - "routes": routes, - "requests": rs.requests.Load(), - "uptime": int(time.Since(rs.startTime).Seconds()), - } - } - - return map[string]any{ - "uptime_seconds": int64(time.Since(r.startTime).Seconds()), - "total_requests": r.totalRequests.Load(), - "routed_requests": r.routedRequests.Load(), - "failed_requests": r.failedRequests.Load(), - "servers": serverStats, - "total_routes": totalRoutes, - } -} \ No newline at end of file diff --git a/src/internal/service/pipeline.go b/src/internal/service/pipeline.go index 765b576..63f20e0 100644 --- a/src/internal/service/pipeline.go +++ b/src/internal/service/pipeline.go @@ -30,10 +30,6 @@ type Pipeline struct { ctx context.Context cancel context.CancelFunc wg sync.WaitGroup - - // For HTTP sinks in router mode - HTTPSinks []*sink.HTTPSink - TCPSinks []*sink.TCPSink } // PipelineStats contains statistics for a pipeline diff --git a/src/internal/service/routerserver.go b/src/internal/service/routerserver.go deleted file mode 100644 index 4117749..0000000 --- a/src/internal/service/routerserver.go +++ /dev/null @@ -1,192 +0,0 @@ -// FILE: logwisp/src/internal/service/routerserver.go -package service - -import ( - "encoding/json" - "fmt" - "strings" - "sync" - "sync/atomic" - "time" - - "logwisp/src/internal/sink" - "logwisp/src/internal/version" - - "github.com/lixenwraith/log" - "github.com/valyala/fasthttp" -) - -// routedSink represents a sink registered with the router -type routedSink struct { - pipelineName string - httpSink *sink.HTTPSink -} - -// routerServer handles HTTP requests for a specific port -type routerServer struct { - port int64 - server *fasthttp.Server - logger *log.Logger - routes map[string]*routedSink // path prefix -> sink - routeMu sync.RWMutex - router *HTTPRouter - startTime time.Time - requests atomic.Uint64 -} - -func (rs *routerServer) requestHandler(ctx *fasthttp.RequestCtx) { - rs.requests.Add(1) - rs.router.totalRequests.Add(1) - - path := string(ctx.Path()) - remoteAddr := ctx.RemoteAddr().String() - - // Log request for debugging - rs.logger.Debug("msg", "Router request", - "component", "router_server", - "method", string(ctx.Method()), - "path", path, - "remote_addr", remoteAddr) - - // Special case: global status at /status - if path == "/status" { - rs.handleGlobalStatus(ctx) - return - } - - // Find matching route - rs.routeMu.RLock() - var matchedSink *routedSink - var matchedPrefix string - var remainingPath string - - for prefix, route := range rs.routes { - if strings.HasPrefix(path, prefix) { - // Longest prefix match - if len(prefix) > len(matchedPrefix) { - matchedPrefix = prefix - matchedSink = route - remainingPath = strings.TrimPrefix(path, prefix) - // Ensure remaining path starts with / or is empty - if remainingPath != "" && !strings.HasPrefix(remainingPath, "/") { - remainingPath = "/" + remainingPath - } - } - } - } - rs.routeMu.RUnlock() - - if matchedSink == nil { - rs.router.failedRequests.Add(1) - rs.handleNotFound(ctx) - return - } - - rs.router.routedRequests.Add(1) - - // Route to sink's handler - if matchedSink.httpSink != nil { - // Save original path - originalPath := string(ctx.URI().Path()) - - // Rewrite path to remove pipeline prefix - if remainingPath == "" { - // Default to stream path if no remaining path - remainingPath = matchedSink.httpSink.GetStreamPath() - } - - rs.logger.Debug("msg", "Routing request to pipeline", - "component", "router_server", - "pipeline", matchedSink.pipelineName, - "original_path", originalPath, - "remaining_path", remainingPath) - - ctx.URI().SetPath(remainingPath) - matchedSink.httpSink.RouteRequest(ctx) - - // Restore original path - ctx.URI().SetPath(originalPath) - } else { - ctx.SetStatusCode(fasthttp.StatusServiceUnavailable) - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(map[string]string{ - "error": "Pipeline HTTP sink not available", - "pipeline": matchedSink.pipelineName, - }) - } -} - -func (rs *routerServer) handleGlobalStatus(ctx *fasthttp.RequestCtx) { - ctx.SetContentType("application/json") - - rs.routeMu.RLock() - pipelines := make(map[string]any) - for prefix, route := range rs.routes { - pipelineInfo := map[string]any{ - "path_prefix": prefix, - "endpoints": map[string]string{ - "stream": prefix + route.httpSink.GetStreamPath(), - "status": prefix + route.httpSink.GetStatusPath(), - }, - } - - // Get sink stats - sinkStats := route.httpSink.GetStats() - pipelineInfo["sink"] = map[string]any{ - "type": sinkStats.Type, - "total_processed": sinkStats.TotalProcessed, - "active_connections": sinkStats.ActiveConnections, - "details": sinkStats.Details, - } - - pipelines[route.pipelineName] = pipelineInfo - } - rs.routeMu.RUnlock() - - // Get router stats - routerStats := rs.router.GetStats() - - status := map[string]any{ - "service": "LogWisp Router", - "version": version.String(), - "port": rs.port, - "pipelines": pipelines, - "total_pipelines": len(pipelines), - "router": routerStats, - "endpoints": map[string]string{ - "global_status": "/status", - }, - } - - data, _ := json.MarshalIndent(status, "", " ") - ctx.SetBody(data) -} - -func (rs *routerServer) handleNotFound(ctx *fasthttp.RequestCtx) { - ctx.SetStatusCode(fasthttp.StatusNotFound) - ctx.SetContentType("application/json") - - rs.routeMu.RLock() - availableRoutes := make([]string, 0, len(rs.routes)*2+1) - availableRoutes = append(availableRoutes, "/status (global status)") - - for prefix, route := range rs.routes { - if route.httpSink != nil { - availableRoutes = append(availableRoutes, - fmt.Sprintf("%s%s (stream: %s)", prefix, route.httpSink.GetStreamPath(), route.pipelineName), - fmt.Sprintf("%s%s (status: %s)", prefix, route.httpSink.GetStatusPath(), route.pipelineName), - ) - } - } - rs.routeMu.RUnlock() - - response := map[string]any{ - "error": "Not Found", - "requested_path": string(ctx.Path()), - "available_routes": availableRoutes, - "hint": "Use /status for global router status", - } - - data, _ := json.MarshalIndent(response, "", " ") - ctx.SetBody(data) -} \ No newline at end of file diff --git a/src/internal/service/service.go b/src/internal/service/service.go index d54eeff..f9b481f 100644 --- a/src/internal/service/service.go +++ b/src/internal/service/service.go @@ -113,20 +113,12 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error { // Create sinks for i, sinkCfg := range cfg.Sinks { - sinkInst, err := s.createSink(sinkCfg, formatter) // Pass formatter + sinkInst, err := s.createSink(sinkCfg, formatter) if err != nil { pipelineCancel() return fmt.Errorf("failed to create sink[%d]: %w", i, err) } pipeline.Sinks = append(pipeline.Sinks, sinkInst) - - // Track HTTP/TCP sinks for router mode - switch s := sinkInst.(type) { - case *sink.HTTPSink: - pipeline.HTTPSinks = append(pipeline.HTTPSinks, s) - case *sink.TCPSink: - pipeline.TCPSinks = append(pipeline.TCPSinks, s) - } } // Start all sources @@ -145,6 +137,16 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error { } } + // Configure authentication for sinks that support it + for _, sinkInst := range pipeline.Sinks { + if setter, ok := sinkInst.(sink.NetAccessSetter); ok { + setter.SetNetAccessConfig(cfg.NetAccess) + } + if setter, ok := sinkInst.(sink.AuthSetter); ok { + setter.SetAuthConfig(cfg.Auth) + } + } + // Wire sources to sinks through filters s.wirePipeline(pipeline) @@ -152,7 +154,9 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error { pipeline.startStatsUpdater(pipelineCtx) s.pipelines[cfg.Name] = pipeline - s.logger.Info("msg", "Pipeline created successfully", "pipeline", cfg.Name) + s.logger.Info("msg", "Pipeline created successfully", + "pipeline", cfg.Name, + "auth_enabled", cfg.Auth != nil && cfg.Auth.Type != "none") return nil } @@ -268,19 +272,19 @@ func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) switch cfg.Type { case "http": - return sink.NewHTTPSink(cfg.Options, s.logger, formatter) // needs implementation + return sink.NewHTTPSink(cfg.Options, s.logger, formatter) case "tcp": - return sink.NewTCPSink(cfg.Options, s.logger, formatter) // needs implementation + return sink.NewTCPSink(cfg.Options, s.logger, formatter) case "http_client": - return sink.NewHTTPClientSink(cfg.Options, s.logger, formatter) // needs verification + return sink.NewHTTPClientSink(cfg.Options, s.logger, formatter) case "tcp_client": - return sink.NewTCPClientSink(cfg.Options, s.logger, formatter) // needs implementation + return sink.NewTCPClientSink(cfg.Options, s.logger, formatter) case "file": return sink.NewFileSink(cfg.Options, s.logger, formatter) case "stdout": - return sink.NewStdoutSink(cfg.Options, s.logger, formatter) // needs implementation + return sink.NewStdoutSink(cfg.Options, s.logger, formatter) case "stderr": - return sink.NewStderrSink(cfg.Options, s.logger, formatter) // needs implementation + return sink.NewStderrSink(cfg.Options, s.logger, formatter) default: return nil, fmt.Errorf("unknown sink type: %s", cfg.Type) } diff --git a/src/internal/sink/http.go b/src/internal/sink/http.go index dc1bea0..9aefe76 100644 --- a/src/internal/sink/http.go +++ b/src/internal/sink/http.go @@ -11,10 +11,12 @@ import ( "sync/atomic" "time" + "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" "logwisp/src/internal/limit" + "logwisp/src/internal/tls" "logwisp/src/internal/version" "github.com/lixenwraith/log" @@ -35,19 +37,24 @@ type HTTPSink struct { logger *log.Logger formatter format.Formatter + // Security components + authenticator *auth.Authenticator + tlsManager *tls.Manager + authConfig *config.AuthConfig + // Path configuration streamPath string statusPath string - // For router integration - standalone bool - // Net limiting netLimiter *limit.NetLimiter + ipChecker *limit.IPChecker // Statistics totalProcessed atomic.Uint64 lastProcessed atomic.Value // time.Time + authFailures atomic.Uint64 + authSuccesses atomic.Uint64 } // HTTPConfig holds HTTP sink configuration @@ -98,6 +105,32 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo } } + // Extract SSL config + if ssl, ok := options["ssl"].(map[string]any); ok { + cfg.SSL = &config.SSLConfig{} + cfg.SSL.Enabled, _ = ssl["enabled"].(bool) + if certFile, ok := ssl["cert_file"].(string); ok { + cfg.SSL.CertFile = certFile + } + if keyFile, ok := ssl["key_file"].(string); ok { + cfg.SSL.KeyFile = keyFile + } + cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool) + if caFile, ok := ssl["client_ca_file"].(string); ok { + cfg.SSL.ClientCAFile = caFile + } + cfg.SSL.VerifyClientCert, _ = ssl["verify_client_cert"].(bool) + if minVer, ok := ssl["min_version"].(string); ok { + cfg.SSL.MinVersion = minVer + } + if maxVer, ok := ssl["max_version"].(string); ok { + cfg.SSL.MaxVersion = maxVer + } + if ciphers, ok := ssl["cipher_suites"].(string); ok { + cfg.SSL.CipherSuites = ciphers + } + } + // Extract net limit config if rl, ok := options["net_limit"].(map[string]any); ok { cfg.NetLimit = &config.NetLimitConfig{} @@ -132,7 +165,6 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo done: make(chan struct{}), streamPath: cfg.StreamPath, statusPath: cfg.StatusPath, - standalone: true, logger: logger, formatter: formatter, } @@ -151,13 +183,6 @@ func (h *HTTPSink) Input() chan<- core.LogEntry { } func (h *HTTPSink) Start(ctx context.Context) error { - if !h.standalone { - // In router mode, don't start our own server - h.logger.Debug("msg", "HTTP sink in router mode, skipping server start", - "component", "http_sink") - return nil - } - // Create fasthttp adapter for logging fasthttpLogger := compat.NewFastHTTPAdapter(h.logger) @@ -168,6 +193,12 @@ func (h *HTTPSink) Start(ctx context.Context) error { Logger: fasthttpLogger, } + // Configure TLS if enabled + if h.tlsManager != nil { + tlsConfig := h.tlsManager.GetHTTPConfig() + h.server.TLSConfig = tlsConfig + } + addr := fmt.Sprintf(":%d", h.config.Port) // Run server in separate goroutine to avoid blocking @@ -178,7 +209,16 @@ func (h *HTTPSink) Start(ctx context.Context) error { "port", h.config.Port, "stream_path", h.streamPath, "status_path", h.statusPath) - err := h.server.ListenAndServe(addr) + + var err error + if h.tlsManager != nil { + // HTTPS server + err = h.server.ListenAndServeTLS(addr, "", "") + } else { + // HTTP server + err = h.server.ListenAndServe(addr) + } + if err != nil { errChan <- err } @@ -210,8 +250,8 @@ func (h *HTTPSink) Stop() { // Signal all client handlers to stop close(h.done) - // Shutdown HTTP server if in standalone mode - if h.standalone && h.server != nil { + // Shutdown HTTP server + if h.server != nil { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() h.server.ShutdownWithContext(ctx) @@ -231,6 +271,18 @@ func (h *HTTPSink) GetStats() SinkStats { netLimitStats = h.netLimiter.GetStats() } + var authStats map[string]any + if h.authenticator != nil { + authStats = h.authenticator.GetStats() + authStats["failures"] = h.authFailures.Load() + authStats["successes"] = h.authSuccesses.Load() + } + + var tlsStats map[string]any + if h.tlsManager != nil { + tlsStats = h.tlsManager.GetStats() + } + return SinkStats{ Type: "http", TotalProcessed: h.totalProcessed.Load(), @@ -245,42 +297,83 @@ func (h *HTTPSink) GetStats() SinkStats { "status": h.statusPath, }, "net_limit": netLimitStats, + "auth": authStats, + "tls": tlsStats, }, } } -// SetRouterMode configures the sink for use with a router -func (h *HTTPSink) SetRouterMode() { - h.standalone = false - h.logger.Debug("msg", "HTTP sink set to router mode", - "component", "http_sink") -} - -// RouteRequest handles a request from the router -func (h *HTTPSink) RouteRequest(ctx *fasthttp.RequestCtx) { - h.requestHandler(ctx) -} - func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { - // Check net limit first remoteAddr := ctx.RemoteAddr().String() - if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed { - ctx.SetStatusCode(int(statusCode)) - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(map[string]any{ - "error": message, - "retry_after": "60", // seconds - }) - return + + // Check IP access control + if h.ipChecker != nil { + if !h.ipChecker.IsAllowed(ctx.RemoteAddr()) { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetContentType("text/plain") + ctx.SetBodyString("Forbidden") + return + } + } + + // Check net limit + if h.netLimiter != nil { + if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed { + ctx.SetStatusCode(int(statusCode)) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(map[string]any{ + "error": message, + "retry_after": "60", // seconds + }) + return + } } path := string(ctx.Path()) + // Status endpoint doesn't require auth + if path == h.statusPath { + h.handleStatus(ctx) + return + } + + // Authenticate request + var session *auth.Session + if h.authenticator != nil { + authHeader := string(ctx.Request.Header.Peek("Authorization")) + var err error + session, err = h.authenticator.AuthenticateHTTP(authHeader, remoteAddr) + if err != nil { + h.authFailures.Add(1) + h.logger.Warn("msg", "Authentication failed", + "component", "http_sink", + "remote_addr", remoteAddr, + "error", err) + + // Return 401 with WWW-Authenticate header + ctx.SetStatusCode(fasthttp.StatusUnauthorized) + if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil { + realm := h.authConfig.BasicAuth.Realm + if realm == "" { + realm = "LogWisp" + } + ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm)) + } else if h.authConfig.Type == "bearer" { + ctx.Response.Header.Set("WWW-Authenticate", "Bearer") + } + + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(map[string]string{ + "error": "Authentication required", + }) + return + } + h.authSuccesses.Add(1) + } + switch path { case h.streamPath: - h.handleStream(ctx) - case h.statusPath: - h.handleStatus(ctx) + h.handleStream(ctx, session) default: ctx.SetStatusCode(fasthttp.StatusNotFound) ctx.SetContentType("application/json") @@ -292,7 +385,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { } } -func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) { +func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) { // Track connection for net limiting remoteAddr := ctx.RemoteAddr().String() if h.netLimiter != nil { @@ -330,7 +423,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) { case <-h.done: return default: - // Drop if client buffer full, may flood logging for slow client + // Drop if client buffer full h.logger.Debug("msg", "Dropped entry for slow client", "component", "http_sink", "remote_addr", remoteAddr) @@ -348,6 +441,8 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) { newCount := h.activeClients.Add(1) h.logger.Debug("msg", "HTTP client connected", "remote_addr", remoteAddr, + "username", session.Username, + "auth_method", session.Method, "active_clients", newCount) h.wg.Add(1) @@ -356,6 +451,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) { newCount := h.activeClients.Add(-1) h.logger.Debug("msg", "HTTP client disconnected", "remote_addr", remoteAddr, + "username", session.Username, "active_clients", newCount) h.wg.Done() }() @@ -364,12 +460,15 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) { clientID := fmt.Sprintf("%d", time.Now().UnixNano()) connectionInfo := map[string]any{ "client_id": clientID, + "username": session.Username, + "auth_method": session.Method, "stream_path": h.streamPath, "status_path": h.statusPath, "buffer_size": h.config.BufferSize, + "tls": h.tlsManager != nil, } data, _ := json.Marshal(connectionInfo) - fmt.Fprintf(w, "event: connected\ndata: %s\n", data) + fmt.Fprintf(w, "event: connected\ndata: %s\n\n", data) w.Flush() var ticker *time.Ticker @@ -402,6 +501,13 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) { } case <-tickerChan: + // Validate session is still active + if h.authenticator != nil && !h.authenticator.ValidateSession(session.ID) { + fmt.Fprintf(w, "event: disconnect\ndata: {\"reason\":\"session_expired\"}\n\n") + w.Flush() + return + } + heartbeatEntry := h.createHeartbeatEntry() if err := h.formatEntryForSSE(w, heartbeatEntry); err != nil { h.logger.Error("msg", "Failed to format heartbeat", @@ -437,8 +543,10 @@ func (h *HTTPSink) formatEntryForSSE(w *bufio.Writer, entry core.LogEntry) error lines := bytes.Split(formatted, []byte{'\n'}) for _, line := range lines { // SSE needs "data: " prefix for each line + // TODO: validate above, is 'data: ' really necessary? make it optional if it works without it? fmt.Fprintf(w, "data: %s\n", line) } + fmt.Fprintf(w, "\n") // Empty line to terminate event return nil } @@ -478,6 +586,26 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { } } + var authStats any + if h.authenticator != nil { + authStats = h.authenticator.GetStats() + authStats.(map[string]any)["failures"] = h.authFailures.Load() + authStats.(map[string]any)["successes"] = h.authSuccesses.Load() + } else { + authStats = map[string]any{ + "enabled": false, + } + } + + var tlsStats any + if h.tlsManager != nil { + tlsStats = h.tlsManager.GetStats() + } else { + tlsStats = map[string]any{ + "enabled": false, + } + } + status := map[string]any{ "service": "LogWisp", "version": version.Short(), @@ -487,7 +615,6 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { "active_clients": h.activeClients.Load(), "buffer_size": h.config.BufferSize, "uptime_seconds": int(time.Since(h.startTime).Seconds()), - "mode": map[string]bool{"standalone": h.standalone, "router": !h.standalone}, }, "endpoints": map[string]string{ "transport": h.streamPath, @@ -499,11 +626,15 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { "interval": h.config.Heartbeat.IntervalSeconds, "format": h.config.Heartbeat.Format, }, - "ssl": map[string]bool{ - "enabled": h.config.SSL != nil && h.config.SSL.Enabled, - }, + "tls": tlsStats, + "auth": authStats, "net_limit": netLimitStats, }, + "statistics": map[string]any{ + "total_processed": h.totalProcessed.Load(), + "auth_failures": h.authFailures.Load(), + "auth_successes": h.authSuccesses.Load(), + }, } data, _ := json.Marshal(status) @@ -523,4 +654,34 @@ func (h *HTTPSink) GetStreamPath() string { // GetStatusPath returns the configured status endpoint path func (h *HTTPSink) GetStatusPath() string { return h.statusPath +} + +func (h *HTTPSink) SetNetAccessConfig(cfg *config.NetAccessConfig) { + h.ipChecker = limit.NewIPChecker(cfg, h.logger) + if h.ipChecker != nil { + h.logger.Info("msg", "IP access control configured for HTTP sink", + "component", "http_sink") + } +} + +// SetAuthConfig configures http sink authentication +func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) { + if authCfg == nil || authCfg.Type == "none" { + return + } + + h.authConfig = authCfg + authenticator, err := auth.New(authCfg, h.logger) + if err != nil { + h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink", + "component", "http_sink", + "error", err) + // Continue without auth + return + } + h.authenticator = authenticator + + h.logger.Info("msg", "Authentication configured for HTTP sink", + "component", "http_sink", + "auth_type", authCfg.Type) } \ No newline at end of file diff --git a/src/internal/sink/sink.go b/src/internal/sink/sink.go index cb431f5..7533d89 100644 --- a/src/internal/sink/sink.go +++ b/src/internal/sink/sink.go @@ -5,6 +5,7 @@ import ( "context" "time" + "logwisp/src/internal/config" "logwisp/src/internal/core" ) @@ -31,4 +32,14 @@ type SinkStats struct { StartTime time.Time LastProcessed time.Time Details map[string]any +} + +// NetAccessSetter is an interface for sinks that can accept network access configuration +type NetAccessSetter interface { + SetNetAccessConfig(cfg *config.NetAccessConfig) +} + +// AuthSetter is an interface for sinks that can accept an AuthConfig. +type AuthSetter interface { + SetAuthConfig(auth *config.AuthConfig) } \ No newline at end of file diff --git a/src/internal/sink/tcp.go b/src/internal/sink/tcp.go index e31fb0f..8808796 100644 --- a/src/internal/sink/tcp.go +++ b/src/internal/sink/tcp.go @@ -2,18 +2,22 @@ package sink import ( + "bytes" "context" "encoding/json" "fmt" "net" + "strings" "sync" "sync/atomic" "time" + "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" "logwisp/src/internal/limit" + "logwisp/src/internal/tls" "github.com/lixenwraith/log" "github.com/lixenwraith/log/compat" @@ -32,12 +36,20 @@ type TCPSink struct { engineMu sync.Mutex wg sync.WaitGroup netLimiter *limit.NetLimiter + ipChecker *limit.IPChecker logger *log.Logger formatter format.Formatter + // Security components + authenticator *auth.Authenticator + tlsManager *tls.Manager + authConfig *config.AuthConfig + // Statistics totalProcessed atomic.Uint64 lastProcessed atomic.Value // time.Time + authFailures atomic.Uint64 + authSuccesses atomic.Uint64 } // TCPConfig holds TCP sink configuration @@ -78,6 +90,32 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For } } + // Extract SSL config + if ssl, ok := options["ssl"].(map[string]any); ok { + cfg.SSL = &config.SSLConfig{} + cfg.SSL.Enabled, _ = ssl["enabled"].(bool) + if certFile, ok := ssl["cert_file"].(string); ok { + cfg.SSL.CertFile = certFile + } + if keyFile, ok := ssl["key_file"].(string); ok { + cfg.SSL.KeyFile = keyFile + } + cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool) + if caFile, ok := ssl["client_ca_file"].(string); ok { + cfg.SSL.ClientCAFile = caFile + } + cfg.SSL.VerifyClientCert, _ = ssl["verify_client_cert"].(bool) + if minVer, ok := ssl["min_version"].(string); ok { + cfg.SSL.MinVersion = minVer + } + if maxVer, ok := ssl["max_version"].(string); ok { + cfg.SSL.MaxVersion = maxVer + } + if ciphers, ok := ssl["cipher_suites"].(string); ok { + cfg.SSL.CipherSuites = ciphers + } + } + // Extract net limit config if rl, ok := options["net_limit"].(map[string]any); ok { cfg.NetLimit = &config.NetLimitConfig{} @@ -115,6 +153,7 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For } t.lastProcessed.Store(time.Time{}) + // Initialize net limiter if cfg.NetLimit != nil && cfg.NetLimit.Enabled { t.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger) } @@ -127,7 +166,10 @@ func (t *TCPSink) Input() chan<- core.LogEntry { } func (t *TCPSink) Start(ctx context.Context) error { - t.server = &tcpServer{sink: t} + t.server = &tcpServer{ + sink: t, + clients: make(map[gnet.Conn]*tcpClient), + } // Start log broadcast loop t.wg.Add(1) @@ -136,24 +178,39 @@ func (t *TCPSink) Start(ctx context.Context) error { t.broadcastLoop(ctx) }() - // Configure gnet + // Configure gnet options addr := fmt.Sprintf("tcp://:%d", t.config.Port) // Create a gnet adapter using the existing logger instance gnetLogger := compat.NewGnetAdapter(t.logger) + var opts []gnet.Option + opts = append(opts, + gnet.WithLogger(gnetLogger), + gnet.WithMulticore(true), + gnet.WithReusePort(true), + ) + + // Add TLS if configured + if t.tlsManager != nil { + // tlsConfig := t.tlsManager.GetTCPConfig() + // TODO: tlsConfig is not used, wrapper to be implemented, non-TLS stream to be available without wrapper + // ☢ SECURITY: gnet doesn't support TLS natively - would need wrapper + // This is a limitation that requires implementing TLS at application layer + t.logger.Warn("msg", "TLS configured but gnet doesn't support native TLS", + "component", "tcp_sink", + "workaround", "Use stunnel or nginx TCP proxy for TLS termination") + } + // Start gnet server errChan := make(chan error, 1) go func() { t.logger.Info("msg", "Starting TCP server", "component", "tcp_sink", - "port", t.config.Port) + "port", t.config.Port, + "auth", t.authenticator != nil) - err := gnet.Run(t.server, addr, - gnet.WithLogger(gnetLogger), - gnet.WithMulticore(true), - gnet.WithReusePort(true), - ) + err := gnet.Run(t.server, addr, opts...) if err != nil { t.logger.Error("msg", "TCP server failed", "component", "tcp_sink", @@ -219,6 +276,18 @@ func (t *TCPSink) GetStats() SinkStats { netLimitStats = t.netLimiter.GetStats() } + var authStats map[string]any + if t.authenticator != nil { + authStats = t.authenticator.GetStats() + authStats["failures"] = t.authFailures.Load() + authStats["successes"] = t.authSuccesses.Load() + } + + var tlsStats map[string]any + if t.tlsManager != nil { + tlsStats = t.tlsManager.GetStats() + } + return SinkStats{ Type: "tcp", TotalProcessed: t.totalProcessed.Load(), @@ -229,6 +298,8 @@ func (t *TCPSink) GetStats() SinkStats { "port": t.config.Port, "buffer_size": t.config.BufferSize, "net_limit": netLimitStats, + "auth": authStats, + "tls": tlsStats, }, } } @@ -263,11 +334,14 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) { continue } - t.server.connections.Range(func(key, value any) bool { - conn := key.(gnet.Conn) - conn.AsyncWrite(data, nil) - return true - }) + // Broadcast only to authenticated clients + t.server.mu.RLock() + for conn, client := range t.server.clients { + if client.authenticated { + conn.AsyncWrite(data, nil) + } + } + t.server.mu.RUnlock() case <-tickerChan: heartbeatEntry := t.createHeartbeatEntry() @@ -279,11 +353,21 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) { continue } - t.server.connections.Range(func(key, value any) bool { - conn := key.(gnet.Conn) - conn.AsyncWrite(data, nil) - return true - }) + t.server.mu.RLock() + for conn, client := range t.server.clients { + if client.authenticated { + // Validate session is still active + if t.authenticator != nil && client.session != nil { + if !t.authenticator.ValidateSession(client.session.ID) { + // Session expired, close connection + conn.Close() + continue + } + } + conn.AsyncWrite(data, nil) + } + } + t.server.mu.RUnlock() case <-t.done: return @@ -320,11 +404,21 @@ func (t *TCPSink) GetActiveConnections() int64 { return t.activeConns.Load() } -// tcpServer handles gnet events +// tcpClient represents a connected TCP client with auth state +type tcpClient struct { + conn gnet.Conn + buffer bytes.Buffer + authenticated bool + session *auth.Session + authTimeout time.Time +} + +// tcpServer handles gnet events with authentication type tcpServer struct { gnet.BuiltinEventEngine - sink *TCPSink - connections sync.Map + sink *TCPSink + clients map[gnet.Conn]*tcpClient + mu sync.RWMutex } func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action { @@ -343,9 +437,17 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { remoteAddr := c.RemoteAddr().String() s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr) + // Check IP access control first + if s.sink.ipChecker != nil { + if !s.sink.ipChecker.IsAllowed(c.RemoteAddr()) { + s.sink.logger.Warn("msg", "TCP connection denied by IP filter", + "remote_addr", remoteAddr) + return nil, gnet.Close + } + } + // Check net limit if s.sink.netLimiter != nil { - // Parse the remote address to get proper net.Addr remoteStr := c.RemoteAddr().String() tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr) if err != nil { @@ -358,7 +460,6 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { if !s.sink.netLimiter.CheckTCP(tcpAddr) { s.sink.logger.Warn("msg", "TCP connection net limited", "remote_addr", remoteAddr) - // Silently close connection when net limited return nil, gnet.Close } @@ -366,24 +467,43 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { s.sink.netLimiter.AddConnection(remoteStr) } - s.connections.Store(c, struct{}{}) + // Create client state + client := &tcpClient{ + conn: c, + authenticated: s.sink.authenticator == nil, // No auth = auto authenticated + authTimeout: time.Now().Add(30 * time.Second), // 30s to authenticate + } + + s.mu.Lock() + s.clients[c] = client + s.mu.Unlock() newCount := s.sink.activeConns.Add(1) s.sink.logger.Debug("msg", "TCP connection opened", "remote_addr", remoteAddr, - "active_connections", newCount) + "active_connections", newCount, + "requires_auth", s.sink.authenticator != nil) + + // Send auth prompt if authentication is required + if s.sink.authenticator != nil { + authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH \nMethods: basic, token\n") + return authPrompt, gnet.None + } return nil, gnet.None } func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action { - s.connections.Delete(c) - remoteAddr := c.RemoteAddr().String() + // Remove client state + s.mu.Lock() + delete(s.clients, c) + s.mu.Unlock() + // Remove connection tracking if s.sink.netLimiter != nil { - s.sink.netLimiter.RemoveConnection(c.RemoteAddr().String()) + s.sink.netLimiter.RemoveConnection(remoteAddr) } newCount := s.sink.activeConns.Add(-1) @@ -395,7 +515,114 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action { } func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { - // We don't expect input from clients, just discard + s.mu.RLock() + client, exists := s.clients[c] + s.mu.RUnlock() + + if !exists { + return gnet.Close + } + + // Check auth timeout + if !client.authenticated && time.Now().After(client.authTimeout) { + s.sink.logger.Warn("msg", "Authentication timeout", + "remote_addr", c.RemoteAddr().String()) + c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil) + return gnet.Close + } + + // Read all available data + data, err := c.Next(-1) + if err != nil { + s.sink.logger.Error("msg", "Error reading from connection", + "component", "tcp_sink", + "error", err) + return gnet.Close + } + + // If not authenticated, expect auth command + if !client.authenticated { + client.buffer.Write(data) + + // Look for complete auth line + if line, err := client.buffer.ReadBytes('\n'); err == nil { + line = bytes.TrimSpace(line) + + // Parse AUTH command: AUTH + parts := strings.SplitN(string(line), " ", 3) + if len(parts) != 3 || parts[0] != "AUTH" { + c.AsyncWrite([]byte("ERROR: Invalid auth format\n"), nil) + return gnet.None + } + + // Authenticate + session, err := s.sink.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String()) + if err != nil { + s.sink.authFailures.Add(1) + s.sink.logger.Warn("msg", "TCP authentication failed", + "remote_addr", c.RemoteAddr().String(), + "method", parts[1], + "error", err) + c.AsyncWrite([]byte(fmt.Sprintf("AUTH FAILED: %v\n", err)), nil) + return gnet.Close + } + + // Authentication successful + s.sink.authSuccesses.Add(1) + s.mu.Lock() + client.authenticated = true + client.session = session + s.mu.Unlock() + + s.sink.logger.Info("msg", "TCP client authenticated", + "remote_addr", c.RemoteAddr().String(), + "username", session.Username, + "method", session.Method) + + c.AsyncWrite([]byte("AUTH OK\n"), nil) + + // Clear buffer after auth + client.buffer.Reset() + } + return gnet.None + } + + // Authenticated clients shouldn't send data, just discard c.Discard(-1) return gnet.None +} + +// SetAuthConfig configures tcp sink authentication +func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) { + if authCfg == nil || authCfg.Type == "none" { + return + } + + t.authConfig = authCfg + authenticator, err := auth.New(authCfg, t.logger) + if err != nil { + t.logger.Error("msg", "Failed to initialize authenticator for TCP sink", + "component", "tcp_sink", + "error", err) + return + } + t.authenticator = authenticator + + // Initialize TLS manager if SSL is configured + if t.config.SSL != nil && t.config.SSL.Enabled { + tlsManager, err := tls.New(t.config.SSL, t.logger) + if err != nil { + t.logger.Error("msg", "Failed to create TLS manager", + "component", "tcp_sink", + "error", err) + // Continue without TLS + return + } + t.tlsManager = tlsManager + } + + t.logger.Info("msg", "Authentication configured for TCP sink", + "component", "tcp_sink", + "auth_type", authCfg.Type, + "tls_enabled", t.tlsManager != nil) } \ No newline at end of file diff --git a/src/internal/tls/manager.go b/src/internal/tls/manager.go new file mode 100644 index 0000000..b6067d0 --- /dev/null +++ b/src/internal/tls/manager.go @@ -0,0 +1,249 @@ +// FILE: logwisp/src/internal/tls/manager.go +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "strings" + + "logwisp/src/internal/config" + + "github.com/lixenwraith/log" +) + +// Manager handles TLS configuration for servers +type Manager struct { + config *config.SSLConfig + tlsConfig *tls.Config + logger *log.Logger +} + +// New creates a TLS configuration from SSL config +func New(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) { + if cfg == nil || !cfg.Enabled { + return nil, nil + } + + m := &Manager{ + config: cfg, + logger: logger, + } + + // Load certificate and key + cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load cert/key: %w", err) + } + + // Create base TLS config + m.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: parseTLSVersion(cfg.MinVersion, tls.VersionTLS12), + MaxVersion: parseTLSVersion(cfg.MaxVersion, tls.VersionTLS13), + } + + // Configure cipher suites if specified + if cfg.CipherSuites != "" { + m.tlsConfig.CipherSuites = parseCipherSuites(cfg.CipherSuites) + } else { + // Use secure defaults + m.tlsConfig.CipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + } + } + + // Configure client authentication (mTLS) + if cfg.ClientAuth { + if cfg.VerifyClientCert { + m.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } else { + m.tlsConfig.ClientAuth = tls.RequireAnyClientCert + } + + // Load client CA if specified + if cfg.ClientCAFile != "" { + caCert, err := os.ReadFile(cfg.ClientCAFile) + if err != nil { + return nil, fmt.Errorf("failed to read client CA: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse client CA certificate") + } + m.tlsConfig.ClientCAs = caCertPool + } + } + + // Set secure defaults + m.tlsConfig.PreferServerCipherSuites = true + m.tlsConfig.SessionTicketsDisabled = false + m.tlsConfig.Renegotiation = tls.RenegotiateNever + + logger.Info("msg", "TLS manager initialized", + "component", "tls", + "min_version", cfg.MinVersion, + "max_version", cfg.MaxVersion, + "client_auth", cfg.ClientAuth, + "cipher_count", len(m.tlsConfig.CipherSuites)) + + return m, nil +} + +// GetConfig returns the TLS configuration +func (m *Manager) GetConfig() *tls.Config { + if m == nil { + return nil + } + // Return a clone to prevent modification + return m.tlsConfig.Clone() +} + +// GetHTTPConfig returns TLS config suitable for HTTP servers +func (m *Manager) GetHTTPConfig() *tls.Config { + if m == nil { + return nil + } + + cfg := m.tlsConfig.Clone() + // Enable HTTP/2 + cfg.NextProtos = []string{"h2", "http/1.1"} + return cfg +} + +// GetTCPConfig returns TLS config for raw TCP connections +func (m *Manager) GetTCPConfig() *tls.Config { + if m == nil { + return nil + } + + cfg := m.tlsConfig.Clone() + // No ALPN for raw TCP + cfg.NextProtos = nil + return cfg +} + +// ValidateClientCert validates a client certificate for mTLS +func (m *Manager) ValidateClientCert(rawCerts [][]byte) error { + if m == nil || !m.config.ClientAuth { + return nil + } + + if len(rawCerts) == 0 { + return fmt.Errorf("no client certificate provided") + } + + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return fmt.Errorf("failed to parse client certificate: %w", err) + } + + // Verify against CA if configured + if m.tlsConfig.ClientCAs != nil { + opts := x509.VerifyOptions{ + Roots: m.tlsConfig.ClientCAs, + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + // Add any intermediate certs + for i := 1; i < len(rawCerts); i++ { + intermediate, err := x509.ParseCertificate(rawCerts[i]) + if err != nil { + continue + } + opts.Intermediates.AddCert(intermediate) + } + + if _, err := cert.Verify(opts); err != nil { + return fmt.Errorf("client certificate verification failed: %w", err) + } + } + + m.logger.Debug("msg", "Client certificate validated", + "component", "tls", + "subject", cert.Subject.String(), + "serial", cert.SerialNumber.String()) + + return nil +} + +func parseTLSVersion(version string, defaultVersion uint16) uint16 { + switch strings.ToUpper(version) { + case "TLS1.0", "TLS10": + return tls.VersionTLS10 + case "TLS1.1", "TLS11": + return tls.VersionTLS11 + case "TLS1.2", "TLS12": + return tls.VersionTLS12 + case "TLS1.3", "TLS13": + return tls.VersionTLS13 + default: + return defaultVersion + } +} + +func parseCipherSuites(suites string) []uint16 { + var result []uint16 + + // Map of cipher suite names to IDs + suiteMap := map[string]uint16{ + // TLS 1.2 ECDHE suites (preferred) + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + + // RSA suites (less preferred) + "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + } + + for _, suite := range strings.Split(suites, ",") { + suite = strings.TrimSpace(suite) + if id, ok := suiteMap[suite]; ok { + result = append(result, id) + } + } + + return result +} + +// GetStats returns TLS statistics +func (m *Manager) GetStats() map[string]any { + if m == nil { + return map[string]any{"enabled": false} + } + + return map[string]any{ + "enabled": true, + "min_version": tlsVersionString(m.tlsConfig.MinVersion), + "max_version": tlsVersionString(m.tlsConfig.MaxVersion), + "client_auth": m.config.ClientAuth, + "cipher_suites": len(m.tlsConfig.CipherSuites), + } +} + +func tlsVersionString(version uint16) string { + switch version { + case tls.VersionTLS10: + return "TLS1.0" + case tls.VersionTLS11: + return "TLS1.1" + case tls.VersionTLS12: + return "TLS1.2" + case tls.VersionTLS13: + return "TLS1.3" + default: + return fmt.Sprintf("0x%04x", version) + } +} \ No newline at end of file