v0.4.0 authentication added and router mode removed

This commit is contained in:
2025-09-06 06:28:56 -04:00
parent ea75c4afed
commit 4248d399b3
26 changed files with 1527 additions and 620 deletions

1
.gitignore vendored
View File

@ -7,5 +7,6 @@ cert
bin bin
script script
build build
test
*.log *.log
*.toml *.toml

View File

@ -18,7 +18,6 @@ Examples:
```bash ```bash
LOGWISP_CONFIG_FILE=/etc/logwisp/config.toml LOGWISP_CONFIG_FILE=/etc/logwisp/config.toml
LOGWISP_CONFIG_DIR=/etc/logwisp LOGWISP_CONFIG_DIR=/etc/logwisp
LOGWISP_ROUTER=true
LOGWISP_BACKGROUND=true LOGWISP_BACKGROUND=true
LOGWISP_QUIET=true LOGWISP_QUIET=true
LOGWISP_DISABLE_STATUS_REPORTER=true LOGWISP_DISABLE_STATUS_REPORTER=true
@ -221,7 +220,6 @@ LOGWISP_PIPELINES_0_SINKS_0_OPTIONS_TARGET=stdout
#!/usr/bin/env bash #!/usr/bin/env bash
# General settings # General settings
export LOGWISP_ROUTER=true
export LOGWISP_DISABLE_STATUS_REPORTER=false export LOGWISP_DISABLE_STATUS_REPORTER=false
# Logging # Logging

13
go.mod
View File

@ -1,12 +1,15 @@
module logwisp module logwisp
go 1.24.5 go 1.25.1
require ( require (
github.com/lixenwraith/config v0.0.0-20250901201021-59a461e31cd4 github.com/golang-jwt/jwt/v5 v5.3.0
github.com/lixenwraith/log v0.0.0-20250722012845-16a3079e46e2 github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3
github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208
github.com/panjf2000/gnet/v2 v2.9.3 github.com/panjf2000/gnet/v2 v2.9.3
github.com/valyala/fasthttp v1.65.0 github.com/valyala/fasthttp v1.65.0
golang.org/x/crypto v0.42.0
golang.org/x/term v0.35.0
) )
require ( require (
@ -19,8 +22,8 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect go.uber.org/zap v1.27.0 // indirect
golang.org/x/sync v0.16.0 // indirect golang.org/x/sync v0.17.0 // indirect
golang.org/x/sys v0.35.0 // indirect golang.org/x/sys v0.36.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

22
go.sum
View File

@ -6,12 +6,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw= github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw=
github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw= github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/lixenwraith/config v0.0.0-20250901201021-59a461e31cd4 h1:SxqXt6J7ZLA39SP4zvJU0Jv3GbXLzM5iB7cgk5d7Pe4= github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0=
github.com/lixenwraith/config v0.0.0-20250901201021-59a461e31cd4/go.mod h1:l+1PZ8JsohLAXOJKu5loFa+zCdOSb/lXf3JUwa5ST/4= github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
github.com/lixenwraith/log v0.0.0-20250722012845-16a3079e46e2 h1:nP/12l+gKkZnZRoM3Vy4iT2anBQm1jCtrppyZq9pcq4= github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208 h1:IB1O/HLv9VR/4mL1Tkjlr91lk+r8anP6bab7rYdS/oE=
github.com/lixenwraith/log v0.0.0-20250722012845-16a3079e46e2/go.mod h1:sLCRfKeLInCj2LcMnAo2knULwfszU8QPuIFOQ8crcFo= github.com/lixenwraith/log v0.0.0-20250908085352-2df52dfb9208/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek= github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
github.com/panjf2000/gnet/v2 v2.9.3 h1:auV3/A9Na3jiBDmYAAU00rPhFKnsAI+TnI1F7YUJMHQ= github.com/panjf2000/gnet/v2 v2.9.3 h1:auV3/A9Na3jiBDmYAAU00rPhFKnsAI+TnI1F7YUJMHQ=
@ -32,10 +34,14 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=

110
src/cmd/auth-gen/main.go Normal file
View File

@ -0,0 +1,110 @@
// FILE: logwisp/src/cmd/auth-gen/main.go
package main
import (
"crypto/rand"
"encoding/base64"
"flag"
"fmt"
"os"
"syscall"
"golang.org/x/crypto/bcrypt"
"golang.org/x/term"
)
func main() {
var (
username = flag.String("u", "", "Username for basic auth")
password = flag.String("p", "", "Password to hash (will prompt if not provided)")
cost = flag.Int("c", 10, "Bcrypt cost (10-31)")
genToken = flag.Bool("t", false, "Generate random bearer token")
tokenLen = flag.Int("l", 32, "Token length in bytes")
)
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "LogWisp Authentication Utility\n\n")
fmt.Fprintf(os.Stderr, "Usage:\n")
fmt.Fprintf(os.Stderr, " Generate bcrypt hash: %s -u <username> [-p <password>]\n", os.Args[0])
fmt.Fprintf(os.Stderr, " Generate bearer token: %s -t [-l <length>]\n", os.Args[0])
fmt.Fprintf(os.Stderr, "\nOptions:\n")
flag.PrintDefaults()
}
flag.Parse()
if *genToken {
generateToken(*tokenLen)
return
}
if *username == "" {
fmt.Fprintf(os.Stderr, "Error: Username required for basic auth\n")
flag.Usage()
os.Exit(1)
}
// Get password
pass := *password
if pass == "" {
pass = promptPassword("Enter password: ")
confirm := promptPassword("Confirm password: ")
if pass != confirm {
fmt.Fprintf(os.Stderr, "Error: Passwords don't match\n")
os.Exit(1)
}
}
// Generate bcrypt hash
hash, err := bcrypt.GenerateFromPassword([]byte(pass), *cost)
if err != nil {
fmt.Fprintf(os.Stderr, "Error generating hash: %v\n", err)
os.Exit(1)
}
// Output TOML config format
fmt.Println("\n# Add to logwisp.toml under [[pipelines.auth.basic_auth.users]]:")
fmt.Printf("[[pipelines.auth.basic_auth.users]]\n")
fmt.Printf("username = \"%s\"\n", *username)
fmt.Printf("password_hash = \"%s\"\n", string(hash))
// Also output for users file format
fmt.Println("\n# Or add to users file:")
fmt.Printf("%s:%s\n", *username, string(hash))
}
func promptPassword(prompt string) string {
fmt.Fprint(os.Stderr, prompt)
password, err := term.ReadPassword(int(syscall.Stdin))
fmt.Fprintln(os.Stderr)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading password: %v\n", err)
os.Exit(1)
}
return string(password)
}
func generateToken(length int) {
if length < 16 {
fmt.Fprintf(os.Stderr, "Warning: Token length < 16 bytes is insecure\n")
}
token := make([]byte, length)
if _, err := rand.Read(token); err != nil {
fmt.Fprintf(os.Stderr, "Error generating token: %v\n", err)
os.Exit(1)
}
// Output in various formats
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
hex := fmt.Sprintf("%x", token)
fmt.Println("\n# Add to logwisp.toml under [pipelines.auth.bearer_auth]:")
fmt.Printf("tokens = [\"%s\"]\n", b64)
fmt.Println("\n# Alternative hex encoding:")
fmt.Printf("# tokens = [\"%s\"]\n", hex)
fmt.Printf("\n# Token (base64): %s\n", b64)
fmt.Printf("# Token (hex): %s\n", hex)
}

View File

@ -14,17 +14,10 @@ import (
) )
// bootstrapService creates and initializes the log transport service // bootstrapService creates and initializes the log transport service
func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service, *service.HTTPRouter, error) { func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service, error) {
// Create service with logger dependency injection // Create service with logger dependency injection
svc := service.New(ctx, logger) svc := service.New(ctx, logger)
// Create HTTP router if requested
var router *service.HTTPRouter
if cfg.UseRouter {
router = service.NewHTTPRouter(svc, logger)
logger.Info("msg", "HTTP router mode enabled")
}
// Initialize pipelines // Initialize pipelines
successCount := 0 successCount := 0
for _, pipelineCfg := range cfg.Pipelines { for _, pipelineCfg := range cfg.Pipelines {
@ -37,32 +30,19 @@ func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service
"error", err) "error", err)
continue continue
} }
// If using router mode, register HTTP sinks
if cfg.UseRouter {
pipeline, err := svc.GetPipeline(pipelineCfg.Name)
if err == nil && len(pipeline.HTTPSinks) > 0 {
if err := router.RegisterPipeline(pipeline); err != nil {
logger.Error("msg", "Failed to register pipeline with router",
"pipeline", pipelineCfg.Name,
"error", err)
}
}
}
successCount++ successCount++
displayPipelineEndpoints(pipelineCfg, cfg.UseRouter) displayPipelineEndpoints(pipelineCfg)
} }
if successCount == 0 { if successCount == 0 {
return nil, nil, fmt.Errorf("no pipelines successfully started (attempted %d)", len(cfg.Pipelines)) return nil, fmt.Errorf("no pipelines successfully started (attempted %d)", len(cfg.Pipelines))
} }
logger.Info("msg", "LogWisp started", logger.Info("msg", "LogWisp started",
"version", version.Short(), "version", version.Short(),
"pipelines", successCount) "pipelines", successCount)
return svc, router, nil return svc, nil
} }
// initializeLogger sets up the logger based on configuration // initializeLogger sets up the logger based on configuration

View File

@ -16,7 +16,6 @@ Application Control:
-v, --version Display version information and exit. -v, --version Display version information and exit.
-b, --background Run LogWisp in the background as a daemon. -b, --background Run LogWisp in the background as a daemon.
-q, --quiet Suppress all console output, including errors. -q, --quiet Suppress all console output, including errors.
--router Enable HTTP router mode for multiplexing pipelines.
Runtime Behavior: Runtime Behavior:
--disable-status-reporter Disable the periodic status reporter. --disable-status-reporter Disable the periodic status reporter.
@ -24,7 +23,7 @@ Runtime Behavior:
Configuration Sources (Precedence: CLI > Env > File > Defaults): Configuration Sources (Precedence: CLI > Env > File > Defaults):
- CLI flags override all other settings. - CLI flags override all other settings.
- Environment variables (e.g., LOGWISP_ROUTER=true) override file settings. - Environment variables override file settings.
- TOML configuration file is the primary method for defining pipelines. - TOML configuration file is the primary method for defining pipelines.
Logging ([logging] section or LOGWISP_LOGGING_* env vars): Logging ([logging] section or LOGWISP_LOGGING_* env vars):

View File

@ -77,7 +77,6 @@ func main() {
"version", version.String(), "version", version.String(),
"config_file", cfg.ConfigFile, "config_file", cfg.ConfigFile,
"log_output", cfg.Logging.Output, "log_output", cfg.Logging.Output,
"router_mode", cfg.UseRouter,
"background_mode", cfg.Background) "background_mode", cfg.Background)
// Create context for shutdown // Create context for shutdown
@ -117,7 +116,7 @@ func main() {
// Traditional static bootstrap // Traditional static bootstrap
logger.Info("msg", "Config auto-reload disabled") logger.Info("msg", "Config auto-reload disabled")
svc, router, err := bootstrapService(ctx, cfg) svc, err := bootstrapService(ctx, cfg)
if err != nil { if err != nil {
logger.Error("msg", "Failed to bootstrap service", "error", err) logger.Error("msg", "Failed to bootstrap service", "error", err)
os.Exit(1) os.Exit(1)
@ -142,12 +141,6 @@ func main() {
logger.Info("msg", "Shutdown signal received, starting graceful shutdown...") logger.Info("msg", "Shutdown signal received, starting graceful shutdown...")
// Shutdown router first if using it
if router != nil {
logger.Info("msg", "Shutting down HTTP router...")
router.Shutdown()
}
// Shutdown service with timeout // Shutdown service with timeout
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel() defer shutdownCancel()

View File

@ -19,7 +19,6 @@ import (
type ReloadManager struct { type ReloadManager struct {
configPath string configPath string
service *service.Service service *service.Service
router *service.HTTPRouter
cfg *config.Config cfg *config.Config
lcfg *lconfig.Config lcfg *lconfig.Config
logger *log.Logger logger *log.Logger
@ -47,14 +46,13 @@ func NewReloadManager(configPath string, initialCfg *config.Config, logger *log.
// Start begins watching for configuration changes // Start begins watching for configuration changes
func (rm *ReloadManager) Start(ctx context.Context) error { func (rm *ReloadManager) Start(ctx context.Context) error {
// Bootstrap initial service // Bootstrap initial service
svc, router, err := bootstrapService(ctx, rm.cfg) svc, err := bootstrapService(ctx, rm.cfg)
if err != nil { if err != nil {
return fmt.Errorf("failed to bootstrap initial service: %w", err) return fmt.Errorf("failed to bootstrap initial service: %w", err)
} }
rm.mu.Lock() rm.mu.Lock()
rm.service = svc rm.service = svc
rm.router = router
rm.mu.Unlock() rm.mu.Unlock()
// Start status reporter for initial service // Start status reporter for initial service
@ -149,11 +147,6 @@ func (rm *ReloadManager) shouldReload(path string) bool {
return true return true
} }
// Router mode changes require reload
if path == "router" || path == "use_router" {
return true
}
// Logging changes don't require service reload // Logging changes don't require service reload
if strings.HasPrefix(path, "logging.") { if strings.HasPrefix(path, "logging.") {
return false return false
@ -214,12 +207,11 @@ func (rm *ReloadManager) performReload(ctx context.Context) error {
// Get current service snapshot // Get current service snapshot
rm.mu.RLock() rm.mu.RLock()
oldService := rm.service oldService := rm.service
oldRouter := rm.router
rm.mu.RUnlock() rm.mu.RUnlock()
// Try to bootstrap with new configuration // Try to bootstrap with new configuration
rm.logger.Debug("msg", "Bootstrapping new service with updated config") rm.logger.Debug("msg", "Bootstrapping new service with updated config")
newService, newRouter, err := bootstrapService(ctx, newCfg) newService, err := bootstrapService(ctx, newCfg)
if err != nil { if err != nil {
// Bootstrap failed - keep old services running // Bootstrap failed - keep old services running
return fmt.Errorf("failed to bootstrap new service (old service still active): %w", err) return fmt.Errorf("failed to bootstrap new service (old service still active): %w", err)
@ -228,7 +220,6 @@ func (rm *ReloadManager) performReload(ctx context.Context) error {
// Bootstrap succeeded - swap services atomically // Bootstrap succeeded - swap services atomically
rm.mu.Lock() rm.mu.Lock()
rm.service = newService rm.service = newService
rm.router = newRouter
rm.cfg = newCfg rm.cfg = newCfg
rm.mu.Unlock() rm.mu.Unlock()
@ -237,22 +228,17 @@ func (rm *ReloadManager) performReload(ctx context.Context) error {
// Gracefully shutdown old services // Gracefully shutdown old services
// This happens after the swap to minimize downtime // This happens after the swap to minimize downtime
go rm.shutdownOldServices(oldRouter, oldService) go rm.shutdownOldServices(oldService)
return nil return nil
} }
// shutdownOldServices gracefully shuts down old services // shutdownOldServices gracefully shuts down old services
func (rm *ReloadManager) shutdownOldServices(router *service.HTTPRouter, svc *service.Service) { func (rm *ReloadManager) shutdownOldServices(svc *service.Service) {
// Give connections time to drain // Give connections time to drain
rm.logger.Debug("msg", "Draining connections from old services") rm.logger.Debug("msg", "Draining connections from old services")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
if router != nil {
rm.logger.Info("msg", "Shutting down old router")
router.Shutdown()
}
if svc != nil { if svc != nil {
rm.logger.Info("msg", "Shutting down old service") rm.logger.Info("msg", "Shutting down old service")
svc.Shutdown() svc.Shutdown()
@ -337,15 +323,9 @@ func (rm *ReloadManager) Shutdown() {
// Shutdown current services // Shutdown current services
rm.mu.RLock() rm.mu.RLock()
currentRouter := rm.router
currentService := rm.service currentService := rm.service
rm.mu.RUnlock() rm.mu.RUnlock()
if currentRouter != nil {
rm.logger.Info("msg", "Shutting down router")
currentRouter.Shutdown()
}
if currentService != nil { if currentService != nil {
rm.logger.Info("msg", "Shutting down service") rm.logger.Info("msg", "Shutting down service")
currentService.Shutdown() currentService.Shutdown()
@ -358,10 +338,3 @@ func (rm *ReloadManager) GetService() *service.Service {
defer rm.mu.RUnlock() defer rm.mu.RUnlock()
return rm.service return rm.service
} }
// GetRouter returns the current router (thread-safe)
func (rm *ReloadManager) GetRouter() *service.HTTPRouter {
rm.mu.RLock()
defer rm.mu.RUnlock()
return rm.router
}

View File

@ -109,7 +109,7 @@ func logPipelineStatus(name string, stats map[string]any) {
} }
// displayPipelineEndpoints logs the configured endpoints for a pipeline // displayPipelineEndpoints logs the configured endpoints for a pipeline
func displayPipelineEndpoints(cfg config.PipelineConfig, routerMode bool) { func displayPipelineEndpoints(cfg config.PipelineConfig) {
// Display sink endpoints // Display sink endpoints
for i, sinkCfg := range cfg.Sinks { for i, sinkCfg := range cfg.Sinks {
switch sinkCfg.Type { switch sinkCfg.Type {
@ -144,19 +144,11 @@ func displayPipelineEndpoints(cfg config.PipelineConfig, routerMode bool) {
statusPath = path statusPath = path
} }
if routerMode { logger.Info("msg", "HTTP endpoints configured",
logger.Info("msg", "HTTP endpoints configured", "pipeline", cfg.Name,
"pipeline", cfg.Name, "sink_index", i,
"sink_index", i, "stream_url", fmt.Sprintf("http://localhost:%d%s", port, streamPath),
"stream_path", fmt.Sprintf("/%s%s", cfg.Name, streamPath), "status_url", fmt.Sprintf("http://localhost:%d%s", port, statusPath))
"status_path", fmt.Sprintf("/%s%s", cfg.Name, statusPath))
} else {
logger.Info("msg", "HTTP endpoints configured",
"pipeline", cfg.Name,
"sink_index", i,
"stream_url", fmt.Sprintf("http://localhost:%d%s", port, streamPath),
"status_url", fmt.Sprintf("http://localhost:%d%s", port, statusPath))
}
// Display net limit info if configured // Display net limit info if configured
if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok { if rl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {

View File

@ -0,0 +1,411 @@
// FILE: logwisp/src/internal/auth/authenticator.go
package auth
import (
"bufio"
"encoding/base64"
"fmt"
"os"
"strings"
"sync"
"time"
"logwisp/src/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/lixenwraith/log"
"golang.org/x/crypto/bcrypt"
)
// Authenticator handles all authentication methods for a pipeline
type Authenticator struct {
config *config.AuthConfig
logger *log.Logger
basicUsers map[string]string // username -> password hash
bearerTokens map[string]bool // token -> valid
jwtParser *jwt.Parser
jwtKeyFunc jwt.Keyfunc
mu sync.RWMutex
// Session tracking
sessions map[string]*Session
sessionMu sync.RWMutex
}
// Session represents an authenticated connection
type Session struct {
ID string
Username string
Method string // basic, bearer, jwt, mtls
RemoteAddr string
CreatedAt time.Time
LastActivity time.Time
Metadata map[string]any
}
// New creates a new authenticator from config
func New(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
if cfg == nil || cfg.Type == "none" {
return nil, nil
}
a := &Authenticator{
config: cfg,
logger: logger,
basicUsers: make(map[string]string),
bearerTokens: make(map[string]bool),
sessions: make(map[string]*Session),
}
// Initialize Basic Auth users
if cfg.Type == "basic" && cfg.BasicAuth != nil {
for _, user := range cfg.BasicAuth.Users {
a.basicUsers[user.Username] = user.PasswordHash
}
// Load users from file if specified
if cfg.BasicAuth.UsersFile != "" {
if err := a.loadUsersFile(cfg.BasicAuth.UsersFile); err != nil {
return nil, fmt.Errorf("failed to load users file: %w", err)
}
}
}
// Initialize Bearer tokens
if cfg.Type == "bearer" && cfg.BearerAuth != nil {
for _, token := range cfg.BearerAuth.Tokens {
a.bearerTokens[token] = true
}
// Setup JWT validation if configured
if cfg.BearerAuth.JWT != nil {
a.jwtParser = jwt.NewParser(
jwt.WithValidMethods([]string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}),
jwt.WithLeeway(5*time.Second),
)
// Setup key function
if cfg.BearerAuth.JWT.SigningKey != "" {
// Static key
key := []byte(cfg.BearerAuth.JWT.SigningKey)
a.jwtKeyFunc = func(token *jwt.Token) (interface{}, error) {
return key, nil
}
} else if cfg.BearerAuth.JWT.JWKSURL != "" {
// JWKS support would require additional implementation
// ☢ SECURITY: JWKS rotation not implemented - tokens won't refresh keys
return nil, fmt.Errorf("JWKS support not yet implemented")
}
}
}
// Start session cleanup
go a.sessionCleanup()
logger.Info("msg", "Authenticator initialized",
"component", "auth",
"type", cfg.Type)
return a, nil
}
// AuthenticateHTTP handles HTTP authentication headers
func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) {
if a == nil || a.config.Type == "none" {
return &Session{
ID: generateSessionID(),
Method: "none",
RemoteAddr: remoteAddr,
CreatedAt: time.Now(),
}, nil
}
switch a.config.Type {
case "basic":
return a.authenticateBasic(authHeader, remoteAddr)
case "bearer":
return a.authenticateBearer(authHeader, remoteAddr)
default:
return nil, fmt.Errorf("unsupported auth type: %s", a.config.Type)
}
}
// AuthenticateTCP handles TCP connection authentication
func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string) (*Session, error) {
if a == nil || a.config.Type == "none" {
return &Session{
ID: generateSessionID(),
Method: "none",
RemoteAddr: remoteAddr,
CreatedAt: time.Now(),
}, nil
}
// TCP auth protocol: AUTH <method> <credentials>
switch strings.ToLower(method) {
case "token":
if a.config.Type != "bearer" {
return nil, fmt.Errorf("token auth not configured")
}
return a.validateToken(credentials, remoteAddr)
case "basic":
if a.config.Type != "basic" {
return nil, fmt.Errorf("basic auth not configured")
}
// Expect base64(username:password)
decoded, err := base64.StdEncoding.DecodeString(credentials)
if err != nil {
return nil, fmt.Errorf("invalid credentials encoding")
}
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid credentials format")
}
return a.validateBasicAuth(parts[0], parts[1], remoteAddr)
default:
return nil, fmt.Errorf("unsupported auth method: %s", method)
}
}
func (a *Authenticator) authenticateBasic(authHeader, remoteAddr string) (*Session, error) {
if !strings.HasPrefix(authHeader, "Basic ") {
return nil, fmt.Errorf("invalid basic auth header")
}
payload, err := base64.StdEncoding.DecodeString(authHeader[6:])
if err != nil {
return nil, fmt.Errorf("invalid base64 encoding")
}
parts := strings.SplitN(string(payload), ":", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid credentials format")
}
return a.validateBasicAuth(parts[0], parts[1], remoteAddr)
}
func (a *Authenticator) validateBasicAuth(username, password, remoteAddr string) (*Session, error) {
a.mu.RLock()
expectedHash, exists := a.basicUsers[username]
a.mu.RUnlock()
if !exists {
// ☢ SECURITY: Perform bcrypt anyway to prevent timing attacks
bcrypt.CompareHashAndPassword([]byte("$2a$10$dummy.hash.to.prevent.timing.attacks"), []byte(password))
return nil, fmt.Errorf("invalid credentials")
}
if err := bcrypt.CompareHashAndPassword([]byte(expectedHash), []byte(password)); err != nil {
return nil, fmt.Errorf("invalid credentials")
}
session := &Session{
ID: generateSessionID(),
Username: username,
Method: "basic",
RemoteAddr: remoteAddr,
CreatedAt: time.Now(),
LastActivity: time.Now(),
}
a.storeSession(session)
return session, nil
}
func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) {
if !strings.HasPrefix(authHeader, "Bearer ") {
return nil, fmt.Errorf("invalid bearer auth header")
}
token := authHeader[7:]
return a.validateToken(token, remoteAddr)
}
func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) {
// Check static tokens first
a.mu.RLock()
isStatic := a.bearerTokens[token]
a.mu.RUnlock()
if isStatic {
session := &Session{
ID: generateSessionID(),
Method: "bearer",
RemoteAddr: remoteAddr,
CreatedAt: time.Now(),
LastActivity: time.Now(),
Metadata: map[string]any{"token_type": "static"},
}
a.storeSession(session)
return session, nil
}
// Try JWT validation if configured
if a.jwtParser != nil && a.jwtKeyFunc != nil {
claims := jwt.MapClaims{}
parsedToken, err := a.jwtParser.ParseWithClaims(token, claims, a.jwtKeyFunc)
if err != nil {
return nil, fmt.Errorf("JWT validation failed: %w", err)
}
if !parsedToken.Valid {
return nil, fmt.Errorf("invalid JWT token")
}
// Check issuer if configured
if a.config.BearerAuth.JWT.Issuer != "" {
if iss, ok := claims["iss"].(string); !ok || iss != a.config.BearerAuth.JWT.Issuer {
return nil, fmt.Errorf("invalid token issuer")
}
}
// Check audience if configured
if a.config.BearerAuth.JWT.Audience != "" {
if aud, ok := claims["aud"].(string); !ok || aud != a.config.BearerAuth.JWT.Audience {
return nil, fmt.Errorf("invalid token audience")
}
}
username := ""
if sub, ok := claims["sub"].(string); ok {
username = sub
}
session := &Session{
ID: generateSessionID(),
Username: username,
Method: "jwt",
RemoteAddr: remoteAddr,
CreatedAt: time.Now(),
LastActivity: time.Now(),
Metadata: map[string]any{"claims": claims},
}
a.storeSession(session)
return session, nil
}
return nil, fmt.Errorf("invalid token")
}
func (a *Authenticator) storeSession(session *Session) {
a.sessionMu.Lock()
a.sessions[session.ID] = session
a.sessionMu.Unlock()
a.logger.Info("msg", "Session created",
"component", "auth",
"session_id", session.ID,
"username", session.Username,
"method", session.Method,
"remote_addr", session.RemoteAddr)
}
func (a *Authenticator) sessionCleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
a.sessionMu.Lock()
now := time.Now()
for id, session := range a.sessions {
if now.Sub(session.LastActivity) > 30*time.Minute {
delete(a.sessions, id)
a.logger.Debug("msg", "Session expired",
"component", "auth",
"session_id", id)
}
}
a.sessionMu.Unlock()
}
}
func (a *Authenticator) loadUsersFile(path string) error {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open users file: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
lineNumber := 0
for scanner.Scan() {
lineNumber++
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue // Skip empty lines and comments
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
a.logger.Warn("msg", "Skipping malformed line in users file",
"component", "auth",
"path", path,
"line_number", lineNumber)
continue
}
username, hash := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
if username != "" && hash != "" {
// File-based users can overwrite inline users if names conflict
a.basicUsers[username] = hash
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading users file: %w", err)
}
a.logger.Info("msg", "Loaded users from file",
"component", "auth",
"path", path,
"user_count", len(a.basicUsers))
return nil
}
func generateSessionID() string {
return fmt.Sprintf("%d-%d", time.Now().UnixNano(), time.Now().Unix())
}
// ValidateSession checks if a session is still valid
func (a *Authenticator) ValidateSession(sessionID string) bool {
if a == nil {
return true
}
a.sessionMu.RLock()
session, exists := a.sessions[sessionID]
a.sessionMu.RUnlock()
if !exists {
return false
}
// Update activity
a.sessionMu.Lock()
session.LastActivity = time.Now()
a.sessionMu.Unlock()
return true
}
// GetStats returns authentication statistics
func (a *Authenticator) GetStats() map[string]any {
if a == nil {
return map[string]any{"enabled": false}
}
a.sessionMu.RLock()
sessionCount := len(a.sessions)
a.sessionMu.RUnlock()
return map[string]any{
"enabled": true,
"type": a.config.Type,
"active_sessions": sessionCount,
"basic_users": len(a.basicUsers),
"static_tokens": len(a.bearerTokens),
}
}

View File

@ -1,7 +1,11 @@
// FILE: logwisp/src/internal/config/auth.go // FILE: logwisp/src/internal/config/auth.go
package config package config
import "fmt" import (
"fmt"
"net"
"strings"
)
type AuthConfig struct { type AuthConfig struct {
// Authentication type: "none", "basic", "bearer", "mtls" // Authentication type: "none", "basic", "bearer", "mtls"
@ -12,10 +16,6 @@ type AuthConfig struct {
// Bearer token auth // Bearer token auth
BearerAuth *BearerAuthConfig `toml:"bearer_auth"` BearerAuth *BearerAuthConfig `toml:"bearer_auth"`
// IP-based access control
IPWhitelist []string `toml:"ip_whitelist"`
IPBlacklist []string `toml:"ip_blacklist"`
} }
type BasicAuthConfig struct { type BasicAuthConfig struct {

View File

@ -3,7 +3,6 @@ package config
type Config struct { type Config struct {
// Top-level flags for application control // Top-level flags for application control
UseRouter bool `toml:"router"`
Background bool `toml:"background"` Background bool `toml:"background"`
ShowVersion bool `toml:"version"` ShowVersion bool `toml:"version"`
Quiet bool `toml:"quiet"` Quiet bool `toml:"quiet"`

View File

@ -3,6 +3,7 @@ package config
import ( import (
"fmt" "fmt"
"net"
"strings" "strings"
) )
@ -28,6 +29,37 @@ type RateLimitConfig struct {
MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"` MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"`
} }
func validateNetAccess(pipelineName string, cfg *NetAccessConfig) error {
if cfg == nil {
return nil
}
// Validate CIDR notation
for _, cidr := range cfg.IPWhitelist {
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
if _, _, err := net.ParseCIDR(cidr); err != nil {
if net.ParseIP(cidr) == nil {
return fmt.Errorf("pipeline '%s': invalid IP whitelist entry: %s", pipelineName, cidr)
}
}
}
for _, cidr := range cfg.IPBlacklist {
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
if _, _, err := net.ParseCIDR(cidr); err != nil {
if net.ParseIP(cidr) == nil {
return fmt.Errorf("pipeline '%s': invalid IP blacklist entry: %s", pipelineName, cidr)
}
}
}
return nil
}
func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error {
if cfg == nil { if cfg == nil {
return nil return nil

View File

@ -19,7 +19,6 @@ type LoadContext struct {
func defaults() *Config { func defaults() *Config {
return &Config{ return &Config{
// Top-level flag defaults // Top-level flag defaults
UseRouter: false,
Background: false, Background: false,
ShowVersion: false, ShowVersion: false,
Quiet: false, Quiet: false,

View File

@ -20,6 +20,9 @@ type PipelineConfig struct {
// Rate limiting // Rate limiting
RateLimit *RateLimitConfig `toml:"rate_limit"` RateLimit *RateLimitConfig `toml:"rate_limit"`
// Network access control (IP filtering)
NetAccess *NetAccessConfig `toml:"net_access"`
// Filter configuration // Filter configuration
Filters []FilterConfig `toml:"filters"` Filters []FilterConfig `toml:"filters"`
@ -34,6 +37,12 @@ type PipelineConfig struct {
Auth *AuthConfig `toml:"auth"` Auth *AuthConfig `toml:"auth"`
} }
// NetAccessConfig defines IP-based access control lists
type NetAccessConfig struct {
IPWhitelist []string `toml:"ip_whitelist"`
IPBlacklist []string `toml:"ip_blacklist"`
}
// SourceConfig represents an input data source // SourceConfig represents an input data source
type SourceConfig struct { type SourceConfig struct {
// Source type: "directory", "file", "stdin", etc. // Source type: "directory", "file", "stdin", etc.

View File

@ -72,6 +72,11 @@ func (c *Config) validate() error {
} }
} }
// Validate net access if present
if err := validateNetAccess(pipeline.Name, pipeline.NetAccess); err != nil {
return err
}
// Validate auth if present // Validate auth if present
if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil { if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil {
return err return err

173
src/internal/limit/ip.go Normal file
View File

@ -0,0 +1,173 @@
// FILE: src/internal/limit/ip.go
package limit
import (
"net"
"strings"
"logwisp/src/internal/config"
"github.com/lixenwraith/log"
)
// IPChecker handles IP-based access control lists
type IPChecker struct {
ipWhitelist []*net.IPNet
ipBlacklist []*net.IPNet
logger *log.Logger
}
// NewIPChecker creates a new IPChecker. Returns nil if no rules are defined.
func NewIPChecker(cfg *config.NetAccessConfig, logger *log.Logger) *IPChecker {
if cfg == nil || (len(cfg.IPWhitelist) == 0 && len(cfg.IPBlacklist) == 0) {
return nil
}
c := &IPChecker{
ipWhitelist: make([]*net.IPNet, 0),
ipBlacklist: make([]*net.IPNet, 0),
logger: logger,
}
// Parse whitelist entries
for _, cidr := range cfg.IPWhitelist {
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
// Try parsing as plain IP
if ip := net.ParseIP(cidr); ip != nil {
if ip.To4() != nil {
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)}
} else {
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
}
} else {
logger.Warn("msg", "Skipping invalid IP whitelist entry",
"component", "ip_checker",
"entry", cidr,
"error", err)
continue
}
}
c.ipWhitelist = append(c.ipWhitelist, ipNet)
}
// Parse blacklist entries
for _, cidr := range cfg.IPBlacklist {
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
// Try parsing as plain IP
if ip := net.ParseIP(cidr); ip != nil {
if ip.To4() != nil {
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)}
} else {
ipNet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
}
} else {
logger.Warn("msg", "Skipping invalid IP blacklist entry",
"component", "ip_checker",
"entry", cidr,
"error", err)
continue
}
}
c.ipBlacklist = append(c.ipBlacklist, ipNet)
}
logger.Info("msg", "IP checker initialized",
"component", "ip_checker",
"whitelist_rules", len(c.ipWhitelist),
"blacklist_rules", len(c.ipBlacklist))
return c
}
// IsAllowed validates if a remote address is permitted
func (c *IPChecker) IsAllowed(remoteAddr net.Addr) bool {
if c == nil {
return true // No checker = allow all
}
// No rules = allow all
if len(c.ipWhitelist) == 0 && len(c.ipBlacklist) == 0 {
return true
}
// Extract IP from address
var ipStr string
switch addr := remoteAddr.(type) {
case *net.TCPAddr:
ipStr = addr.IP.String()
case *net.UDPAddr:
ipStr = addr.IP.String()
default:
// Try string parsing
addrStr := remoteAddr.String()
host, _, err := net.SplitHostPort(addrStr)
if err != nil {
ipStr = addrStr
} else {
ipStr = host
}
}
ip := net.ParseIP(ipStr)
if ip == nil {
c.logger.Warn("msg", "Could not parse remote address to IP",
"component", "ip_checker",
"remote_addr", remoteAddr.String())
return false // Deny unparseable addresses
}
// Check blacklist first (deny takes precedence)
for _, ipNet := range c.ipBlacklist {
if ipNet.Contains(ip) {
c.logger.Warn("msg", "Blacklisted IP denied",
"component", "ip_checker",
"ip", ipStr,
"rule", ipNet.String())
return false
}
}
// If whitelist is configured, IP must be in it
if len(c.ipWhitelist) > 0 {
for _, ipNet := range c.ipWhitelist {
if ipNet.Contains(ip) {
c.logger.Debug("msg", "IP allowed by whitelist",
"component", "ip_checker",
"ip", ipStr,
"rule", ipNet.String())
return true
}
}
// No whitelist match = deny
c.logger.Warn("msg", "IP not in whitelist",
"component", "ip_checker",
"ip", ipStr)
return false
}
// No blacklist match + no whitelist configured = allow
return true
}
// GetStats returns IP checker statistics
func (c *IPChecker) GetStats() map[string]any {
if c == nil {
return map[string]any{"enabled": false}
}
return map[string]any{
"enabled": true,
"whitelist_rules": len(c.ipWhitelist),
"blacklist_rules": len(c.ipBlacklist),
}
}

View File

@ -1,232 +0,0 @@
// FILE: logwisp/src/internal/service/httprouter.go
package service
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"logwisp/src/internal/sink"
"github.com/lixenwraith/log"
"github.com/valyala/fasthttp"
)
// HTTPRouter manages HTTP routing for multiple pipelines
type HTTPRouter struct {
service *Service
servers map[int64]*routerServer // port -> server
mu sync.RWMutex
logger *log.Logger
// Statistics
startTime time.Time
totalRequests atomic.Uint64
routedRequests atomic.Uint64
failedRequests atomic.Uint64
}
// NewHTTPRouter creates a new HTTP router
func NewHTTPRouter(service *Service, logger *log.Logger) *HTTPRouter {
return &HTTPRouter{
service: service,
servers: make(map[int64]*routerServer),
startTime: time.Now(),
logger: logger,
}
}
// RegisterPipeline registers a pipeline's HTTP sinks with the router
func (r *HTTPRouter) RegisterPipeline(pipeline *Pipeline) error {
// Register all HTTP sinks in the pipeline
for _, httpSink := range pipeline.HTTPSinks {
if err := r.registerHTTPSink(pipeline.Name, httpSink); err != nil {
return err
}
}
return nil
}
// registerHTTPSink registers a single HTTP sink
func (r *HTTPRouter) registerHTTPSink(pipelineName string, httpSink *sink.HTTPSink) error {
// Get port from sink configuration
stats := httpSink.GetStats()
details := stats.Details
port := details["port"].(int64)
r.mu.Lock()
rs, exists := r.servers[port]
if !exists {
// Create new server for this port
rs = &routerServer{
port: port,
routes: make(map[string]*routedSink),
router: r,
startTime: time.Now(),
logger: r.logger,
}
rs.server = &fasthttp.Server{
Handler: rs.requestHandler,
DisableKeepalive: false,
StreamRequestBody: true,
CloseOnShutdown: true,
}
r.servers[port] = rs
// Startup sync channel
startupDone := make(chan error, 1)
// Start server in background
go func() {
addr := fmt.Sprintf(":%d", port)
r.logger.Info("msg", "Starting router server",
"component", "http_router",
"port", port)
// Signal that server is about to start
startupDone <- nil
if err := rs.server.ListenAndServe(addr); err != nil {
r.logger.Error("msg", "Router server failed",
"component", "http_router",
"port", port,
"error", err)
}
}()
// Wait for server startup signal with timeout
select {
case err := <-startupDone:
if err != nil {
r.mu.Unlock()
return fmt.Errorf("server startup failed: %w", err)
}
case <-time.After(5 * time.Second):
r.mu.Unlock()
return fmt.Errorf("server startup timeout on port %d", port)
}
}
r.mu.Unlock()
// Register routes for this sink
rs.routeMu.Lock()
defer rs.routeMu.Unlock()
// Use pipeline name as path prefix
pathPrefix := "/" + pipelineName
// Check for conflicts
for existingPath, existing := range rs.routes {
if strings.HasPrefix(pathPrefix, existingPath) || strings.HasPrefix(existingPath, pathPrefix) {
return fmt.Errorf("path conflict: '%s' conflicts with existing pipeline '%s' at '%s'",
pathPrefix, existing.pipelineName, existingPath)
}
}
// Set the sink to router mode
httpSink.SetRouterMode()
rs.routes[pathPrefix] = &routedSink{
pipelineName: pipelineName,
httpSink: httpSink,
}
r.logger.Info("msg", "Registered pipeline route",
"component", "http_router",
"pipeline", pipelineName,
"path", pathPrefix,
"port", port)
return nil
}
// UnregisterPipeline removes a pipeline's routes
func (r *HTTPRouter) UnregisterPipeline(pipelineName string) {
r.mu.RLock()
defer r.mu.RUnlock()
for port, rs := range r.servers {
rs.routeMu.Lock()
for path, route := range rs.routes {
if route.pipelineName == pipelineName {
delete(rs.routes, path)
r.logger.Info("msg", "Unregistered pipeline route",
"component", "http_router",
"pipeline", pipelineName,
"path", path,
"port", port)
}
}
// Check if server has no more routes
if len(rs.routes) == 0 {
r.logger.Info("msg", "No routes left on port, considering shutdown",
"component", "http_router",
"port", port)
}
rs.routeMu.Unlock()
}
}
// Shutdown stops all router servers
func (r *HTTPRouter) Shutdown() {
r.logger.Info("msg", "Starting router shutdown...")
r.mu.Lock()
defer r.mu.Unlock()
var wg sync.WaitGroup
for port, rs := range r.servers {
wg.Add(1)
go func(p int64, s *routerServer) {
defer wg.Done()
r.logger.Info("msg", "Shutting down server",
"component", "http_router",
"port", p)
if err := s.server.Shutdown(); err != nil {
r.logger.Error("msg", "Error shutting down server",
"component", "http_router",
"port", p,
"error", err)
}
}(port, rs)
}
wg.Wait()
r.logger.Info("msg", "Router shutdown complete")
}
// GetStats returns router statistics
func (r *HTTPRouter) GetStats() map[string]any {
r.mu.RLock()
defer r.mu.RUnlock()
serverStats := make(map[int64]any)
totalRoutes := 0
for port, rs := range r.servers {
rs.routeMu.RLock()
routes := make([]string, 0, len(rs.routes))
for path := range rs.routes {
routes = append(routes, path)
totalRoutes++
}
rs.routeMu.RUnlock()
serverStats[port] = map[string]any{
"routes": routes,
"requests": rs.requests.Load(),
"uptime": int(time.Since(rs.startTime).Seconds()),
}
}
return map[string]any{
"uptime_seconds": int64(time.Since(r.startTime).Seconds()),
"total_requests": r.totalRequests.Load(),
"routed_requests": r.routedRequests.Load(),
"failed_requests": r.failedRequests.Load(),
"servers": serverStats,
"total_routes": totalRoutes,
}
}

View File

@ -30,10 +30,6 @@ type Pipeline struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup wg sync.WaitGroup
// For HTTP sinks in router mode
HTTPSinks []*sink.HTTPSink
TCPSinks []*sink.TCPSink
} }
// PipelineStats contains statistics for a pipeline // PipelineStats contains statistics for a pipeline

View File

@ -1,192 +0,0 @@
// FILE: logwisp/src/internal/service/routerserver.go
package service
import (
"encoding/json"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"logwisp/src/internal/sink"
"logwisp/src/internal/version"
"github.com/lixenwraith/log"
"github.com/valyala/fasthttp"
)
// routedSink represents a sink registered with the router
type routedSink struct {
pipelineName string
httpSink *sink.HTTPSink
}
// routerServer handles HTTP requests for a specific port
type routerServer struct {
port int64
server *fasthttp.Server
logger *log.Logger
routes map[string]*routedSink // path prefix -> sink
routeMu sync.RWMutex
router *HTTPRouter
startTime time.Time
requests atomic.Uint64
}
func (rs *routerServer) requestHandler(ctx *fasthttp.RequestCtx) {
rs.requests.Add(1)
rs.router.totalRequests.Add(1)
path := string(ctx.Path())
remoteAddr := ctx.RemoteAddr().String()
// Log request for debugging
rs.logger.Debug("msg", "Router request",
"component", "router_server",
"method", string(ctx.Method()),
"path", path,
"remote_addr", remoteAddr)
// Special case: global status at /status
if path == "/status" {
rs.handleGlobalStatus(ctx)
return
}
// Find matching route
rs.routeMu.RLock()
var matchedSink *routedSink
var matchedPrefix string
var remainingPath string
for prefix, route := range rs.routes {
if strings.HasPrefix(path, prefix) {
// Longest prefix match
if len(prefix) > len(matchedPrefix) {
matchedPrefix = prefix
matchedSink = route
remainingPath = strings.TrimPrefix(path, prefix)
// Ensure remaining path starts with / or is empty
if remainingPath != "" && !strings.HasPrefix(remainingPath, "/") {
remainingPath = "/" + remainingPath
}
}
}
}
rs.routeMu.RUnlock()
if matchedSink == nil {
rs.router.failedRequests.Add(1)
rs.handleNotFound(ctx)
return
}
rs.router.routedRequests.Add(1)
// Route to sink's handler
if matchedSink.httpSink != nil {
// Save original path
originalPath := string(ctx.URI().Path())
// Rewrite path to remove pipeline prefix
if remainingPath == "" {
// Default to stream path if no remaining path
remainingPath = matchedSink.httpSink.GetStreamPath()
}
rs.logger.Debug("msg", "Routing request to pipeline",
"component", "router_server",
"pipeline", matchedSink.pipelineName,
"original_path", originalPath,
"remaining_path", remainingPath)
ctx.URI().SetPath(remainingPath)
matchedSink.httpSink.RouteRequest(ctx)
// Restore original path
ctx.URI().SetPath(originalPath)
} else {
ctx.SetStatusCode(fasthttp.StatusServiceUnavailable)
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "Pipeline HTTP sink not available",
"pipeline": matchedSink.pipelineName,
})
}
}
func (rs *routerServer) handleGlobalStatus(ctx *fasthttp.RequestCtx) {
ctx.SetContentType("application/json")
rs.routeMu.RLock()
pipelines := make(map[string]any)
for prefix, route := range rs.routes {
pipelineInfo := map[string]any{
"path_prefix": prefix,
"endpoints": map[string]string{
"stream": prefix + route.httpSink.GetStreamPath(),
"status": prefix + route.httpSink.GetStatusPath(),
},
}
// Get sink stats
sinkStats := route.httpSink.GetStats()
pipelineInfo["sink"] = map[string]any{
"type": sinkStats.Type,
"total_processed": sinkStats.TotalProcessed,
"active_connections": sinkStats.ActiveConnections,
"details": sinkStats.Details,
}
pipelines[route.pipelineName] = pipelineInfo
}
rs.routeMu.RUnlock()
// Get router stats
routerStats := rs.router.GetStats()
status := map[string]any{
"service": "LogWisp Router",
"version": version.String(),
"port": rs.port,
"pipelines": pipelines,
"total_pipelines": len(pipelines),
"router": routerStats,
"endpoints": map[string]string{
"global_status": "/status",
},
}
data, _ := json.MarshalIndent(status, "", " ")
ctx.SetBody(data)
}
func (rs *routerServer) handleNotFound(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetContentType("application/json")
rs.routeMu.RLock()
availableRoutes := make([]string, 0, len(rs.routes)*2+1)
availableRoutes = append(availableRoutes, "/status (global status)")
for prefix, route := range rs.routes {
if route.httpSink != nil {
availableRoutes = append(availableRoutes,
fmt.Sprintf("%s%s (stream: %s)", prefix, route.httpSink.GetStreamPath(), route.pipelineName),
fmt.Sprintf("%s%s (status: %s)", prefix, route.httpSink.GetStatusPath(), route.pipelineName),
)
}
}
rs.routeMu.RUnlock()
response := map[string]any{
"error": "Not Found",
"requested_path": string(ctx.Path()),
"available_routes": availableRoutes,
"hint": "Use /status for global router status",
}
data, _ := json.MarshalIndent(response, "", " ")
ctx.SetBody(data)
}

View File

@ -113,20 +113,12 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
// Create sinks // Create sinks
for i, sinkCfg := range cfg.Sinks { for i, sinkCfg := range cfg.Sinks {
sinkInst, err := s.createSink(sinkCfg, formatter) // Pass formatter sinkInst, err := s.createSink(sinkCfg, formatter)
if err != nil { if err != nil {
pipelineCancel() pipelineCancel()
return fmt.Errorf("failed to create sink[%d]: %w", i, err) return fmt.Errorf("failed to create sink[%d]: %w", i, err)
} }
pipeline.Sinks = append(pipeline.Sinks, sinkInst) pipeline.Sinks = append(pipeline.Sinks, sinkInst)
// Track HTTP/TCP sinks for router mode
switch s := sinkInst.(type) {
case *sink.HTTPSink:
pipeline.HTTPSinks = append(pipeline.HTTPSinks, s)
case *sink.TCPSink:
pipeline.TCPSinks = append(pipeline.TCPSinks, s)
}
} }
// Start all sources // Start all sources
@ -145,6 +137,16 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
} }
} }
// Configure authentication for sinks that support it
for _, sinkInst := range pipeline.Sinks {
if setter, ok := sinkInst.(sink.NetAccessSetter); ok {
setter.SetNetAccessConfig(cfg.NetAccess)
}
if setter, ok := sinkInst.(sink.AuthSetter); ok {
setter.SetAuthConfig(cfg.Auth)
}
}
// Wire sources to sinks through filters // Wire sources to sinks through filters
s.wirePipeline(pipeline) s.wirePipeline(pipeline)
@ -152,7 +154,9 @@ func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
pipeline.startStatsUpdater(pipelineCtx) pipeline.startStatsUpdater(pipelineCtx)
s.pipelines[cfg.Name] = pipeline s.pipelines[cfg.Name] = pipeline
s.logger.Info("msg", "Pipeline created successfully", "pipeline", cfg.Name) s.logger.Info("msg", "Pipeline created successfully",
"pipeline", cfg.Name,
"auth_enabled", cfg.Auth != nil && cfg.Auth.Type != "none")
return nil return nil
} }
@ -268,19 +272,19 @@ func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter)
switch cfg.Type { switch cfg.Type {
case "http": case "http":
return sink.NewHTTPSink(cfg.Options, s.logger, formatter) // needs implementation return sink.NewHTTPSink(cfg.Options, s.logger, formatter)
case "tcp": case "tcp":
return sink.NewTCPSink(cfg.Options, s.logger, formatter) // needs implementation return sink.NewTCPSink(cfg.Options, s.logger, formatter)
case "http_client": case "http_client":
return sink.NewHTTPClientSink(cfg.Options, s.logger, formatter) // needs verification return sink.NewHTTPClientSink(cfg.Options, s.logger, formatter)
case "tcp_client": case "tcp_client":
return sink.NewTCPClientSink(cfg.Options, s.logger, formatter) // needs implementation return sink.NewTCPClientSink(cfg.Options, s.logger, formatter)
case "file": case "file":
return sink.NewFileSink(cfg.Options, s.logger, formatter) return sink.NewFileSink(cfg.Options, s.logger, formatter)
case "stdout": case "stdout":
return sink.NewStdoutSink(cfg.Options, s.logger, formatter) // needs implementation return sink.NewStdoutSink(cfg.Options, s.logger, formatter)
case "stderr": case "stderr":
return sink.NewStderrSink(cfg.Options, s.logger, formatter) // needs implementation return sink.NewStderrSink(cfg.Options, s.logger, formatter)
default: default:
return nil, fmt.Errorf("unknown sink type: %s", cfg.Type) return nil, fmt.Errorf("unknown sink type: %s", cfg.Type)
} }

View File

@ -11,10 +11,12 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"logwisp/src/internal/auth"
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/format" "logwisp/src/internal/format"
"logwisp/src/internal/limit" "logwisp/src/internal/limit"
"logwisp/src/internal/tls"
"logwisp/src/internal/version" "logwisp/src/internal/version"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
@ -35,19 +37,24 @@ type HTTPSink struct {
logger *log.Logger logger *log.Logger
formatter format.Formatter formatter format.Formatter
// Security components
authenticator *auth.Authenticator
tlsManager *tls.Manager
authConfig *config.AuthConfig
// Path configuration // Path configuration
streamPath string streamPath string
statusPath string statusPath string
// For router integration
standalone bool
// Net limiting // Net limiting
netLimiter *limit.NetLimiter netLimiter *limit.NetLimiter
ipChecker *limit.IPChecker
// Statistics // Statistics
totalProcessed atomic.Uint64 totalProcessed atomic.Uint64
lastProcessed atomic.Value // time.Time lastProcessed atomic.Value // time.Time
authFailures atomic.Uint64
authSuccesses atomic.Uint64
} }
// HTTPConfig holds HTTP sink configuration // HTTPConfig holds HTTP sink configuration
@ -98,6 +105,32 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
} }
} }
// Extract SSL config
if ssl, ok := options["ssl"].(map[string]any); ok {
cfg.SSL = &config.SSLConfig{}
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
if certFile, ok := ssl["cert_file"].(string); ok {
cfg.SSL.CertFile = certFile
}
if keyFile, ok := ssl["key_file"].(string); ok {
cfg.SSL.KeyFile = keyFile
}
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
if caFile, ok := ssl["client_ca_file"].(string); ok {
cfg.SSL.ClientCAFile = caFile
}
cfg.SSL.VerifyClientCert, _ = ssl["verify_client_cert"].(bool)
if minVer, ok := ssl["min_version"].(string); ok {
cfg.SSL.MinVersion = minVer
}
if maxVer, ok := ssl["max_version"].(string); ok {
cfg.SSL.MaxVersion = maxVer
}
if ciphers, ok := ssl["cipher_suites"].(string); ok {
cfg.SSL.CipherSuites = ciphers
}
}
// Extract net limit config // Extract net limit config
if rl, ok := options["net_limit"].(map[string]any); ok { if rl, ok := options["net_limit"].(map[string]any); ok {
cfg.NetLimit = &config.NetLimitConfig{} cfg.NetLimit = &config.NetLimitConfig{}
@ -132,7 +165,6 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo
done: make(chan struct{}), done: make(chan struct{}),
streamPath: cfg.StreamPath, streamPath: cfg.StreamPath,
statusPath: cfg.StatusPath, statusPath: cfg.StatusPath,
standalone: true,
logger: logger, logger: logger,
formatter: formatter, formatter: formatter,
} }
@ -151,13 +183,6 @@ func (h *HTTPSink) Input() chan<- core.LogEntry {
} }
func (h *HTTPSink) Start(ctx context.Context) error { func (h *HTTPSink) Start(ctx context.Context) error {
if !h.standalone {
// In router mode, don't start our own server
h.logger.Debug("msg", "HTTP sink in router mode, skipping server start",
"component", "http_sink")
return nil
}
// Create fasthttp adapter for logging // Create fasthttp adapter for logging
fasthttpLogger := compat.NewFastHTTPAdapter(h.logger) fasthttpLogger := compat.NewFastHTTPAdapter(h.logger)
@ -168,6 +193,12 @@ func (h *HTTPSink) Start(ctx context.Context) error {
Logger: fasthttpLogger, Logger: fasthttpLogger,
} }
// Configure TLS if enabled
if h.tlsManager != nil {
tlsConfig := h.tlsManager.GetHTTPConfig()
h.server.TLSConfig = tlsConfig
}
addr := fmt.Sprintf(":%d", h.config.Port) addr := fmt.Sprintf(":%d", h.config.Port)
// Run server in separate goroutine to avoid blocking // Run server in separate goroutine to avoid blocking
@ -178,7 +209,16 @@ func (h *HTTPSink) Start(ctx context.Context) error {
"port", h.config.Port, "port", h.config.Port,
"stream_path", h.streamPath, "stream_path", h.streamPath,
"status_path", h.statusPath) "status_path", h.statusPath)
err := h.server.ListenAndServe(addr)
var err error
if h.tlsManager != nil {
// HTTPS server
err = h.server.ListenAndServeTLS(addr, "", "")
} else {
// HTTP server
err = h.server.ListenAndServe(addr)
}
if err != nil { if err != nil {
errChan <- err errChan <- err
} }
@ -210,8 +250,8 @@ func (h *HTTPSink) Stop() {
// Signal all client handlers to stop // Signal all client handlers to stop
close(h.done) close(h.done)
// Shutdown HTTP server if in standalone mode // Shutdown HTTP server
if h.standalone && h.server != nil { if h.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
h.server.ShutdownWithContext(ctx) h.server.ShutdownWithContext(ctx)
@ -231,6 +271,18 @@ func (h *HTTPSink) GetStats() SinkStats {
netLimitStats = h.netLimiter.GetStats() netLimitStats = h.netLimiter.GetStats()
} }
var authStats map[string]any
if h.authenticator != nil {
authStats = h.authenticator.GetStats()
authStats["failures"] = h.authFailures.Load()
authStats["successes"] = h.authSuccesses.Load()
}
var tlsStats map[string]any
if h.tlsManager != nil {
tlsStats = h.tlsManager.GetStats()
}
return SinkStats{ return SinkStats{
Type: "http", Type: "http",
TotalProcessed: h.totalProcessed.Load(), TotalProcessed: h.totalProcessed.Load(),
@ -245,42 +297,83 @@ func (h *HTTPSink) GetStats() SinkStats {
"status": h.statusPath, "status": h.statusPath,
}, },
"net_limit": netLimitStats, "net_limit": netLimitStats,
"auth": authStats,
"tls": tlsStats,
}, },
} }
} }
// SetRouterMode configures the sink for use with a router
func (h *HTTPSink) SetRouterMode() {
h.standalone = false
h.logger.Debug("msg", "HTTP sink set to router mode",
"component", "http_sink")
}
// RouteRequest handles a request from the router
func (h *HTTPSink) RouteRequest(ctx *fasthttp.RequestCtx) {
h.requestHandler(ctx)
}
func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
// Check net limit first
remoteAddr := ctx.RemoteAddr().String() remoteAddr := ctx.RemoteAddr().String()
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
ctx.SetStatusCode(int(statusCode)) // Check IP access control
ctx.SetContentType("application/json") if h.ipChecker != nil {
json.NewEncoder(ctx).Encode(map[string]any{ if !h.ipChecker.IsAllowed(ctx.RemoteAddr()) {
"error": message, ctx.SetStatusCode(fasthttp.StatusForbidden)
"retry_after": "60", // seconds ctx.SetContentType("text/plain")
}) ctx.SetBodyString("Forbidden")
return return
}
}
// Check net limit
if h.netLimiter != nil {
if allowed, statusCode, message := h.netLimiter.CheckHTTP(remoteAddr); !allowed {
ctx.SetStatusCode(int(statusCode))
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]any{
"error": message,
"retry_after": "60", // seconds
})
return
}
} }
path := string(ctx.Path()) path := string(ctx.Path())
// Status endpoint doesn't require auth
if path == h.statusPath {
h.handleStatus(ctx)
return
}
// Authenticate request
var session *auth.Session
if h.authenticator != nil {
authHeader := string(ctx.Request.Header.Peek("Authorization"))
var err error
session, err = h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
if err != nil {
h.authFailures.Add(1)
h.logger.Warn("msg", "Authentication failed",
"component", "http_sink",
"remote_addr", remoteAddr,
"error", err)
// Return 401 with WWW-Authenticate header
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
realm := h.authConfig.BasicAuth.Realm
if realm == "" {
realm = "LogWisp"
}
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
} else if h.authConfig.Type == "bearer" {
ctx.Response.Header.Set("WWW-Authenticate", "Bearer")
}
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "Authentication required",
})
return
}
h.authSuccesses.Add(1)
}
switch path { switch path {
case h.streamPath: case h.streamPath:
h.handleStream(ctx) h.handleStream(ctx, session)
case h.statusPath:
h.handleStatus(ctx)
default: default:
ctx.SetStatusCode(fasthttp.StatusNotFound) ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
@ -292,7 +385,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
} }
} }
func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) { func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) {
// Track connection for net limiting // Track connection for net limiting
remoteAddr := ctx.RemoteAddr().String() remoteAddr := ctx.RemoteAddr().String()
if h.netLimiter != nil { if h.netLimiter != nil {
@ -330,7 +423,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
case <-h.done: case <-h.done:
return return
default: default:
// Drop if client buffer full, may flood logging for slow client // Drop if client buffer full
h.logger.Debug("msg", "Dropped entry for slow client", h.logger.Debug("msg", "Dropped entry for slow client",
"component", "http_sink", "component", "http_sink",
"remote_addr", remoteAddr) "remote_addr", remoteAddr)
@ -348,6 +441,8 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
newCount := h.activeClients.Add(1) newCount := h.activeClients.Add(1)
h.logger.Debug("msg", "HTTP client connected", h.logger.Debug("msg", "HTTP client connected",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
"username", session.Username,
"auth_method", session.Method,
"active_clients", newCount) "active_clients", newCount)
h.wg.Add(1) h.wg.Add(1)
@ -356,6 +451,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
newCount := h.activeClients.Add(-1) newCount := h.activeClients.Add(-1)
h.logger.Debug("msg", "HTTP client disconnected", h.logger.Debug("msg", "HTTP client disconnected",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
"username", session.Username,
"active_clients", newCount) "active_clients", newCount)
h.wg.Done() h.wg.Done()
}() }()
@ -364,12 +460,15 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
clientID := fmt.Sprintf("%d", time.Now().UnixNano()) clientID := fmt.Sprintf("%d", time.Now().UnixNano())
connectionInfo := map[string]any{ connectionInfo := map[string]any{
"client_id": clientID, "client_id": clientID,
"username": session.Username,
"auth_method": session.Method,
"stream_path": h.streamPath, "stream_path": h.streamPath,
"status_path": h.statusPath, "status_path": h.statusPath,
"buffer_size": h.config.BufferSize, "buffer_size": h.config.BufferSize,
"tls": h.tlsManager != nil,
} }
data, _ := json.Marshal(connectionInfo) data, _ := json.Marshal(connectionInfo)
fmt.Fprintf(w, "event: connected\ndata: %s\n", data) fmt.Fprintf(w, "event: connected\ndata: %s\n\n", data)
w.Flush() w.Flush()
var ticker *time.Ticker var ticker *time.Ticker
@ -402,6 +501,13 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx) {
} }
case <-tickerChan: case <-tickerChan:
// Validate session is still active
if h.authenticator != nil && !h.authenticator.ValidateSession(session.ID) {
fmt.Fprintf(w, "event: disconnect\ndata: {\"reason\":\"session_expired\"}\n\n")
w.Flush()
return
}
heartbeatEntry := h.createHeartbeatEntry() heartbeatEntry := h.createHeartbeatEntry()
if err := h.formatEntryForSSE(w, heartbeatEntry); err != nil { if err := h.formatEntryForSSE(w, heartbeatEntry); err != nil {
h.logger.Error("msg", "Failed to format heartbeat", h.logger.Error("msg", "Failed to format heartbeat",
@ -437,8 +543,10 @@ func (h *HTTPSink) formatEntryForSSE(w *bufio.Writer, entry core.LogEntry) error
lines := bytes.Split(formatted, []byte{'\n'}) lines := bytes.Split(formatted, []byte{'\n'})
for _, line := range lines { for _, line := range lines {
// SSE needs "data: " prefix for each line // SSE needs "data: " prefix for each line
// TODO: validate above, is 'data: ' really necessary? make it optional if it works without it?
fmt.Fprintf(w, "data: %s\n", line) fmt.Fprintf(w, "data: %s\n", line)
} }
fmt.Fprintf(w, "\n") // Empty line to terminate event
return nil return nil
} }
@ -478,6 +586,26 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
} }
} }
var authStats any
if h.authenticator != nil {
authStats = h.authenticator.GetStats()
authStats.(map[string]any)["failures"] = h.authFailures.Load()
authStats.(map[string]any)["successes"] = h.authSuccesses.Load()
} else {
authStats = map[string]any{
"enabled": false,
}
}
var tlsStats any
if h.tlsManager != nil {
tlsStats = h.tlsManager.GetStats()
} else {
tlsStats = map[string]any{
"enabled": false,
}
}
status := map[string]any{ status := map[string]any{
"service": "LogWisp", "service": "LogWisp",
"version": version.Short(), "version": version.Short(),
@ -487,7 +615,6 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
"active_clients": h.activeClients.Load(), "active_clients": h.activeClients.Load(),
"buffer_size": h.config.BufferSize, "buffer_size": h.config.BufferSize,
"uptime_seconds": int(time.Since(h.startTime).Seconds()), "uptime_seconds": int(time.Since(h.startTime).Seconds()),
"mode": map[string]bool{"standalone": h.standalone, "router": !h.standalone},
}, },
"endpoints": map[string]string{ "endpoints": map[string]string{
"transport": h.streamPath, "transport": h.streamPath,
@ -499,11 +626,15 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
"interval": h.config.Heartbeat.IntervalSeconds, "interval": h.config.Heartbeat.IntervalSeconds,
"format": h.config.Heartbeat.Format, "format": h.config.Heartbeat.Format,
}, },
"ssl": map[string]bool{ "tls": tlsStats,
"enabled": h.config.SSL != nil && h.config.SSL.Enabled, "auth": authStats,
},
"net_limit": netLimitStats, "net_limit": netLimitStats,
}, },
"statistics": map[string]any{
"total_processed": h.totalProcessed.Load(),
"auth_failures": h.authFailures.Load(),
"auth_successes": h.authSuccesses.Load(),
},
} }
data, _ := json.Marshal(status) data, _ := json.Marshal(status)
@ -524,3 +655,33 @@ func (h *HTTPSink) GetStreamPath() string {
func (h *HTTPSink) GetStatusPath() string { func (h *HTTPSink) GetStatusPath() string {
return h.statusPath return h.statusPath
} }
func (h *HTTPSink) SetNetAccessConfig(cfg *config.NetAccessConfig) {
h.ipChecker = limit.NewIPChecker(cfg, h.logger)
if h.ipChecker != nil {
h.logger.Info("msg", "IP access control configured for HTTP sink",
"component", "http_sink")
}
}
// SetAuthConfig configures http sink authentication
func (h *HTTPSink) SetAuthConfig(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
h.authConfig = authCfg
authenticator, err := auth.New(authCfg, h.logger)
if err != nil {
h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink",
"component", "http_sink",
"error", err)
// Continue without auth
return
}
h.authenticator = authenticator
h.logger.Info("msg", "Authentication configured for HTTP sink",
"component", "http_sink",
"auth_type", authCfg.Type)
}

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"time" "time"
"logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
) )
@ -32,3 +33,13 @@ type SinkStats struct {
LastProcessed time.Time LastProcessed time.Time
Details map[string]any Details map[string]any
} }
// NetAccessSetter is an interface for sinks that can accept network access configuration
type NetAccessSetter interface {
SetNetAccessConfig(cfg *config.NetAccessConfig)
}
// AuthSetter is an interface for sinks that can accept an AuthConfig.
type AuthSetter interface {
SetAuthConfig(auth *config.AuthConfig)
}

View File

@ -2,18 +2,22 @@
package sink package sink
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"logwisp/src/internal/auth"
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/format" "logwisp/src/internal/format"
"logwisp/src/internal/limit" "logwisp/src/internal/limit"
"logwisp/src/internal/tls"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
"github.com/lixenwraith/log/compat" "github.com/lixenwraith/log/compat"
@ -32,12 +36,20 @@ type TCPSink struct {
engineMu sync.Mutex engineMu sync.Mutex
wg sync.WaitGroup wg sync.WaitGroup
netLimiter *limit.NetLimiter netLimiter *limit.NetLimiter
ipChecker *limit.IPChecker
logger *log.Logger logger *log.Logger
formatter format.Formatter formatter format.Formatter
// Security components
authenticator *auth.Authenticator
tlsManager *tls.Manager
authConfig *config.AuthConfig
// Statistics // Statistics
totalProcessed atomic.Uint64 totalProcessed atomic.Uint64
lastProcessed atomic.Value // time.Time lastProcessed atomic.Value // time.Time
authFailures atomic.Uint64
authSuccesses atomic.Uint64
} }
// TCPConfig holds TCP sink configuration // TCPConfig holds TCP sink configuration
@ -78,6 +90,32 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
} }
} }
// Extract SSL config
if ssl, ok := options["ssl"].(map[string]any); ok {
cfg.SSL = &config.SSLConfig{}
cfg.SSL.Enabled, _ = ssl["enabled"].(bool)
if certFile, ok := ssl["cert_file"].(string); ok {
cfg.SSL.CertFile = certFile
}
if keyFile, ok := ssl["key_file"].(string); ok {
cfg.SSL.KeyFile = keyFile
}
cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool)
if caFile, ok := ssl["client_ca_file"].(string); ok {
cfg.SSL.ClientCAFile = caFile
}
cfg.SSL.VerifyClientCert, _ = ssl["verify_client_cert"].(bool)
if minVer, ok := ssl["min_version"].(string); ok {
cfg.SSL.MinVersion = minVer
}
if maxVer, ok := ssl["max_version"].(string); ok {
cfg.SSL.MaxVersion = maxVer
}
if ciphers, ok := ssl["cipher_suites"].(string); ok {
cfg.SSL.CipherSuites = ciphers
}
}
// Extract net limit config // Extract net limit config
if rl, ok := options["net_limit"].(map[string]any); ok { if rl, ok := options["net_limit"].(map[string]any); ok {
cfg.NetLimit = &config.NetLimitConfig{} cfg.NetLimit = &config.NetLimitConfig{}
@ -115,6 +153,7 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
} }
t.lastProcessed.Store(time.Time{}) t.lastProcessed.Store(time.Time{})
// Initialize net limiter
if cfg.NetLimit != nil && cfg.NetLimit.Enabled { if cfg.NetLimit != nil && cfg.NetLimit.Enabled {
t.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger) t.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger)
} }
@ -127,7 +166,10 @@ func (t *TCPSink) Input() chan<- core.LogEntry {
} }
func (t *TCPSink) Start(ctx context.Context) error { func (t *TCPSink) Start(ctx context.Context) error {
t.server = &tcpServer{sink: t} t.server = &tcpServer{
sink: t,
clients: make(map[gnet.Conn]*tcpClient),
}
// Start log broadcast loop // Start log broadcast loop
t.wg.Add(1) t.wg.Add(1)
@ -136,24 +178,39 @@ func (t *TCPSink) Start(ctx context.Context) error {
t.broadcastLoop(ctx) t.broadcastLoop(ctx)
}() }()
// Configure gnet // Configure gnet options
addr := fmt.Sprintf("tcp://:%d", t.config.Port) addr := fmt.Sprintf("tcp://:%d", t.config.Port)
// Create a gnet adapter using the existing logger instance // Create a gnet adapter using the existing logger instance
gnetLogger := compat.NewGnetAdapter(t.logger) gnetLogger := compat.NewGnetAdapter(t.logger)
var opts []gnet.Option
opts = append(opts,
gnet.WithLogger(gnetLogger),
gnet.WithMulticore(true),
gnet.WithReusePort(true),
)
// Add TLS if configured
if t.tlsManager != nil {
// tlsConfig := t.tlsManager.GetTCPConfig()
// TODO: tlsConfig is not used, wrapper to be implemented, non-TLS stream to be available without wrapper
// ☢ SECURITY: gnet doesn't support TLS natively - would need wrapper
// This is a limitation that requires implementing TLS at application layer
t.logger.Warn("msg", "TLS configured but gnet doesn't support native TLS",
"component", "tcp_sink",
"workaround", "Use stunnel or nginx TCP proxy for TLS termination")
}
// Start gnet server // Start gnet server
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
t.logger.Info("msg", "Starting TCP server", t.logger.Info("msg", "Starting TCP server",
"component", "tcp_sink", "component", "tcp_sink",
"port", t.config.Port) "port", t.config.Port,
"auth", t.authenticator != nil)
err := gnet.Run(t.server, addr, err := gnet.Run(t.server, addr, opts...)
gnet.WithLogger(gnetLogger),
gnet.WithMulticore(true),
gnet.WithReusePort(true),
)
if err != nil { if err != nil {
t.logger.Error("msg", "TCP server failed", t.logger.Error("msg", "TCP server failed",
"component", "tcp_sink", "component", "tcp_sink",
@ -219,6 +276,18 @@ func (t *TCPSink) GetStats() SinkStats {
netLimitStats = t.netLimiter.GetStats() netLimitStats = t.netLimiter.GetStats()
} }
var authStats map[string]any
if t.authenticator != nil {
authStats = t.authenticator.GetStats()
authStats["failures"] = t.authFailures.Load()
authStats["successes"] = t.authSuccesses.Load()
}
var tlsStats map[string]any
if t.tlsManager != nil {
tlsStats = t.tlsManager.GetStats()
}
return SinkStats{ return SinkStats{
Type: "tcp", Type: "tcp",
TotalProcessed: t.totalProcessed.Load(), TotalProcessed: t.totalProcessed.Load(),
@ -229,6 +298,8 @@ func (t *TCPSink) GetStats() SinkStats {
"port": t.config.Port, "port": t.config.Port,
"buffer_size": t.config.BufferSize, "buffer_size": t.config.BufferSize,
"net_limit": netLimitStats, "net_limit": netLimitStats,
"auth": authStats,
"tls": tlsStats,
}, },
} }
} }
@ -263,11 +334,14 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
continue continue
} }
t.server.connections.Range(func(key, value any) bool { // Broadcast only to authenticated clients
conn := key.(gnet.Conn) t.server.mu.RLock()
conn.AsyncWrite(data, nil) for conn, client := range t.server.clients {
return true if client.authenticated {
}) conn.AsyncWrite(data, nil)
}
}
t.server.mu.RUnlock()
case <-tickerChan: case <-tickerChan:
heartbeatEntry := t.createHeartbeatEntry() heartbeatEntry := t.createHeartbeatEntry()
@ -279,11 +353,21 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
continue continue
} }
t.server.connections.Range(func(key, value any) bool { t.server.mu.RLock()
conn := key.(gnet.Conn) for conn, client := range t.server.clients {
conn.AsyncWrite(data, nil) if client.authenticated {
return true // Validate session is still active
}) if t.authenticator != nil && client.session != nil {
if !t.authenticator.ValidateSession(client.session.ID) {
// Session expired, close connection
conn.Close()
continue
}
}
conn.AsyncWrite(data, nil)
}
}
t.server.mu.RUnlock()
case <-t.done: case <-t.done:
return return
@ -320,11 +404,21 @@ func (t *TCPSink) GetActiveConnections() int64 {
return t.activeConns.Load() return t.activeConns.Load()
} }
// tcpServer handles gnet events // tcpClient represents a connected TCP client with auth state
type tcpClient struct {
conn gnet.Conn
buffer bytes.Buffer
authenticated bool
session *auth.Session
authTimeout time.Time
}
// tcpServer handles gnet events with authentication
type tcpServer struct { type tcpServer struct {
gnet.BuiltinEventEngine gnet.BuiltinEventEngine
sink *TCPSink sink *TCPSink
connections sync.Map clients map[gnet.Conn]*tcpClient
mu sync.RWMutex
} }
func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action { func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
@ -343,9 +437,17 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
remoteAddr := c.RemoteAddr().String() remoteAddr := c.RemoteAddr().String()
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr) s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
// Check IP access control first
if s.sink.ipChecker != nil {
if !s.sink.ipChecker.IsAllowed(c.RemoteAddr()) {
s.sink.logger.Warn("msg", "TCP connection denied by IP filter",
"remote_addr", remoteAddr)
return nil, gnet.Close
}
}
// Check net limit // Check net limit
if s.sink.netLimiter != nil { if s.sink.netLimiter != nil {
// Parse the remote address to get proper net.Addr
remoteStr := c.RemoteAddr().String() remoteStr := c.RemoteAddr().String()
tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr) tcpAddr, err := net.ResolveTCPAddr("tcp", remoteStr)
if err != nil { if err != nil {
@ -358,7 +460,6 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
if !s.sink.netLimiter.CheckTCP(tcpAddr) { if !s.sink.netLimiter.CheckTCP(tcpAddr) {
s.sink.logger.Warn("msg", "TCP connection net limited", s.sink.logger.Warn("msg", "TCP connection net limited",
"remote_addr", remoteAddr) "remote_addr", remoteAddr)
// Silently close connection when net limited
return nil, gnet.Close return nil, gnet.Close
} }
@ -366,24 +467,43 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
s.sink.netLimiter.AddConnection(remoteStr) s.sink.netLimiter.AddConnection(remoteStr)
} }
s.connections.Store(c, struct{}{}) // Create client state
client := &tcpClient{
conn: c,
authenticated: s.sink.authenticator == nil, // No auth = auto authenticated
authTimeout: time.Now().Add(30 * time.Second), // 30s to authenticate
}
s.mu.Lock()
s.clients[c] = client
s.mu.Unlock()
newCount := s.sink.activeConns.Add(1) newCount := s.sink.activeConns.Add(1)
s.sink.logger.Debug("msg", "TCP connection opened", s.sink.logger.Debug("msg", "TCP connection opened",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
"active_connections", newCount) "active_connections", newCount,
"requires_auth", s.sink.authenticator != nil)
// Send auth prompt if authentication is required
if s.sink.authenticator != nil {
authPrompt := []byte("AUTH REQUIRED\nFormat: AUTH <method> <credentials>\nMethods: basic, token\n")
return authPrompt, gnet.None
}
return nil, gnet.None return nil, gnet.None
} }
func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action { func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
s.connections.Delete(c)
remoteAddr := c.RemoteAddr().String() remoteAddr := c.RemoteAddr().String()
// Remove client state
s.mu.Lock()
delete(s.clients, c)
s.mu.Unlock()
// Remove connection tracking // Remove connection tracking
if s.sink.netLimiter != nil { if s.sink.netLimiter != nil {
s.sink.netLimiter.RemoveConnection(c.RemoteAddr().String()) s.sink.netLimiter.RemoveConnection(remoteAddr)
} }
newCount := s.sink.activeConns.Add(-1) newCount := s.sink.activeConns.Add(-1)
@ -395,7 +515,114 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
} }
func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
// We don't expect input from clients, just discard s.mu.RLock()
client, exists := s.clients[c]
s.mu.RUnlock()
if !exists {
return gnet.Close
}
// Check auth timeout
if !client.authenticated && time.Now().After(client.authTimeout) {
s.sink.logger.Warn("msg", "Authentication timeout",
"remote_addr", c.RemoteAddr().String())
c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil)
return gnet.Close
}
// Read all available data
data, err := c.Next(-1)
if err != nil {
s.sink.logger.Error("msg", "Error reading from connection",
"component", "tcp_sink",
"error", err)
return gnet.Close
}
// If not authenticated, expect auth command
if !client.authenticated {
client.buffer.Write(data)
// Look for complete auth line
if line, err := client.buffer.ReadBytes('\n'); err == nil {
line = bytes.TrimSpace(line)
// Parse AUTH command: AUTH <method> <credentials>
parts := strings.SplitN(string(line), " ", 3)
if len(parts) != 3 || parts[0] != "AUTH" {
c.AsyncWrite([]byte("ERROR: Invalid auth format\n"), nil)
return gnet.None
}
// Authenticate
session, err := s.sink.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String())
if err != nil {
s.sink.authFailures.Add(1)
s.sink.logger.Warn("msg", "TCP authentication failed",
"remote_addr", c.RemoteAddr().String(),
"method", parts[1],
"error", err)
c.AsyncWrite([]byte(fmt.Sprintf("AUTH FAILED: %v\n", err)), nil)
return gnet.Close
}
// Authentication successful
s.sink.authSuccesses.Add(1)
s.mu.Lock()
client.authenticated = true
client.session = session
s.mu.Unlock()
s.sink.logger.Info("msg", "TCP client authenticated",
"remote_addr", c.RemoteAddr().String(),
"username", session.Username,
"method", session.Method)
c.AsyncWrite([]byte("AUTH OK\n"), nil)
// Clear buffer after auth
client.buffer.Reset()
}
return gnet.None
}
// Authenticated clients shouldn't send data, just discard
c.Discard(-1) c.Discard(-1)
return gnet.None return gnet.None
} }
// SetAuthConfig configures tcp sink authentication
func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
t.authConfig = authCfg
authenticator, err := auth.New(authCfg, t.logger)
if err != nil {
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
"component", "tcp_sink",
"error", err)
return
}
t.authenticator = authenticator
// Initialize TLS manager if SSL is configured
if t.config.SSL != nil && t.config.SSL.Enabled {
tlsManager, err := tls.New(t.config.SSL, t.logger)
if err != nil {
t.logger.Error("msg", "Failed to create TLS manager",
"component", "tcp_sink",
"error", err)
// Continue without TLS
return
}
t.tlsManager = tlsManager
}
t.logger.Info("msg", "Authentication configured for TCP sink",
"component", "tcp_sink",
"auth_type", authCfg.Type,
"tls_enabled", t.tlsManager != nil)
}

249
src/internal/tls/manager.go Normal file
View File

@ -0,0 +1,249 @@
// FILE: logwisp/src/internal/tls/manager.go
package tls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"strings"
"logwisp/src/internal/config"
"github.com/lixenwraith/log"
)
// Manager handles TLS configuration for servers
type Manager struct {
config *config.SSLConfig
tlsConfig *tls.Config
logger *log.Logger
}
// New creates a TLS configuration from SSL config
func New(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) {
if cfg == nil || !cfg.Enabled {
return nil, nil
}
m := &Manager{
config: cfg,
logger: logger,
}
// Load certificate and key
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load cert/key: %w", err)
}
// Create base TLS config
m.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: parseTLSVersion(cfg.MinVersion, tls.VersionTLS12),
MaxVersion: parseTLSVersion(cfg.MaxVersion, tls.VersionTLS13),
}
// Configure cipher suites if specified
if cfg.CipherSuites != "" {
m.tlsConfig.CipherSuites = parseCipherSuites(cfg.CipherSuites)
} else {
// Use secure defaults
m.tlsConfig.CipherSuites = []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
}
}
// Configure client authentication (mTLS)
if cfg.ClientAuth {
if cfg.VerifyClientCert {
m.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
} else {
m.tlsConfig.ClientAuth = tls.RequireAnyClientCert
}
// Load client CA if specified
if cfg.ClientCAFile != "" {
caCert, err := os.ReadFile(cfg.ClientCAFile)
if err != nil {
return nil, fmt.Errorf("failed to read client CA: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse client CA certificate")
}
m.tlsConfig.ClientCAs = caCertPool
}
}
// Set secure defaults
m.tlsConfig.PreferServerCipherSuites = true
m.tlsConfig.SessionTicketsDisabled = false
m.tlsConfig.Renegotiation = tls.RenegotiateNever
logger.Info("msg", "TLS manager initialized",
"component", "tls",
"min_version", cfg.MinVersion,
"max_version", cfg.MaxVersion,
"client_auth", cfg.ClientAuth,
"cipher_count", len(m.tlsConfig.CipherSuites))
return m, nil
}
// GetConfig returns the TLS configuration
func (m *Manager) GetConfig() *tls.Config {
if m == nil {
return nil
}
// Return a clone to prevent modification
return m.tlsConfig.Clone()
}
// GetHTTPConfig returns TLS config suitable for HTTP servers
func (m *Manager) GetHTTPConfig() *tls.Config {
if m == nil {
return nil
}
cfg := m.tlsConfig.Clone()
// Enable HTTP/2
cfg.NextProtos = []string{"h2", "http/1.1"}
return cfg
}
// GetTCPConfig returns TLS config for raw TCP connections
func (m *Manager) GetTCPConfig() *tls.Config {
if m == nil {
return nil
}
cfg := m.tlsConfig.Clone()
// No ALPN for raw TCP
cfg.NextProtos = nil
return cfg
}
// ValidateClientCert validates a client certificate for mTLS
func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
if m == nil || !m.config.ClientAuth {
return nil
}
if len(rawCerts) == 0 {
return fmt.Errorf("no client certificate provided")
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return fmt.Errorf("failed to parse client certificate: %w", err)
}
// Verify against CA if configured
if m.tlsConfig.ClientCAs != nil {
opts := x509.VerifyOptions{
Roots: m.tlsConfig.ClientCAs,
Intermediates: x509.NewCertPool(),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
// Add any intermediate certs
for i := 1; i < len(rawCerts); i++ {
intermediate, err := x509.ParseCertificate(rawCerts[i])
if err != nil {
continue
}
opts.Intermediates.AddCert(intermediate)
}
if _, err := cert.Verify(opts); err != nil {
return fmt.Errorf("client certificate verification failed: %w", err)
}
}
m.logger.Debug("msg", "Client certificate validated",
"component", "tls",
"subject", cert.Subject.String(),
"serial", cert.SerialNumber.String())
return nil
}
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
switch strings.ToUpper(version) {
case "TLS1.0", "TLS10":
return tls.VersionTLS10
case "TLS1.1", "TLS11":
return tls.VersionTLS11
case "TLS1.2", "TLS12":
return tls.VersionTLS12
case "TLS1.3", "TLS13":
return tls.VersionTLS13
default:
return defaultVersion
}
}
func parseCipherSuites(suites string) []uint16 {
var result []uint16
// Map of cipher suite names to IDs
suiteMap := map[string]uint16{
// TLS 1.2 ECDHE suites (preferred)
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
// RSA suites (less preferred)
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
}
for _, suite := range strings.Split(suites, ",") {
suite = strings.TrimSpace(suite)
if id, ok := suiteMap[suite]; ok {
result = append(result, id)
}
}
return result
}
// GetStats returns TLS statistics
func (m *Manager) GetStats() map[string]any {
if m == nil {
return map[string]any{"enabled": false}
}
return map[string]any{
"enabled": true,
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
"client_auth": m.config.ClientAuth,
"cipher_suites": len(m.tlsConfig.CipherSuites),
}
}
func tlsVersionString(version uint16) string {
switch version {
case tls.VersionTLS10:
return "TLS1.0"
case tls.VersionTLS11:
return "TLS1.1"
case tls.VersionTLS12:
return "TLS1.2"
case tls.VersionTLS13:
return "TLS1.3"
default:
return fmt.Sprintf("0x%04x", version)
}
}