diff --git a/go.mod b/go.mod index 25600e0..5cb7cfb 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,12 @@ module logwisp go 1.25.1 require ( - github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 + github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6 github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 github.com/panjf2000/gnet/v2 v2.9.4 - github.com/stretchr/testify v1.10.0 - github.com/valyala/fasthttp v1.66.0 - golang.org/x/crypto v0.42.0 - golang.org/x/term v0.35.0 - golang.org/x/time v0.13.0 + github.com/valyala/fasthttp v1.67.0 + golang.org/x/crypto v0.43.0 + golang.org/x/term v0.36.0 ) require ( @@ -20,12 +18,11 @@ require ( github.com/klauspost/compress v1.18.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect golang.org/x/sync v0.17.0 // indirect - golang.org/x/sys v0.36.0 // indirect + golang.org/x/sys v0.37.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d7220f2..ec1d8c2 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGL github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0= -github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0= +github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6 h1:G9qP8biXBT6bwBOjEe1tZwjA0gPuB5DC+fLBRXDNXqo= +github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0= github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 h1:9Qf+BR83sKjok2E1Nct+3Sfzoj2dLGwC/zyQDVNmmqs= github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0= github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= @@ -22,8 +22,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.66.0 h1:M87A0Z7EayeyNaV6pfO3tUTUiYO0dZfEJnRGXTVNuyU= -github.com/valyala/fasthttp v1.66.0/go.mod h1:Y4eC+zwoocmXSVCB1JmhNbYtS7tZPRI2ztPB72EVObs= +github.com/valyala/fasthttp v1.67.0 h1:tqKlJMUP6iuNG8hGjK/s9J4kadH7HLV4ijEcPGsezac= +github.com/valyala/fasthttp v1.67.0/go.mod h1:qYSIpqt/0XNmShgo/8Aq8E3UYWVVwNS2QYmzd8WIEPM= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -32,16 +32,14 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= -golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= -golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= -golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= -golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= +golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/src/cmd/logwisp/bootstrap.go b/src/cmd/logwisp/bootstrap.go index e81dcea..5404f55 100644 --- a/src/cmd/logwisp/bootstrap.go +++ b/src/cmd/logwisp/bootstrap.go @@ -24,7 +24,7 @@ func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service logger.Info("msg", "Initializing pipeline", "pipeline", pipelineCfg.Name) // Create the pipeline - if err := svc.NewPipeline(pipelineCfg); err != nil { + if err := svc.NewPipeline(&pipelineCfg); err != nil { logger.Error("msg", "Failed to create pipeline", "pipeline", pipelineCfg.Name, "error", err) diff --git a/src/cmd/logwisp/commands.go b/src/cmd/logwisp/commands.go deleted file mode 100644 index d0898e4..0000000 --- a/src/cmd/logwisp/commands.go +++ /dev/null @@ -1,138 +0,0 @@ -// FILE: src/cmd/logwisp/commands.go -package main - -import ( - "fmt" - "os" - - "logwisp/src/internal/auth" - "logwisp/src/internal/tls" - "logwisp/src/internal/version" -) - -// Handles subcommand routing before main app initialization -type CommandRouter struct { - commands map[string]CommandHandler -} - -// Defines the interface for subcommands -type CommandHandler interface { - Execute(args []string) error - Description() string -} - -// Creates and initializes the command router -func NewCommandRouter() *CommandRouter { - router := &CommandRouter{ - commands: make(map[string]CommandHandler), - } - - // Register available commands - router.commands["auth"] = &authCommand{} - router.commands["version"] = &versionCommand{} - router.commands["help"] = &helpCommand{} - router.commands["tls"] = &tlsCommand{} - - return router -} - -// Checks for and executes subcommands -func (r *CommandRouter) Route(args []string) error { - if len(args) < 1 { - return nil - } - - // Check for help flags anywhere in args - for _, arg := range args[1:] { // Skip program name - if arg == "-h" || arg == "--help" || arg == "help" { - // Show main help and exit regardless of other flags - r.commands["help"].Execute(nil) - os.Exit(0) - } - } - - // Check for commands - if len(args) > 1 { - cmdName := args[1] - - if handler, exists := r.commands[cmdName]; exists { - if err := handler.Execute(args[2:]); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - os.Exit(0) - } - - // Check if it looks like a mistyped command (not a flag) - if cmdName[0] != '-' { - fmt.Fprintf(os.Stderr, "Unknown command: %s\n", cmdName) - fmt.Fprintln(os.Stderr, "\nAvailable commands:") - r.ShowCommands() - os.Exit(1) - } - } - - return nil -} - -// Displays available subcommands -func (r *CommandRouter) ShowCommands() { - fmt.Fprintln(os.Stderr, " auth Generate authentication credentials") - fmt.Fprintln(os.Stderr, " tls Generate TLS certificates") - fmt.Fprintln(os.Stderr, " version Show version information") - fmt.Fprintln(os.Stderr, " help Display help information") - fmt.Fprintln(os.Stderr, "\nUse 'logwisp --help' for command-specific help") -} - -// TODO: Future: refactor with a new command interface -type helpCommand struct{} - -func (c *helpCommand) Execute(args []string) error { - // Check if help is requested for a specific command - if len(args) > 0 { - // TODO: Future: show command-specific help - // For now, just show general help - } - fmt.Print(helpText) - return nil -} - -func (c *helpCommand) Description() string { - return "Display help information" -} - -// authCommand wrapper -type authCommand struct{} - -func (c *authCommand) Execute(args []string) error { - gen := auth.NewAuthGeneratorCommand() - return gen.Execute(args) -} - -func (c *authCommand) Description() string { - return "Generate authentication credentials (passwords, tokens)" -} - -// versionCommand wrapper -type versionCommand struct{} - -func (c *versionCommand) Execute(args []string) error { - fmt.Println(version.String()) - return nil -} - -func (c *versionCommand) Description() string { - return "Show version information" -} - -// tlsCommand wrapper -type tlsCommand struct{} - -func (c *tlsCommand) Execute(args []string) error { - gen := tls.NewCertGeneratorCommand() - return gen.Execute(args) -} - -func (c *tlsCommand) Description() string { - return "Generate TLS certificates (CA, server, client)" -} \ No newline at end of file diff --git a/src/cmd/logwisp/commands/auth.go b/src/cmd/logwisp/commands/auth.go new file mode 100644 index 0000000..d7b9c05 --- /dev/null +++ b/src/cmd/logwisp/commands/auth.go @@ -0,0 +1,355 @@ +// FILE: src/cmd/logwisp/commands/auth.go +package commands + +import ( + "crypto/rand" + "encoding/base64" + "flag" + "fmt" + "io" + "os" + "strings" + "syscall" + + "logwisp/src/internal/auth" + "logwisp/src/internal/core" + + "golang.org/x/term" +) + +type AuthCommand struct { + output io.Writer + errOut io.Writer +} + +func NewAuthCommand() *AuthCommand { + return &AuthCommand{ + output: os.Stdout, + errOut: os.Stderr, + } +} + +func (ac *AuthCommand) Execute(args []string) error { + cmd := flag.NewFlagSet("auth", flag.ContinueOnError) + cmd.SetOutput(ac.errOut) + + var ( + // User credentials + username = cmd.String("u", "", "Username") + usernameLong = cmd.String("user", "", "Username") + password = cmd.String("p", "", "Password (will prompt if not provided)") + passwordLong = cmd.String("password", "", "Password (will prompt if not provided)") + + // Auth type selection (multiple ways to specify) + authType = cmd.String("t", "", "Auth type: basic, scram, or token") + authTypeLong = cmd.String("type", "", "Auth type: basic, scram, or token") + useScram = cmd.Bool("s", false, "Generate SCRAM credentials (TCP)") + useScramLong = cmd.Bool("scram", false, "Generate SCRAM credentials (TCP)") + useBasic = cmd.Bool("b", false, "Generate basic auth credentials (HTTP)") + useBasicLong = cmd.Bool("basic", false, "Generate basic auth credentials (HTTP)") + + // Token generation + genToken = cmd.Bool("k", false, "Generate random bearer token") + genTokenLong = cmd.Bool("token", false, "Generate random bearer token") + tokenLen = cmd.Int("l", 32, "Token length in bytes") + tokenLenLong = cmd.Int("length", 32, "Token length in bytes") + + // Migration option + migrate = cmd.Bool("m", false, "Convert basic auth PHC to SCRAM") + migrateLong = cmd.Bool("migrate", false, "Convert basic auth PHC to SCRAM") + phcHash = cmd.String("phc", "", "PHC hash to migrate (required with --migrate)") + ) + + cmd.Usage = func() { + fmt.Fprintln(ac.errOut, "Generate authentication credentials for LogWisp") + fmt.Fprintln(ac.errOut, "\nUsage: logwisp auth [options]") + fmt.Fprintln(ac.errOut, "\nExamples:") + fmt.Fprintln(ac.errOut, " # Generate basic auth hash for HTTP sources/sinks") + fmt.Fprintln(ac.errOut, " logwisp auth -u admin -b") + fmt.Fprintln(ac.errOut, " logwisp auth --user=admin --basic") + fmt.Fprintln(ac.errOut, " ") + fmt.Fprintln(ac.errOut, " # Generate SCRAM credentials for TCP") + fmt.Fprintln(ac.errOut, " logwisp auth -u tcpuser -s") + fmt.Fprintln(ac.errOut, " logwisp auth --user=tcpuser --scram") + fmt.Fprintln(ac.errOut, " ") + fmt.Fprintln(ac.errOut, " # Generate bearer token") + fmt.Fprintln(ac.errOut, " logwisp auth -k -l 64") + fmt.Fprintln(ac.errOut, " logwisp auth --token --length=64") + fmt.Fprintln(ac.errOut, "\nOptions:") + cmd.PrintDefaults() + } + + if err := cmd.Parse(args); err != nil { + return err + } + + // Check for unparsed arguments + if cmd.NArg() > 0 { + return fmt.Errorf("unexpected argument(s): %s", strings.Join(cmd.Args(), " ")) + } + + // Merge short and long form values + finalUsername := coalesceString(*username, *usernameLong) + finalPassword := coalesceString(*password, *passwordLong) + finalAuthType := coalesceString(*authType, *authTypeLong) + finalGenToken := coalesceBool(*genToken, *genTokenLong) + finalTokenLen := coalesceInt(*tokenLen, *tokenLenLong, core.DefaultTokenLength) + finalUseScram := coalesceBool(*useScram, *useScramLong) + finalUseBasic := coalesceBool(*useBasic, *useBasicLong) + finalMigrate := coalesceBool(*migrate, *migrateLong) + + // Handle migration mode + if finalMigrate { + if *phcHash == "" || finalUsername == "" || finalPassword == "" { + return fmt.Errorf("--migrate requires --user, --password, and --phc flags") + } + return ac.migrateToScram(finalUsername, finalPassword, *phcHash) + } + + // Determine auth type from flags + if finalGenToken || finalAuthType == "token" { + return ac.generateToken(finalTokenLen) + } + + // Determine credential type + credType := "basic" // default + + // Check explicit type flags + if finalUseScram || finalAuthType == "scram" { + credType = "scram" + } else if finalUseBasic || finalAuthType == "basic" { + credType = "basic" + } else if finalAuthType != "" { + return fmt.Errorf("invalid auth type: %s (valid: basic, scram, token)", finalAuthType) + } + + // Username required for password-based auth + if finalUsername == "" { + cmd.Usage() + return fmt.Errorf("username required for %s auth generation", credType) + } + + return ac.generatePasswordHash(finalUsername, finalPassword, credType) +} + +func (ac *AuthCommand) Description() string { + return "Generate authentication credentials (passwords, tokens, SCRAM)" +} + +func (ac *AuthCommand) Help() string { + return `Auth Command - Generate authentication credentials for LogWisp + +Usage: + logwisp auth [options] + +Authentication Types: + HTTP/HTTPS Sources & Sinks (TLS required): + - Basic Auth: Username/password with Argon2id hashing + - Bearer Token: Random cryptographic tokens + + TCP Sources & Sinks (No TLS): + - SCRAM: Argon2-SCRAM-SHA256 for plaintext connections + +Options: + -u, --user Username for credential generation + -p, --password Password (will prompt if not provided) + -t, --type Auth type: "basic", "scram", or "token" + -b, --basic Generate basic auth credentials (HTTP/HTTPS) + -s, --scram Generate SCRAM credentials (TCP) + -k, --token Generate random bearer token + -l, --length Token length in bytes (default: 32) + +Examples: +Examples: + # Generate basic auth hash for HTTP/HTTPS (with TLS) + logwisp auth -u admin -b + logwisp auth --user=admin --basic + + # Generate SCRAM credentials for TCP (without TLS) + logwisp auth -u tcpuser -s + logwisp auth --user=tcpuser --type=scram + + # Generate 64-byte bearer token + logwisp auth -k -l 64 + logwisp auth --token --length=64 + + # Convert existing basic auth to SCRAM (HTTPS to TCP conversion) + logwisp auth -u admin -m --phc='$argon2id$v=19$m=65536...' --password='secret' + +Output: + The command outputs configuration snippets ready to paste into logwisp.toml + and the raw credential values for external auth files. + +Security Notes: + - Basic auth and tokens require TLS encryption for HTTP connections + - SCRAM provides authentication but NOT encryption for TCP connections + - Use strong passwords (12+ characters with mixed case, numbers, symbols) + - Store credentials securely and never commit them to version control +` +} + +func (ac *AuthCommand) generatePasswordHash(username, password, credType string) error { + // Get password if not provided + if password == "" { + var err error + password, err = ac.promptForPassword() + if err != nil { + return err + } + } + + switch credType { + case "basic": + return ac.generateBasicAuth(username, password) + case "scram": + return ac.generateScramAuth(username, password) + default: + return fmt.Errorf("invalid credential type: %s", credType) + } +} + +// promptForPassword handles password prompting with confirmation +func (ac *AuthCommand) promptForPassword() (string, error) { + pass1 := ac.promptPassword("Enter password: ") + pass2 := ac.promptPassword("Confirm password: ") + if pass1 != pass2 { + return "", fmt.Errorf("passwords don't match") + } + return pass1, nil +} + +func (ac *AuthCommand) promptPassword(prompt string) string { + fmt.Fprint(ac.errOut, prompt) + password, err := term.ReadPassword(syscall.Stdin) + fmt.Fprintln(ac.errOut) + if err != nil { + fmt.Fprintf(ac.errOut, "Failed to read password: %v\n", err) + os.Exit(1) + } + return string(password) +} + +// generateBasicAuth creates Argon2id hash for HTTP basic auth +func (ac *AuthCommand) generateBasicAuth(username, password string) error { + // Generate salt + salt := make([]byte, core.Argon2SaltLen) + if _, err := rand.Read(salt); err != nil { + return fmt.Errorf("failed to generate salt: %w", err) + } + + // Generate Argon2id hash + cred, err := auth.DeriveCredential(username, password, salt, + core.Argon2Time, core.Argon2Memory, core.Argon2Threads) + if err != nil { + return fmt.Errorf("failed to derive credential: %w", err) + } + + // Output configuration snippets + fmt.Fprintln(ac.output, "\n# Basic Auth Configuration (HTTP sources/sinks)") + fmt.Fprintln(ac.output, "# REQUIRES HTTPS/TLS for security") + fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:") + fmt.Fprintln(ac.output, "") + fmt.Fprintln(ac.output, "[pipelines.auth]") + fmt.Fprintln(ac.output, `type = "basic"`) + fmt.Fprintln(ac.output, "") + fmt.Fprintln(ac.output, "[[pipelines.auth.basic_auth.users]]") + fmt.Fprintf(ac.output, "username = %q\n", username) + fmt.Fprintf(ac.output, "password_hash = %q\n\n", cred.PHCHash) + + fmt.Fprintln(ac.output, "# For external users file:") + fmt.Fprintf(ac.output, "%s:%s\n", username, cred.PHCHash) + + return nil +} + +// generateScramAuth creates Argon2id-SCRAM-SHA256 credentials for TCP +func (ac *AuthCommand) generateScramAuth(username, password string) error { + // Generate salt + salt := make([]byte, core.Argon2SaltLen) + if _, err := rand.Read(salt); err != nil { + return fmt.Errorf("failed to generate salt: %w", err) + } + + // Use internal auth package to derive SCRAM credentials + cred, err := auth.DeriveCredential(username, password, salt, + core.Argon2Time, core.Argon2Memory, core.Argon2Threads) + if err != nil { + return fmt.Errorf("failed to derive SCRAM credential: %w", err) + } + + // Output SCRAM configuration + fmt.Fprintln(ac.output, "\n# SCRAM Auth Configuration (TCP sources/sinks)") + fmt.Fprintln(ac.output, "# Provides authentication but NOT encryption") + fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:") + fmt.Fprintln(ac.output, "") + fmt.Fprintln(ac.output, "[pipelines.auth]") + fmt.Fprintln(ac.output, `type = "scram"`) + fmt.Fprintln(ac.output, "") + fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]") + fmt.Fprintf(ac.output, "username = %q\n", username) + fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey)) + fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey)) + fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt)) + fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime) + fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory) + fmt.Fprintf(ac.output, "argon_threads = %d\n\n", cred.ArgonThreads) + + fmt.Fprintln(ac.output, "# Note: SCRAM provides authentication only.") + fmt.Fprintln(ac.output, "# Use TLS/mTLS for encryption if needed.") + + return nil +} + +func (ac *AuthCommand) generateToken(length int) error { + if length < 16 { + fmt.Fprintln(ac.errOut, "Warning: tokens < 16 bytes are cryptographically weak") + } + if length > 512 { + return fmt.Errorf("token length exceeds maximum (512 bytes)") + } + + token := make([]byte, length) + if _, err := rand.Read(token); err != nil { + return fmt.Errorf("failed to generate random bytes: %w", err) + } + + b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token) + hex := fmt.Sprintf("%x", token) + + fmt.Fprintln(ac.output, "\n# Token Configuration") + fmt.Fprintln(ac.output, "# Add to logwisp.toml:") + fmt.Fprintf(ac.output, "tokens = [%q]\n\n", b64) + + fmt.Fprintln(ac.output, "# Generated Token:") + fmt.Fprintf(ac.output, "Base64: %s\n", b64) + fmt.Fprintf(ac.output, "Hex: %s\n", hex) + + return nil +} + +// migrateToScram converts basic auth PHC hash to SCRAM credentials +func (ac *AuthCommand) migrateToScram(username, password, phcHash string) error { + // CHANGED: Moved from internal/auth to CLI command layer + cred, err := auth.MigrateFromPHC(username, password, phcHash) + if err != nil { + return fmt.Errorf("migration failed: %w", err) + } + + // Output SCRAM configuration (reuse format from generateScramAuth) + fmt.Fprintln(ac.output, "\n# Migrated SCRAM Credentials") + fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:") + fmt.Fprintln(ac.output, "") + fmt.Fprintln(ac.output, "[pipelines.auth]") + fmt.Fprintln(ac.output, `type = "scram"`) + fmt.Fprintln(ac.output, "") + fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]") + fmt.Fprintf(ac.output, "username = %q\n", username) + fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey)) + fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey)) + fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt)) + fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime) + fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory) + fmt.Fprintf(ac.output, "argon_threads = %d\n", cred.ArgonThreads) + + return nil +} \ No newline at end of file diff --git a/src/cmd/logwisp/commands/help.go b/src/cmd/logwisp/commands/help.go new file mode 100644 index 0000000..26bac8f --- /dev/null +++ b/src/cmd/logwisp/commands/help.go @@ -0,0 +1,123 @@ +// FILE: src/cmd/logwisp/commands/help.go +package commands + +import ( + "fmt" + "sort" + "strings" +) + +const generalHelpTemplate = `LogWisp: A flexible log transport and processing tool. + +Usage: + logwisp [command] [options] + logwisp [options] + +Commands: +%s + +Application Options: + -c, --config Path to configuration file (default: logwisp.toml) + -h, --help Display this help message and exit + -v, --version Display version information and exit + -b, --background Run LogWisp in the background as a daemon + -q, --quiet Suppress all console output, including errors + +Runtime Options: + --disable-status-reporter Disable the periodic status reporter + --config-auto-reload Enable config reload on file change + +For command-specific help: + logwisp help + logwisp --help + +Configuration Sources (Precedence: CLI > Env > File > Defaults): + - CLI flags override all other settings + - Environment variables override file settings + - TOML configuration file is the primary method + +Examples: + # Generate password for admin user + logwisp auth -u admin + + # Start service with custom config + logwisp -c /etc/logwisp/prod.toml + + # Run in background with config reload + logwisp -b --config-auto-reload + +For detailed configuration options, please refer to the documentation. +` + +// HelpCommand handles help display +type HelpCommand struct { + router *CommandRouter +} + +// NewHelpCommand creates a new help command +func NewHelpCommand(router *CommandRouter) *HelpCommand { + return &HelpCommand{router: router} +} + +// Execute displays help information +func (c *HelpCommand) Execute(args []string) error { + // Check if help is requested for a specific command + if len(args) > 0 && args[0] != "" { + cmdName := args[0] + + if handler, exists := c.router.GetCommand(cmdName); exists { + fmt.Print(handler.Help()) + return nil + } + + return fmt.Errorf("unknown command: %s", cmdName) + } + + // Display general help with command list + fmt.Printf(generalHelpTemplate, c.formatCommandList()) + return nil +} + +// formatCommandList creates a formatted list of available commands +func (c *HelpCommand) formatCommandList() string { + commands := c.router.GetCommands() + + // Sort command names for consistent output + names := make([]string, 0, len(commands)) + maxLen := 0 + for name := range commands { + names = append(names, name) + if len(name) > maxLen { + maxLen = len(name) + } + } + sort.Strings(names) + + // Format each command with aligned descriptions + var lines []string + for _, name := range names { + handler := commands[name] + padding := strings.Repeat(" ", maxLen-len(name)+2) + lines = append(lines, fmt.Sprintf(" %s%s%s", name, padding, handler.Description())) + } + + return strings.Join(lines, "\n") +} + +func (c *HelpCommand) Description() string { + return "Display help information" +} + +func (c *HelpCommand) Help() string { + return `Help Command - Display help information + +Usage: + logwisp help Show general help + logwisp help Show help for a specific command + +Examples: + logwisp help # Show general help + logwisp help auth # Show auth command help + logwisp auth --help # Alternative way to get command help +` +} \ No newline at end of file diff --git a/src/cmd/logwisp/commands/router.go b/src/cmd/logwisp/commands/router.go new file mode 100644 index 0000000..a68df00 --- /dev/null +++ b/src/cmd/logwisp/commands/router.go @@ -0,0 +1,118 @@ +// FILE: src/cmd/logwisp/commands/router.go +package commands + +import ( + "fmt" + "os" +) + +// Handler defines the interface for subcommands +type Handler interface { + Execute(args []string) error + Description() string + Help() string +} + +// CommandRouter handles subcommand routing before main app initialization +type CommandRouter struct { + commands map[string]Handler +} + +// NewCommandRouter creates and initializes the command router +func NewCommandRouter() *CommandRouter { + router := &CommandRouter{ + commands: make(map[string]Handler), + } + + // Register available commands + router.commands["auth"] = NewAuthCommand() + router.commands["tls"] = NewTLSCommand() + router.commands["version"] = NewVersionCommand() + router.commands["help"] = NewHelpCommand(router) + + return router +} + +// Route checks for and executes subcommands +func (r *CommandRouter) Route(args []string) (bool, error) { + if len(args) < 2 { + return false, nil // No command specified, let main app continue + } + + cmdName := args[1] + + // Special case: help flag at any position shows general help + for _, arg := range args[1:] { + if arg == "-h" || arg == "--help" { + // If it's after a valid command, show command-specific help + if handler, exists := r.commands[cmdName]; exists && cmdName != "help" { + fmt.Print(handler.Help()) + return true, nil + } + // Otherwise show general help + return true, r.commands["help"].Execute(nil) + } + } + + // Check if this is a known command + handler, exists := r.commands[cmdName] + if !exists { + // Check if it looks like a mistyped command (not a flag) + if cmdName[0] != '-' { + return false, fmt.Errorf("unknown command: %s\n\nRun 'logwisp help' for usage", cmdName) + } + // It's a flag, let main app handle it + return false, nil + } + + // Execute the command + return true, handler.Execute(args[2:]) +} + +// GetCommand returns a command handler by name +func (r *CommandRouter) GetCommand(name string) (Handler, bool) { + cmd, exists := r.commands[name] + return cmd, exists +} + +// GetCommands returns all registered commands +func (r *CommandRouter) GetCommands() map[string]Handler { + return r.commands +} + +// ShowCommands displays available subcommands +func (r *CommandRouter) ShowCommands() { + for name, handler := range r.commands { + fmt.Fprintf(os.Stderr, " %-10s %s\n", name, handler.Description()) + } + fmt.Fprintln(os.Stderr, "\nUse 'logwisp --help' for command-specific help") +} + +// Helper functions to merge short and long options +func coalesceString(values ...string) string { + for _, v := range values { + if v != "" { + return v + } + } + return "" +} + +func coalesceInt(primary, secondary, defaultVal int) int { + if primary != defaultVal { + return primary + } + if secondary != defaultVal { + return secondary + } + return defaultVal +} + +func coalesceBool(values ...bool) bool { + for _, v := range values { + if v { + return true + } + } + return false +} \ No newline at end of file diff --git a/src/internal/tls/generator.go b/src/cmd/logwisp/commands/tls.go similarity index 64% rename from src/internal/tls/generator.go rename to src/cmd/logwisp/commands/tls.go index 63d91c3..562d4d0 100644 --- a/src/internal/tls/generator.go +++ b/src/cmd/logwisp/commands/tls.go @@ -1,5 +1,5 @@ -// FILE: src/internal/tls/generator.go -package tls +// FILE: src/cmd/logwisp/commands/tls.go +package commands import ( "crypto/rand" @@ -17,40 +17,50 @@ import ( "time" ) -type CertGeneratorCommand struct { +type TLSCommand struct { output io.Writer errOut io.Writer } -func NewCertGeneratorCommand() *CertGeneratorCommand { - return &CertGeneratorCommand{ +func NewTLSCommand() *TLSCommand { + return &TLSCommand{ output: os.Stdout, errOut: os.Stderr, } } -func (cg *CertGeneratorCommand) Execute(args []string) error { +func (tc *TLSCommand) Execute(args []string) error { cmd := flag.NewFlagSet("tls", flag.ContinueOnError) - cmd.SetOutput(cg.errOut) + cmd.SetOutput(tc.errOut) - // Subcommands + // Certificate type flags var ( genCA = cmd.Bool("ca", false, "Generate CA certificate") genServer = cmd.Bool("server", false, "Generate server certificate") genClient = cmd.Bool("client", false, "Generate client certificate") selfSign = cmd.Bool("self-signed", false, "Generate self-signed certificate") - // Common options + // Common options - short forms commonName = cmd.String("cn", "", "Common name (required)") - org = cmd.String("org", "LogWisp", "Organization") - country = cmd.String("country", "US", "Country code") - validDays = cmd.Int("days", 365, "Validity period in days") - keySize = cmd.Int("bits", 2048, "RSA key size") + org = cmd.String("o", "LogWisp", "Organization") + country = cmd.String("c", "US", "Country code") + validDays = cmd.Int("d", 365, "Validity period in days") + keySize = cmd.Int("b", 2048, "RSA key size") - // Server/Client specific - hosts = cmd.String("hosts", "", "Comma-separated hostnames/IPs (server cert)") - caFile = cmd.String("ca-cert", "", "CA certificate file (for signing)") - caKeyFile = cmd.String("ca-key", "", "CA key file (for signing)") + // Common options - long forms + commonNameLong = cmd.String("common-name", "", "Common name (required)") + orgLong = cmd.String("org", "LogWisp", "Organization") + countryLong = cmd.String("country", "US", "Country code") + validDaysLong = cmd.Int("days", 365, "Validity period in days") + keySizeLong = cmd.Int("bits", 2048, "RSA key size") + + // Server/Client specific - short forms + hosts = cmd.String("h", "", "Comma-separated hostnames/IPs") + caFile = cmd.String("ca-cert", "", "CA certificate file") + caKey = cmd.String("ca-key", "", "CA key file") + + // Server/Client specific - long forms + hostsLong = cmd.String("hosts", "", "Comma-separated hostnames/IPs") // Output files certOut = cmd.String("cert-out", "", "Output certificate file") @@ -58,51 +68,135 @@ func (cg *CertGeneratorCommand) Execute(args []string) error { ) cmd.Usage = func() { - fmt.Fprintln(cg.errOut, "Generate TLS certificates for LogWisp") - fmt.Fprintln(cg.errOut, "\nUsage: logwisp tls [options]") - fmt.Fprintln(cg.errOut, "\nExamples:") - fmt.Fprintln(cg.errOut, " # Generate self-signed certificate") - fmt.Fprintln(cg.errOut, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1") - fmt.Fprintln(cg.errOut, " ") - fmt.Fprintln(cg.errOut, " # Generate CA certificate") - fmt.Fprintln(cg.errOut, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key") - fmt.Fprintln(cg.errOut, " ") - fmt.Fprintln(cg.errOut, " # Generate server certificate signed by CA") - fmt.Fprintln(cg.errOut, " logwisp tls --server --cn server.example.com --hosts server.example.com \\") - fmt.Fprintln(cg.errOut, " --ca-cert ca.crt --ca-key ca.key") - fmt.Fprintln(cg.errOut, "\nOptions:") + fmt.Fprintln(tc.errOut, "Generate TLS certificates for LogWisp") + fmt.Fprintln(tc.errOut, "\nUsage: logwisp tls [options]") + fmt.Fprintln(tc.errOut, "\nExamples:") + fmt.Fprintln(tc.errOut, " # Generate self-signed certificate") + fmt.Fprintln(tc.errOut, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1") + fmt.Fprintln(tc.errOut, " ") + fmt.Fprintln(tc.errOut, " # Generate CA certificate") + fmt.Fprintln(tc.errOut, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key") + fmt.Fprintln(tc.errOut, " ") + fmt.Fprintln(tc.errOut, " # Generate server certificate signed by CA") + fmt.Fprintln(tc.errOut, " logwisp tls --server --cn server.example.com --hosts server.example.com \\") + fmt.Fprintln(tc.errOut, " --ca-cert ca.crt --ca-key ca.key") + fmt.Fprintln(tc.errOut, "\nOptions:") cmd.PrintDefaults() - fmt.Fprintln(cg.errOut) + fmt.Fprintln(tc.errOut) } if err := cmd.Parse(args); err != nil { return err } + // Check for unparsed arguments + if cmd.NArg() > 0 { + return fmt.Errorf("unexpected argument(s): %s", strings.Join(cmd.Args(), " ")) + } + + // Merge short and long options + finalCN := coalesceString(*commonName, *commonNameLong) + finalOrg := coalesceString(*org, *orgLong, "LogWisp") + finalCountry := coalesceString(*country, *countryLong, "US") + finalDays := coalesceInt(*validDays, *validDaysLong, 365) + finalKeySize := coalesceInt(*keySize, *keySizeLong, 2048) + finalHosts := coalesceString(*hosts, *hostsLong) + finalCAFile := *caFile // no short form + finalCAKey := *caKey // no short form + finalCertOut := *certOut // no short form + finalKeyOut := *keyOut // no short form + // Validate common name - if *commonName == "" { + if finalCN == "" { cmd.Usage() return fmt.Errorf("common name (--cn) is required") } + // Validate RSA key size + if finalKeySize != 2048 && finalKeySize != 3072 && finalKeySize != 4096 { + return fmt.Errorf("invalid key size: %d (valid: 2048, 3072, 4096)", finalKeySize) + } + // Route to appropriate generator switch { case *genCA: - return cg.generateCA(*commonName, *org, *country, *validDays, *keySize, *certOut, *keyOut) + return tc.generateCA(finalCN, finalOrg, finalCountry, finalDays, finalKeySize, finalCertOut, finalKeyOut) case *selfSign: - return cg.generateSelfSigned(*commonName, *org, *country, *hosts, *validDays, *keySize, *certOut, *keyOut) + return tc.generateSelfSigned(finalCN, finalOrg, finalCountry, finalHosts, finalDays, finalKeySize, finalCertOut, finalKeyOut) case *genServer: - return cg.generateServerCert(*commonName, *org, *country, *hosts, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) + return tc.generateServerCert(finalCN, finalOrg, finalCountry, finalHosts, finalCAFile, finalCAKey, finalDays, finalKeySize, finalCertOut, finalKeyOut) case *genClient: - return cg.generateClientCert(*commonName, *org, *country, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) + return tc.generateClientCert(finalCN, finalOrg, finalCountry, finalCAFile, finalCAKey, finalDays, finalKeySize, finalCertOut, finalKeyOut) default: cmd.Usage() return fmt.Errorf("specify certificate type: --ca, --self-signed, --server, or --client") } } +func (tc *TLSCommand) Description() string { + return "Generate TLS certificates (CA, server, client, self-signed)" +} + +func (tc *TLSCommand) Help() string { + return `TLS Command - Generate TLS certificates for LogWisp + +Usage: + logwisp tls [options] + +Certificate Types: + --ca Generate Certificate Authority (CA) certificate + --server Generate server certificate (requires CA or self-signed) + --client Generate client certificate (for mTLS) + --self-signed Generate self-signed certificate (single cert for testing) + +Common Options: + --cn, --common-name Common Name (required) + -o, --org Organization name (default: "LogWisp") + -c, --country Country code (default: "US") + -d, --days Validity period in days (default: 365) + -b, --bits RSA key size (default: 2048) + +Server Certificate Options: + -h, --hosts Comma-separated hostnames/IPs + Example: "localhost,10.0.0.1,example.com" + --ca-cert CA certificate file (for signing) + --ca-key CA key file (for signing) + +Output Options: + --cert-out Output certificate file (default: stdout) + --key-out Output private key file (default: stdout) + +Examples: + # Generate self-signed certificate for testing + logwisp tls --self-signed --cn localhost --hosts "localhost,127.0.0.1" \ + --cert-out server.crt --key-out server.key + + # Generate CA certificate + logwisp tls --ca --cn "LogWisp CA" --days 3650 \ + --cert-out ca.crt --key-out ca.key + + # Generate server certificate signed by CA + logwisp tls --server --cn "logwisp.example.com" \ + --hosts "logwisp.example.com,10.0.0.100" \ + --ca-cert ca.crt --ca-key ca.key \ + --cert-out server.crt --key-out server.key + + # Generate client certificate for mTLS + logwisp tls --client --cn "client1" \ + --ca-cert ca.crt --ca-key ca.key \ + --cert-out client.crt --key-out client.key + +Security Notes: + - Keep private keys secure and never share them + - Use 2048-bit RSA minimum, 3072 or 4096 for higher security + - For production, use certificates from a trusted CA + - Self-signed certificates are only for development/testing + - Rotate certificates before expiration +` +} + // Create and manage private CA -func (cg *CertGeneratorCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error { +func (tc *TLSCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error { // Generate RSA key priv, err := rsa.GenerateKey(rand.Reader, bits) if err != nil { @@ -178,7 +272,7 @@ func parseHosts(hostList string) ([]string, []net.IP) { } // Generate self-signed certificate -func (cg *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error { +func (tc *TLSCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error { // 1. Generate an RSA private key with the specified bit size priv, err := rsa.GenerateKey(rand.Reader, bits) if err != nil { @@ -245,7 +339,7 @@ func (cg *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts strin } // Generate server cert with CA -func (cg *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { +func (tc *TLSCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { caCert, caKey, err := loadCA(caFile, caKeyFile) if err != nil { return err @@ -308,7 +402,7 @@ func (cg *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFi } // Generate client cert with CA -func (cg *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { +func (tc *TLSCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { caCert, caKey, err := loadCA(caFile, caKeyFile) if err != nil { return err diff --git a/src/cmd/logwisp/commands/version.go b/src/cmd/logwisp/commands/version.go new file mode 100644 index 0000000..073e3ff --- /dev/null +++ b/src/cmd/logwisp/commands/version.go @@ -0,0 +1,41 @@ +// FILE: src/cmd/logwisp/commands/version.go +package commands + +import ( + "fmt" + + "logwisp/src/internal/version" +) + +// VersionCommand handles version display +type VersionCommand struct{} + +// NewVersionCommand creates a new version command +func NewVersionCommand() *VersionCommand { + return &VersionCommand{} +} + +func (c *VersionCommand) Execute(args []string) error { + fmt.Println(version.String()) + return nil +} + +func (c *VersionCommand) Description() string { + return "Show version information" +} + +func (c *VersionCommand) Help() string { + return `Version Command - Show LogWisp version information + +Usage: + logwisp version + logwisp -v + logwisp --version + +Output includes: + - Version number + - Build date + - Git commit hash (if available) + - Go version used for compilation +` +} \ No newline at end of file diff --git a/src/cmd/logwisp/help.go b/src/cmd/logwisp/help.go deleted file mode 100644 index 4015060..0000000 --- a/src/cmd/logwisp/help.go +++ /dev/null @@ -1,59 +0,0 @@ -// FILE: logwisp/src/cmd/logwisp/help.go -package main - -import ( - "fmt" - "os" -) - -const helpText = `LogWisp: A flexible log transport and processing tool. - -Usage: - logwisp [command] [options] - logwisp [options] - -Commands: - auth Generate authentication credentials - version Display version information - -Application Control: - -c, --config Path to configuration file (default: logwisp.toml) - -h, --help Display this help message and exit - -v, --version Display version information and exit - -b, --background Run LogWisp in the background as a daemon - -q, --quiet Suppress all console output, including errors - -Runtime Behavior: - --disable-status-reporter Disable the periodic status reporter - --config-auto-reload Enable config reload on file change - -For command-specific help: - logwisp --help - -Configuration Sources (Precedence: CLI > Env > File > Defaults): - - CLI flags override all other settings - - Environment variables override file settings - - TOML configuration file is the primary method - -Examples: - # Generate password for admin user - logwisp auth -u admin - - # Start service with custom config - logwisp -c /etc/logwisp/prod.toml - - # Run in background - logwisp -b --config-auto-reload - -For detailed configuration options, please refer to the documentation. -` - -// Scans arguments for help flags and prints help text if found. -func CheckAndDisplayHelp(args []string) { - for _, arg := range args { - if arg == "-h" || arg == "--help" { - fmt.Fprint(os.Stdout, helpText) - os.Exit(0) - } - } -} \ No newline at end of file diff --git a/src/cmd/logwisp/main.go b/src/cmd/logwisp/main.go index 107c716..58ae93c 100644 --- a/src/cmd/logwisp/main.go +++ b/src/cmd/logwisp/main.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "logwisp/src/cmd/logwisp/commands" "logwisp/src/internal/config" "logwisp/src/internal/version" @@ -22,12 +23,22 @@ var logger *log.Logger func main() { // Handle subcommands before any config loading // This prevents flag conflicts with lixenwraith/config - router := NewCommandRouter() - if router.Route(os.Args) != nil { - // Subcommand was handled, exit already called - return + router := commands.NewCommandRouter() + handled, err := router.Route(os.Args) + + if err != nil { + // Command execution error + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) } + if handled { + // Command was successfully handled + os.Exit(0) + } + + // No subcommand, continue with main application + // Emulates nohup signal.Ignore(syscall.SIGHUP) @@ -158,8 +169,6 @@ func main() { select { case <-done: - // Save configuration after graceful shutdown (no reload manager in static mode) - saveConfigurationOnExit(cfg, nil, logger) logger.Info("msg", "Shutdown complete") case <-shutdownCtx.Done(): logger.Error("msg", "Shutdown timeout exceeded - forcing exit") @@ -172,9 +181,6 @@ func main() { // Wait for context cancellation <-ctx.Done() - // Save configuration before final shutdown, handled by reloadManager - saveConfigurationOnExit(cfg, reloadManager, logger) - // Shutdown is handled by ReloadManager.Shutdown() in defer logger.Info("msg", "Shutdown complete") } @@ -186,48 +192,4 @@ func shutdownLogger() { Error("Logger shutdown error: %v\n", err) } } -} - -// Saves the configuration to file on exist -func saveConfigurationOnExit(cfg *config.Config, reloadManager *ReloadManager, logger *log.Logger) { - // Only save if explicitly enabled and we have a valid path - if !cfg.ConfigSaveOnExit || cfg.ConfigFile == "" { - return - } - - // Create a context with timeout for save operation - saveCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Perform save in goroutine to respect timeout - done := make(chan error, 1) - go func() { - var err error - if reloadManager != nil && reloadManager.lcfg != nil { - // Use existing lconfig instance from reload manager - // This ensures we save through the same configuration system - err = reloadManager.lcfg.Save(cfg.ConfigFile) - } else { - // Static mode: create temporary lconfig for saving - err = cfg.SaveToFile(cfg.ConfigFile) - } - done <- err - }() - - select { - case err := <-done: - if err != nil { - logger.Error("msg", "Failed to save configuration on exit", - "path", cfg.ConfigFile, - "error", err) - // Don't fail the exit on save error - } else { - logger.Info("msg", "Configuration saved successfully", - "path", cfg.ConfigFile) - } - case <-saveCtx.Done(): - logger.Error("msg", "Configuration save timeout exceeded", - "path", cfg.ConfigFile, - "timeout", "5s") - } } \ No newline at end of file diff --git a/src/cmd/logwisp/reload.go b/src/cmd/logwisp/reload.go index 9e9d1e2..9a6df05 100644 --- a/src/cmd/logwisp/reload.go +++ b/src/cmd/logwisp/reload.go @@ -338,14 +338,6 @@ func (rm *ReloadManager) stopStatusReporter() { } } -// Wrapper to save the config -func (rm *ReloadManager) SaveConfig(path string) error { - if rm.lcfg == nil { - return fmt.Errorf("no lconfig instance available") - } - return rm.lcfg.Save(path) -} - // Stops the reload manager func (rm *ReloadManager) Shutdown() { rm.logger.Info("msg", "Shutting down reload manager") diff --git a/src/cmd/logwisp/status.go b/src/cmd/logwisp/status.go index dddd136..faba2c8 100644 --- a/src/cmd/logwisp/status.go +++ b/src/cmd/logwisp/status.go @@ -114,81 +114,76 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) { for i, sinkCfg := range cfg.Sinks { switch sinkCfg.Type { case "tcp": - if port, ok := sinkCfg.Options["port"].(int64); ok { - host := "0.0.0.0" // Get host or default to 0.0.0.0 - if h, ok := sinkCfg.Options["host"].(string); ok && h != "" { - host = h + if sinkCfg.TCP != nil { + host := "0.0.0.0" + if sinkCfg.TCP.Host != "" { + host = sinkCfg.TCP.Host } logger.Info("msg", "TCP endpoint configured", "component", "main", "pipeline", cfg.Name, "sink_index", i, - "listen", fmt.Sprintf("%s:%d", host, port)) + "listen", fmt.Sprintf("%s:%d", host, sinkCfg.TCP.Port)) // Display net limit info if configured - if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok { - if enabled, ok := nl["enabled"].(bool); ok && enabled { - logger.Info("msg", "TCP net limiting enabled", - "pipeline", cfg.Name, - "sink_index", i, - "requests_per_second", nl["requests_per_second"], - "burst_size", nl["burst_size"]) - } + if sinkCfg.TCP.NetLimit != nil && sinkCfg.TCP.NetLimit.Enabled { + logger.Info("msg", "TCP net limiting enabled", + "pipeline", cfg.Name, + "sink_index", i, + "requests_per_second", sinkCfg.TCP.NetLimit.RequestsPerSecond, + "burst_size", sinkCfg.TCP.NetLimit.BurstSize) } } case "http": - if port, ok := sinkCfg.Options["port"].(int64); ok { + if sinkCfg.HTTP != nil { host := "0.0.0.0" - if h, ok := sinkCfg.Options["host"].(string); ok && h != "" { - host = h + if sinkCfg.HTTP.Host != "" { + host = sinkCfg.HTTP.Host } streamPath := "/stream" statusPath := "/status" - if path, ok := sinkCfg.Options["stream_path"].(string); ok { - streamPath = path + if sinkCfg.HTTP.StreamPath != "" { + streamPath = sinkCfg.HTTP.StreamPath } - if path, ok := sinkCfg.Options["status_path"].(string); ok { - statusPath = path + if sinkCfg.HTTP.StatusPath != "" { + statusPath = sinkCfg.HTTP.StatusPath } logger.Info("msg", "HTTP endpoints configured", "pipeline", cfg.Name, "sink_index", i, - "listen", fmt.Sprintf("%s:%d", host, port), - "stream_url", fmt.Sprintf("http://%s:%d%s", host, port, streamPath), - "status_url", fmt.Sprintf("http://%s:%d%s", host, port, statusPath)) + "listen", fmt.Sprintf("%s:%d", host, sinkCfg.HTTP.Port), + "stream_url", fmt.Sprintf("http://%s:%d%s", host, sinkCfg.HTTP.Port, streamPath), + "status_url", fmt.Sprintf("http://%s:%d%s", host, sinkCfg.HTTP.Port, statusPath)) // Display net limit info if configured - if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok { - if enabled, ok := nl["enabled"].(bool); ok && enabled { - logger.Info("msg", "HTTP net limiting enabled", - "pipeline", cfg.Name, - "sink_index", i, - "requests_per_second", nl["requests_per_second"], - "burst_size", nl["burst_size"]) - } + if sinkCfg.HTTP.NetLimit != nil && sinkCfg.HTTP.NetLimit.Enabled { + logger.Info("msg", "HTTP net limiting enabled", + "pipeline", cfg.Name, + "sink_index", i, + "requests_per_second", sinkCfg.HTTP.NetLimit.RequestsPerSecond, + "burst_size", sinkCfg.HTTP.NetLimit.BurstSize) } } case "file": - if dir, ok := sinkCfg.Options["directory"].(string); ok { - name, _ := sinkCfg.Options["name"].(string) + if sinkCfg.File != nil { logger.Info("msg", "File sink configured", "pipeline", cfg.Name, "sink_index", i, - "directory", dir, - "name", name) + "directory", sinkCfg.File.Directory, + "name", sinkCfg.File.Name) } case "console": - if target, ok := sinkCfg.Options["target"].(string); ok { + if sinkCfg.Console != nil { logger.Info("msg", "Console sink configured", "pipeline", cfg.Name, "sink_index", i, - "target", target) + "target", sinkCfg.Console.Target) } } } @@ -197,10 +192,10 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) { for i, sourceCfg := range cfg.Sources { switch sourceCfg.Type { case "http": - if port, ok := sourceCfg.Options["port"].(int64); ok { + if sourceCfg.HTTP != nil { host := "0.0.0.0" - if h, ok := sourceCfg.Options["host"].(string); ok && h != "" { - host = h + if sourceCfg.HTTP.Host != "" { + host = sourceCfg.HTTP.Host } displayHost := host @@ -209,22 +204,22 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) { } ingestPath := "/ingest" - if path, ok := sourceCfg.Options["ingest_path"].(string); ok { - ingestPath = path + if sourceCfg.HTTP.IngestPath != "" { + ingestPath = sourceCfg.HTTP.IngestPath } logger.Info("msg", "HTTP source configured", "pipeline", cfg.Name, "source_index", i, - "listen", fmt.Sprintf("%s:%d", host, port), - "ingest_url", fmt.Sprintf("http://%s:%d%s", displayHost, port, ingestPath)) + "listen", fmt.Sprintf("%s:%d", host, sourceCfg.HTTP.Port), + "ingest_url", fmt.Sprintf("http://%s:%d%s", displayHost, sourceCfg.HTTP.Port, ingestPath)) } case "tcp": - if port, ok := sourceCfg.Options["port"].(int64); ok { + if sourceCfg.TCP != nil { host := "0.0.0.0" - if h, ok := sourceCfg.Options["host"].(string); ok && h != "" { - host = h + if sourceCfg.TCP.Host != "" { + host = sourceCfg.TCP.Host } displayHost := host @@ -235,19 +230,24 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) { logger.Info("msg", "TCP source configured", "pipeline", cfg.Name, "source_index", i, - "listen", fmt.Sprintf("%s:%d", host, port), - "endpoint", fmt.Sprintf("%s:%d", displayHost, port)) + "listen", fmt.Sprintf("%s:%d", host, sourceCfg.TCP.Port), + "endpoint", fmt.Sprintf("%s:%d", displayHost, sourceCfg.TCP.Port)) } - // TODO: missing other types of source, to be added - } - } + case "directory": + if sourceCfg.Directory != nil { + logger.Info("msg", "Directory source configured", + "pipeline", cfg.Name, + "source_index", i, + "path", sourceCfg.Directory.Path, + "pattern", sourceCfg.Directory.Pattern) + } - // Display authentication information - if cfg.Auth != nil && cfg.Auth.Type != "none" { - logger.Info("msg", "Authentication enabled", - "pipeline", cfg.Name, - "auth_type", cfg.Auth.Type) + case "stdin": + logger.Info("msg", "Stdin source configured", + "pipeline", cfg.Name, + "source_index", i) + } } // Display filter information diff --git a/src/internal/auth/authenticator.go b/src/internal/auth/authenticator.go index 1815296..5195c10 100644 --- a/src/internal/auth/authenticator.go +++ b/src/internal/auth/authenticator.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "encoding/base64" "fmt" - "net" "strings" "sync" "time" @@ -13,7 +12,6 @@ import ( "logwisp/src/internal/config" "github.com/lixenwraith/log" - "golang.org/x/time/rate" ) // Prevent unbounded map growth @@ -21,66 +19,50 @@ const maxAuthTrackedIPs = 10000 // Handles all authentication methods for a pipeline type Authenticator struct { - config *config.AuthConfig - logger *log.Logger - bearerTokens map[string]bool // token -> valid - mu sync.RWMutex + config *config.ServerAuthConfig + logger *log.Logger + tokens map[string]bool // token -> valid + mu sync.RWMutex // Session tracking sessions map[string]*Session sessionMu sync.RWMutex - - // Brute-force protection - ipAuthAttempts map[string]*ipAuthState - authMu sync.RWMutex -} - -// Per-IP auth attempt tracking -type ipAuthState struct { - limiter *rate.Limiter - failCount int - lastAttempt time.Time - blockedUntil time.Time } // Represents an authenticated connection type Session struct { ID string Username string - Method string // basic, bearer, mtls + Method string // basic, token, mtls RemoteAddr string CreatedAt time.Time LastActivity time.Time } // Creates a new authenticator from config -func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { +func NewAuthenticator(cfg *config.ServerAuthConfig, logger *log.Logger) (*Authenticator, error) { // SCRAM is handled by ScramManager in sources if cfg == nil || cfg.Type == "none" || cfg.Type == "scram" { return nil, nil } a := &Authenticator{ - config: cfg, - logger: logger, - bearerTokens: make(map[string]bool), - sessions: make(map[string]*Session), - ipAuthAttempts: make(map[string]*ipAuthState), + config: cfg, + logger: logger, + tokens: make(map[string]bool), + sessions: make(map[string]*Session), } - // Initialize Bearer tokens - if cfg.Type == "bearer" && cfg.BearerAuth != nil { - for _, token := range cfg.BearerAuth.Tokens { - a.bearerTokens[token] = true + // Initialize tokens + if cfg.Type == "token" && cfg.Token != nil { + for _, token := range cfg.Token.Tokens { + a.tokens[token] = true } } // Start session cleanup go a.sessionCleanup() - // Start auth attempt cleanup - go a.authAttemptCleanup() - logger.Info("msg", "Authenticator initialized", "component", "auth", "type", cfg.Type) @@ -88,129 +70,6 @@ func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticato return a, nil } -// Check and enforce rate limits -func (a *Authenticator) checkRateLimit(remoteAddr string) error { - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - ip = remoteAddr // Fallback for malformed addresses - } - - a.authMu.Lock() - defer a.authMu.Unlock() - - state, exists := a.ipAuthAttempts[ip] - now := time.Now() - - if !exists { - // Check map size limit before creating new entry - if len(a.ipAuthAttempts) >= maxAuthTrackedIPs { - // Evict an old entry using simplified LRU - // Sample 20 random entries and evict the oldest - const sampleSize = 20 - var oldestIP string - oldestTime := now - - // Build sample - sampled := 0 - for sampledIP, sampledState := range a.ipAuthAttempts { - if sampledState.lastAttempt.Before(oldestTime) { - oldestIP = sampledIP - oldestTime = sampledState.lastAttempt - } - sampled++ - if sampled >= sampleSize { - break - } - } - - // Evict the oldest from our sample - if oldestIP != "" { - delete(a.ipAuthAttempts, oldestIP) - a.logger.Debug("msg", "Evicted old auth attempt state", - "component", "auth", - "evicted_ip", oldestIP, - "last_seen", oldestTime) - } - } - - // Create new state for this IP - // 5 attempts per minute, burst of 3 - state = &ipAuthState{ - limiter: rate.NewLimiter(rate.Every(12*time.Second), 3), - lastAttempt: now, - } - a.ipAuthAttempts[ip] = state - } - - // Check if IP is temporarily blocked - if now.Before(state.blockedUntil) { - remaining := state.blockedUntil.Sub(now) - a.logger.Warn("msg", "IP temporarily blocked", - "component", "auth", - "ip", ip, - "remaining", remaining) - // Sleep to slow down even blocked attempts - time.Sleep(2 * time.Second) - return fmt.Errorf("temporarily blocked, try again in %v", remaining.Round(time.Second)) - } - - // Check rate limit - if !state.limiter.Allow() { - state.failCount++ - - // Only set new blockedUntil if not already blocked - // This prevents indefinite block extension - if state.blockedUntil.IsZero() || now.After(state.blockedUntil) { - // Progressive blocking: 2^failCount minutes - blockMinutes := 1 << min(state.failCount, 6) // Cap at 64 minutes - state.blockedUntil = now.Add(time.Duration(blockMinutes) * time.Minute) - - a.logger.Warn("msg", "Rate limit exceeded, blocking IP", - "component", "auth", - "ip", ip, - "fail_count", state.failCount, - "block_duration", time.Duration(blockMinutes)*time.Minute) - } - - return fmt.Errorf("rate limit exceeded") - } - - state.lastAttempt = now - return nil -} - -// Record failed attempt -func (a *Authenticator) recordFailure(remoteAddr string) { - ip, _, _ := net.SplitHostPort(remoteAddr) - if ip == "" { - ip = remoteAddr - } - - a.authMu.Lock() - defer a.authMu.Unlock() - - if state, exists := a.ipAuthAttempts[ip]; exists { - state.failCount++ - state.lastAttempt = time.Now() - } -} - -// Reset failure count on success -func (a *Authenticator) recordSuccess(remoteAddr string) { - ip, _, _ := net.SplitHostPort(remoteAddr) - if ip == "" { - ip = remoteAddr - } - - a.authMu.Lock() - defer a.authMu.Unlock() - - if state, exists := a.ipAuthAttempts[ip]; exists { - state.failCount = 0 - state.blockedUntil = time.Time{} - } -} - // Handles HTTP authentication headers func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) { if a == nil || a.config.Type == "none" { @@ -222,77 +81,27 @@ func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Sessio }, nil } - // Check rate limit - if err := a.checkRateLimit(remoteAddr); err != nil { - return nil, err - } - var session *Session var err error switch a.config.Type { - case "bearer": - session, err = a.authenticateBearer(authHeader, remoteAddr) + case "token": + session, err = a.authenticateToken(authHeader, remoteAddr) default: err = fmt.Errorf("unsupported auth type: %s", a.config.Type) } if err != nil { - a.recordFailure(remoteAddr) time.Sleep(500 * time.Millisecond) return nil, err } - a.recordSuccess(remoteAddr) return session, nil } -// Handles TCP connection authentication -func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string) (*Session, error) { - if a == nil || a.config.Type == "none" { - return &Session{ - ID: generateSessionID(), - Method: "none", - RemoteAddr: remoteAddr, - CreatedAt: time.Now(), - }, nil - } - - // Check rate limit first - if err := a.checkRateLimit(remoteAddr); err != nil { - return nil, err - } - - var session *Session - var err error - - // TCP auth protocol: AUTH - switch strings.ToLower(method) { - case "token": - if a.config.Type != "bearer" { - err = fmt.Errorf("token auth not configured") - } else { - session, err = a.validateToken(credentials, remoteAddr) - } - - default: - err = fmt.Errorf("unsupported auth method: %s", method) - } - - if err != nil { - a.recordFailure(remoteAddr) - // Add delay on failure - time.Sleep(500 * time.Millisecond) - return nil, err - } - - a.recordSuccess(remoteAddr) - return session, nil -} - -func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) { - if !strings.HasPrefix(authHeader, "Bearer ") { - return nil, fmt.Errorf("invalid bearer auth header") +func (a *Authenticator) authenticateToken(authHeader, remoteAddr string) (*Session, error) { + if !strings.HasPrefix(authHeader, "Token") { + return nil, fmt.Errorf("invalid token auth header") } token := authHeader[7:] @@ -302,7 +111,7 @@ func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Sess func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) { // Check static tokens first a.mu.RLock() - isValid := a.bearerTokens[token] + isValid := a.tokens[token] a.mu.RUnlock() if !isValid { @@ -311,7 +120,7 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error session := &Session{ ID: generateSessionID(), - Method: "bearer", + Method: "token", RemoteAddr: remoteAddr, CreatedAt: time.Now(), LastActivity: time.Now(), @@ -352,27 +161,6 @@ func (a *Authenticator) sessionCleanup() { } } -// Cleanup old auth attempts -func (a *Authenticator) authAttemptCleanup() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for range ticker.C { - a.authMu.Lock() - now := time.Now() - for ip, state := range a.ipAuthAttempts { - // Remove entries older than 1 hour with no recent activity - if now.Sub(state.lastAttempt) > time.Hour { - delete(a.ipAuthAttempts, ip) - a.logger.Debug("msg", "Cleaned up auth attempt state", - "component", "auth", - "ip", ip) - } - } - a.authMu.Unlock() - } -} - func generateSessionID() string { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { @@ -418,6 +206,6 @@ func (a *Authenticator) GetStats() map[string]any { "enabled": true, "type": a.config.Type, "active_sessions": sessionCount, - "static_tokens": len(a.bearerTokens), + "static_tokens": len(a.tokens), } } \ No newline at end of file diff --git a/src/internal/auth/generator.go b/src/internal/auth/generator.go deleted file mode 100644 index 2daf21a..0000000 --- a/src/internal/auth/generator.go +++ /dev/null @@ -1,207 +0,0 @@ -// FILE: src/internal/auth/generator.go -package auth - -import ( - "crypto/rand" - "encoding/base64" - "flag" - "fmt" - "io" - "os" - "syscall" - - "logwisp/src/internal/scram" - - "golang.org/x/crypto/argon2" - "golang.org/x/term" -) - -// Argon2id parameters -const ( - argon2Time = 3 - argon2Memory = 64 * 1024 // 64 MB - argon2Threads = 4 - argon2SaltLen = 16 - argon2KeyLen = 32 -) - -type AuthGeneratorCommand struct { - output io.Writer - errOut io.Writer -} - -func NewAuthGeneratorCommand() *AuthGeneratorCommand { - return &AuthGeneratorCommand{ - output: os.Stdout, - errOut: os.Stderr, - } -} - -func (ag *AuthGeneratorCommand) Execute(args []string) error { - cmd := flag.NewFlagSet("auth", flag.ContinueOnError) - cmd.SetOutput(ag.errOut) - - var ( - username = cmd.String("u", "", "Username") - password = cmd.String("p", "", "Password (will prompt if not provided)") - authType = cmd.String("type", "basic", "Auth type: basic (HTTP) or scram (TCP)") - genToken = cmd.Bool("t", false, "Generate random bearer token") - tokenLen = cmd.Int("l", 32, "Token length in bytes (min 16, max 512)") - ) - - cmd.Usage = func() { - fmt.Fprintln(ag.errOut, "Generate authentication credentials for LogWisp") - fmt.Fprintln(ag.errOut, "\nUsage: logwisp auth [options]") - fmt.Fprintln(ag.errOut, "\nExamples:") - fmt.Fprintln(ag.errOut, " # Generate basic auth hash for HTTP sources/sinks") - fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type basic") - fmt.Fprintln(ag.errOut, " ") - fmt.Fprintln(ag.errOut, " # Generate SCRAM credentials for TCP sources/sinks") - fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type scram") - fmt.Fprintln(ag.errOut, " ") - fmt.Fprintln(ag.errOut, " # Generate 64-byte bearer token") - fmt.Fprintln(ag.errOut, " logwisp auth -t -l 64") - fmt.Fprintln(ag.errOut, "\nOptions:") - cmd.PrintDefaults() - fmt.Fprintln(ag.errOut) - } - - if err := cmd.Parse(args); err != nil { - return err - } - - if *genToken { - return ag.generateToken(*tokenLen) - } - - if *username == "" { - cmd.Usage() - return fmt.Errorf("username required for credential generation") - } - - switch *authType { - case "basic": - return ag.generateBasicAuth(*username, *password) - case "scram": - return ag.generateScramAuth(*username, *password) - default: - return fmt.Errorf("invalid auth type: %s (use 'basic' or 'scram')", *authType) - } -} - -func (ag *AuthGeneratorCommand) generateBasicAuth(username, password string) error { - // Get password if not provided - if password == "" { - pass1 := ag.promptPassword("Enter password: ") - pass2 := ag.promptPassword("Confirm password: ") - if pass1 != pass2 { - return fmt.Errorf("passwords don't match") - } - password = pass1 - } - - // Generate salt - salt := make([]byte, argon2SaltLen) - if _, err := rand.Read(salt); err != nil { - return fmt.Errorf("failed to generate salt: %w", err) - } - - // Generate Argon2id hash - hash := argon2.IDKey([]byte(password), salt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen) - - // Encode in PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash - saltB64 := base64.RawStdEncoding.EncodeToString(salt) - hashB64 := base64.RawStdEncoding.EncodeToString(hash) - phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", - argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64) - - // Output configuration snippets - fmt.Fprintln(ag.output, "\n# Basic Auth Configuration (HTTP sources/sinks)") - fmt.Fprintln(ag.output, "# REQUIRES HTTPS/TLS for security") - fmt.Fprintln(ag.output, "# Add to logwisp.toml under [[pipelines]]:") - fmt.Fprintln(ag.output, "") - fmt.Fprintln(ag.output, "[pipelines.auth]") - fmt.Fprintln(ag.output, `type = "basic"`) - fmt.Fprintln(ag.output, "") - fmt.Fprintln(ag.output, "[[pipelines.auth.basic_auth.users]]") - fmt.Fprintf(ag.output, "username = %q\n", username) - fmt.Fprintf(ag.output, "password_hash = %q\n\n", phcHash) - - return nil -} - -func (ag *AuthGeneratorCommand) generateScramAuth(username, password string) error { - // Get password if not provided - if password == "" { - pass1 := ag.promptPassword("Enter password: ") - pass2 := ag.promptPassword("Confirm password: ") - if pass1 != pass2 { - return fmt.Errorf("passwords don't match") - } - password = pass1 - } - - // Generate salt - salt := make([]byte, 16) - if _, err := rand.Read(salt); err != nil { - return fmt.Errorf("failed to generate salt: %w", err) - } - - // Derive SCRAM credential - cred, err := scram.DeriveCredential(username, password, salt, 3, 65536, 4) - if err != nil { - return fmt.Errorf("failed to derive SCRAM credential: %w", err) - } - - // Output SCRAM configuration - fmt.Fprintln(ag.output, "\n# SCRAM Auth Configuration (for TCP sources/sinks)") - fmt.Fprintln(ag.output, "# Add to logwisp.toml:") - fmt.Fprintln(ag.output, "[[pipelines.auth.scram_auth.users]]") - fmt.Fprintf(ag.output, "username = %q\n", username) - fmt.Fprintf(ag.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey)) - fmt.Fprintf(ag.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey)) - fmt.Fprintf(ag.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt)) - fmt.Fprintf(ag.output, "argon_time = %d\n", cred.ArgonTime) - fmt.Fprintf(ag.output, "argon_memory = %d\n", cred.ArgonMemory) - fmt.Fprintf(ag.output, "argon_threads = %d\n\n", cred.ArgonThreads) - - return nil -} - -func (ag *AuthGeneratorCommand) generateToken(length int) error { - if length < 16 { - fmt.Fprintln(ag.errOut, "Warning: tokens < 16 bytes are cryptographically weak") - } - if length > 512 { - return fmt.Errorf("token length exceeds maximum (512 bytes)") - } - - token := make([]byte, length) - if _, err := rand.Read(token); err != nil { - return fmt.Errorf("failed to generate random bytes: %w", err) - } - - b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token) - hex := fmt.Sprintf("%x", token) - - fmt.Fprintln(ag.output, "\n# Bearer Token Configuration") - fmt.Fprintln(ag.output, "# Add to logwisp.toml:") - fmt.Fprintf(ag.output, "tokens = [%q]\n\n", b64) - - fmt.Fprintln(ag.output, "# Generated Token:") - fmt.Fprintf(ag.output, "Base64: %s\n", b64) - fmt.Fprintf(ag.output, "Hex: %s\n", hex) - - return nil -} - -func (ag *AuthGeneratorCommand) promptPassword(prompt string) string { - fmt.Fprint(ag.errOut, prompt) - password, err := term.ReadPassword(syscall.Stdin) - fmt.Fprintln(ag.errOut) - if err != nil { - fmt.Fprintf(ag.errOut, "Failed to read password: %v\n", err) - os.Exit(1) - } - return string(password) -} \ No newline at end of file diff --git a/src/internal/scram/client.go b/src/internal/auth/scram_client.go similarity index 85% rename from src/internal/scram/client.go rename to src/internal/auth/scram_client.go index 8e04459..3722269 100644 --- a/src/internal/scram/client.go +++ b/src/internal/auth/scram_client.go @@ -1,5 +1,5 @@ -// FILE: src/internal/scram/client.go -package scram +// FILE: src/internal/auth/scram_client.go +package auth import ( "crypto/rand" @@ -12,7 +12,7 @@ import ( ) // Client handles SCRAM client-side authentication -type Client struct { +type ScramClient struct { Username string Password string @@ -23,16 +23,16 @@ type Client struct { serverKey []byte } -// NewClient creates SCRAM client -func NewClient(username, password string) *Client { - return &Client{ +// NewScramClient creates SCRAM client +func NewScramClient(username, password string) *ScramClient { + return &ScramClient{ Username: username, Password: password, } } // StartAuthentication generates ClientFirst message -func (c *Client) StartAuthentication() (*ClientFirst, error) { +func (c *ScramClient) StartAuthentication() (*ClientFirst, error) { // Generate client nonce nonce := make([]byte, 32) if _, err := rand.Read(nonce); err != nil { @@ -47,7 +47,7 @@ func (c *Client) StartAuthentication() (*ClientFirst, error) { } // ProcessServerFirst handles server challenge -func (c *Client) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) { +func (c *ScramClient) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) { c.serverFirst = msg // Decode salt @@ -83,7 +83,7 @@ func (c *Client) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) { } // VerifyServerFinal validates server signature -func (c *Client) VerifyServerFinal(msg *ServerFinal) error { +func (c *ScramClient) VerifyServerFinal(msg *ServerFinal) error { if c.authMessage == "" || c.serverKey == nil { return fmt.Errorf("invalid handshake state") } diff --git a/src/internal/scram/credential.go b/src/internal/auth/scram_credential.go similarity index 85% rename from src/internal/scram/credential.go rename to src/internal/auth/scram_credential.go index ac77650..1851a40 100644 --- a/src/internal/scram/credential.go +++ b/src/internal/auth/scram_credential.go @@ -1,5 +1,5 @@ -// FILE: src/internal/scram/credential.go -package scram +// FILE: src/internal/auth/scram_credential.go +package auth import ( "crypto/hmac" @@ -9,6 +9,8 @@ import ( "fmt" "strings" + "logwisp/src/internal/core" + "golang.org/x/crypto/argon2" ) @@ -31,7 +33,13 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3 } // Derive salted password using Argon2id - saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, 32) + saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, core.Argon2KeyLen) + + // Construct PHC format for basic auth compatibility + saltB64 := base64.RawStdEncoding.EncodeToString(salt) + hashB64 := base64.RawStdEncoding.EncodeToString(saltedPassword) + phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, memory, time, threads, saltB64, hashB64) // Derive keys clientKey := computeHMAC(saltedPassword, []byte("Client Key")) @@ -46,6 +54,7 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3 ArgonThreads: threads, StoredKey: storedKey[:], ServerKey: serverKey, + PHCHash: phcHash, }, nil } diff --git a/src/internal/auth/scram_manager.go b/src/internal/auth/scram_manager.go new file mode 100644 index 0000000..230497a --- /dev/null +++ b/src/internal/auth/scram_manager.go @@ -0,0 +1,83 @@ +// FILE: src/internal/auth/scram_manager.go +package auth + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + + "logwisp/src/internal/config" +) + +// ScramManager provides high-level SCRAM operations with rate limiting +type ScramManager struct { + server *ScramServer +} + +// NewScramManager creates SCRAM manager +func NewScramManager(scramAuthCfg *config.ScramAuthConfig) *ScramManager { + manager := &ScramManager{ + server: NewScramServer(), + } + + // Load users from SCRAM config + for _, user := range scramAuthCfg.Users { + storedKey, err := base64.StdEncoding.DecodeString(user.StoredKey) + if err != nil { + // Skip user with invalid stored key + continue + } + + serverKey, err := base64.StdEncoding.DecodeString(user.ServerKey) + if err != nil { + // Skip user with invalid server key + continue + } + + salt, err := base64.StdEncoding.DecodeString(user.Salt) + if err != nil { + // Skip user with invalid salt + continue + } + + cred := &Credential{ + Username: user.Username, + StoredKey: storedKey, + ServerKey: serverKey, + Salt: salt, + ArgonTime: user.ArgonTime, + ArgonMemory: user.ArgonMemory, + ArgonThreads: user.ArgonThreads, + } + manager.server.AddCredential(cred) + } + + return manager +} + +// RegisterUser creates new user credential +func (sm *ScramManager) RegisterUser(username, password string) error { + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return fmt.Errorf("salt generation failed: %w", err) + } + + cred, err := DeriveCredential(username, password, salt, + sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads) + if err != nil { + return err + } + + sm.server.AddCredential(cred) + return nil +} + +// HandleClientFirst wraps server's HandleClientFirst +func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { + return sm.server.HandleClientFirst(msg) +} + +// HandleClientFinal wraps server's HandleClientFinal +func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { + return sm.server.HandleClientFinal(msg) +} \ No newline at end of file diff --git a/src/internal/auth/scram_message.go b/src/internal/auth/scram_message.go new file mode 100644 index 0000000..37f0842 --- /dev/null +++ b/src/internal/auth/scram_message.go @@ -0,0 +1,38 @@ +// FILE: src/internal/auth/scram_message.go +package auth + +import ( + "fmt" +) + +// ClientFirst initiates authentication +type ClientFirst struct { + Username string `json:"u"` + ClientNonce string `json:"n"` +} + +// ServerFirst contains server challenge +type ServerFirst struct { + FullNonce string `json:"r"` // client_nonce + server_nonce + Salt string `json:"s"` // base64 + ArgonTime uint32 `json:"t"` + ArgonMemory uint32 `json:"m"` + ArgonThreads uint8 `json:"p"` +} + +// ClientFinal contains client proof +type ClientFinal struct { + FullNonce string `json:"r"` + ClientProof string `json:"p"` // base64 +} + +// ServerFinal contains server signature for mutual auth +type ServerFinal struct { + ServerSignature string `json:"v"` // base64 + SessionID string `json:"sid,omitempty"` +} + +func (sf *ServerFirst) Marshal() string { + return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d", + sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads) +} \ No newline at end of file diff --git a/src/internal/auth/scram_protocol.go b/src/internal/auth/scram_protocol.go new file mode 100644 index 0000000..03f384f --- /dev/null +++ b/src/internal/auth/scram_protocol.go @@ -0,0 +1,117 @@ +// FILE: src/internal/auth/scram_protocol.go +package auth + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/lixenwraith/log" + "github.com/panjf2000/gnet/v2" +) + +// ScramProtocolHandler handles SCRAM message exchange for TCP +type ScramProtocolHandler struct { + manager *ScramManager + logger *log.Logger +} + +// NewScramProtocolHandler creates protocol handler +func NewScramProtocolHandler(manager *ScramManager, logger *log.Logger) *ScramProtocolHandler { + return &ScramProtocolHandler{ + manager: manager, + logger: logger, + } +} + +// HandleAuthMessage processes a complete auth line from buffer +func (sph *ScramProtocolHandler) HandleAuthMessage(line []byte, conn gnet.Conn) (authenticated bool, session *Session, err error) { + // Parse SCRAM messages + parts := strings.Fields(string(line)) + if len(parts) < 2 { + conn.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil) + return false, nil, fmt.Errorf("invalid message format") + } + + switch parts[0] { + case "SCRAM-FIRST": + // Parse ClientFirst JSON + var clientFirst ClientFirst + if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil { + conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil) + return false, nil, fmt.Errorf("invalid JSON") + } + + // Process with SCRAM server + serverFirst, err := sph.manager.HandleClientFirst(&clientFirst) + if err != nil { + // Still send challenge to prevent user enumeration + response, _ := json.Marshal(serverFirst) + conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil) + return false, nil, err + } + + // Send ServerFirst challenge + response, _ := json.Marshal(serverFirst) + conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil) + return false, nil, nil // Not authenticated yet + + case "SCRAM-PROOF": + // Parse ClientFinal JSON + var clientFinal ClientFinal + if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil { + conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil) + return false, nil, fmt.Errorf("invalid JSON") + } + + // Verify proof + serverFinal, err := sph.manager.HandleClientFinal(&clientFinal) + if err != nil { + conn.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil) + return false, nil, err + } + + // Authentication successful + session = &Session{ + ID: serverFinal.SessionID, + Method: "scram-sha-256", + RemoteAddr: conn.RemoteAddr().String(), + CreatedAt: time.Now(), + } + + // Send ServerFinal with signature + response, _ := json.Marshal(serverFinal) + conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil) + + return true, session, nil + + default: + conn.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil) + return false, nil, fmt.Errorf("unknown command: %s", parts[0]) + } +} + +// FormatSCRAMRequest formats a SCRAM protocol message for TCP +func FormatSCRAMRequest(command string, data interface{}) (string, error) { + jsonData, err := json.Marshal(data) + if err != nil { + return "", fmt.Errorf("failed to marshal %s: %w", command, err) + } + return fmt.Sprintf("%s %s\n", command, jsonData), nil +} + +// ParseSCRAMResponse parses a SCRAM protocol response from TCP +func ParseSCRAMResponse(response string) (command string, data string, err error) { + response = strings.TrimSpace(response) + parts := strings.SplitN(response, " ", 2) + if len(parts) < 1 { + return "", "", fmt.Errorf("empty response") + } + + command = parts[0] + if len(parts) > 1 { + data = parts[1] + } + return command, data, nil +} \ No newline at end of file diff --git a/src/internal/scram/server.go b/src/internal/auth/scram_server.go similarity index 85% rename from src/internal/scram/server.go rename to src/internal/auth/scram_server.go index ef61a59..b9252c9 100644 --- a/src/internal/scram/server.go +++ b/src/internal/auth/scram_server.go @@ -1,5 +1,5 @@ -// FILE: src/internal/scram/server.go -package scram +// FILE: src/internal/auth/scram_server.go +package auth import ( "crypto/rand" @@ -9,14 +9,17 @@ import ( "fmt" "sync" "time" + + "logwisp/src/internal/core" ) // Server handles SCRAM authentication -type Server struct { +type ScramServer struct { credentials map[string]*Credential handshakes map[string]*HandshakeState mu sync.RWMutex + // TODO: configurability useful? to be included in config or refactor to use core.const directly for simplicity // Default Argon2 params for new registrations DefaultTime uint32 DefaultMemory uint32 @@ -29,32 +32,30 @@ type HandshakeState struct { ClientNonce string ServerNonce string FullNonce string - AuthMessage string Credential *Credential CreatedAt time.Time - ClientProof []byte } -// NewServer creates SCRAM server -func NewServer() *Server { - return &Server{ +// NewScramServer creates SCRAM server +func NewScramServer() *ScramServer { + return &ScramServer{ credentials: make(map[string]*Credential), handshakes: make(map[string]*HandshakeState), - DefaultTime: 3, - DefaultMemory: 64 * 1024, - DefaultThreads: 4, + DefaultTime: core.Argon2Time, + DefaultMemory: core.Argon2Memory, + DefaultThreads: core.Argon2Threads, } } // AddCredential registers user credential -func (s *Server) AddCredential(cred *Credential) { +func (s *ScramServer) AddCredential(cred *Credential) { s.mu.Lock() defer s.mu.Unlock() s.credentials[cred.Username] = cred } // HandleClientFirst processes initial auth request -func (s *Server) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { +func (s *ScramServer) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { s.mu.Lock() defer s.mu.Unlock() @@ -103,7 +104,7 @@ func (s *Server) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { } // HandleClientFinal verifies client proof -func (s *Server) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { +func (s *ScramServer) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { s.mu.Lock() defer s.mu.Unlock() @@ -157,7 +158,7 @@ func (s *Server) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { }, nil } -func (s *Server) cleanupHandshakes() { +func (s *ScramServer) cleanupHandshakes() { cutoff := time.Now().Add(-60 * time.Second) for nonce, state := range s.handshakes { if state.CreatedAt.Before(cutoff) { @@ -170,10 +171,4 @@ func generateNonce() string { b := make([]byte, 32) rand.Read(b) return base64.StdEncoding.EncodeToString(b) -} - -func generateSessionID() string { - b := make([]byte, 24) - rand.Read(b) - return base64.URLEncoding.EncodeToString(b) } \ No newline at end of file diff --git a/src/internal/config/auth.go b/src/internal/config/auth.go deleted file mode 100644 index 8b22cbc..0000000 --- a/src/internal/config/auth.go +++ /dev/null @@ -1,81 +0,0 @@ -// FILE: logwisp/src/internal/config/auth.go -package config - -import ( - "fmt" -) - -type AuthConfig struct { - // Authentication type: "none", "basic", "scram", "bearer", "mtls" - Type string `toml:"type"` - - BasicAuth *BasicAuthConfig `toml:"basic_auth"` - ScramAuth *ScramAuthConfig `toml:"scram_auth"` - BearerAuth *BearerAuthConfig `toml:"bearer_auth"` -} - -type BasicAuthConfig struct { - Users []BasicAuthUser `toml:"users"` - Realm string `toml:"realm"` -} - -type BasicAuthUser struct { - Username string `toml:"username"` - PasswordHash string `toml:"password_hash"` // Argon2 -} - -type ScramAuthConfig struct { - Users []ScramUser `toml:"users"` -} - -type ScramUser struct { - Username string `toml:"username"` - StoredKey string `toml:"stored_key"` // base64 - ServerKey string `toml:"server_key"` // base64 - Salt string `toml:"salt"` // base64 - ArgonTime uint32 `toml:"argon_time"` - ArgonMemory uint32 `toml:"argon_memory"` - ArgonThreads uint8 `toml:"argon_threads"` -} - -type BearerAuthConfig struct { - // Static tokens - Tokens []string `toml:"tokens"` - - // TODO: Maybe future development - // // JWT validation - // JWT *JWTConfig `toml:"jwt"` -} - -// TODO: Maybe future development -// type JWTConfig struct { -// JWKSURL string `toml:"jwks_url"` -// SigningKey string `toml:"signing_key"` -// Issuer string `toml:"issuer"` -// Audience string `toml:"audience"` -// } - -func validateAuth(pipelineName string, auth *AuthConfig) error { - if auth == nil { - return nil - } - - validTypes := map[string]bool{"none": true, "basic": true, "scram": true, "bearer": true, "mtls": true} - if !validTypes[auth.Type] { - return fmt.Errorf("pipeline '%s': invalid auth type: %s", pipelineName, auth.Type) - } - - if auth.Type == "basic" && auth.BasicAuth == nil { - return fmt.Errorf("pipeline '%s': basic auth type specified but config missing", pipelineName) - } - - if auth.Type == "scram" && auth.ScramAuth == nil { - return fmt.Errorf("pipeline '%s': scram auth type specified but config missing", pipelineName) - } - - if auth.Type == "bearer" && auth.BearerAuth == nil { - return fmt.Errorf("pipeline '%s': bearer auth type specified but config missing", pipelineName) - } - - return nil -} \ No newline at end of file diff --git a/src/internal/config/config.go b/src/internal/config/config.go index 37186a4..da6e7b1 100644 --- a/src/internal/config/config.go +++ b/src/internal/config/config.go @@ -1,6 +1,8 @@ // FILE: logwisp/src/internal/config/config.go package config +// --- LogWisp Configuration Options --- + type Config struct { // Top-level flags for application control Background bool `toml:"background"` @@ -10,7 +12,6 @@ type Config struct { // Runtime behavior flags DisableStatusReporter bool `toml:"disable_status_reporter"` ConfigAutoReload bool `toml:"config_auto_reload"` - ConfigSaveOnExit bool `toml:"config_save_on_exit"` // Internal flag indicating demonized child process BackgroundDaemon bool `toml:"background-daemon"` @@ -21,4 +22,365 @@ type Config struct { // Existing fields Logging *LogConfig `toml:"logging"` Pipelines []PipelineConfig `toml:"pipelines"` +} + +// --- Logging Options --- + +// Represents logging configuration for LogWisp +type LogConfig struct { + // Output mode: "file", "stdout", "stderr", "split", "all", "none" + Output string `toml:"output"` + + // Log level: "debug", "info", "warn", "error" + Level string `toml:"level"` + + // File output settings (when Output includes "file" or "all") + File *LogFileConfig `toml:"file"` + + // Console output settings + Console *LogConsoleConfig `toml:"console"` +} + +type LogFileConfig struct { + // Directory for log files + Directory string `toml:"directory"` + + // Base name for log files + Name string `toml:"name"` + + // Maximum size per log file in MB + MaxSizeMB int64 `toml:"max_size_mb"` + + // Maximum total size of all logs in MB + MaxTotalSizeMB int64 `toml:"max_total_size_mb"` + + // Log retention in hours (0 = disabled) + RetentionHours float64 `toml:"retention_hours"` +} + +type LogConsoleConfig struct { + // Target for console output: "stdout", "stderr", "split" + // "split": info/debug to stdout, warn/error to stderr + Target string `toml:"target"` + + // Format: "txt" or "json" + Format string `toml:"format"` +} + +// --- Pipeline Options --- + +type PipelineConfig struct { + Name string `toml:"name"` + Sources []SourceConfig `toml:"sources"` + RateLimit *RateLimitConfig `toml:"rate_limit"` + Filters []FilterConfig `toml:"filters"` + Format *FormatConfig `toml:"format"` + + Sinks []SinkConfig `toml:"sinks"` + // Auth *ServerAuthConfig `toml:"auth"` // Global auth for pipeline +} + +// Common configuration structs used across components + +type NetLimitConfig struct { + Enabled bool `toml:"enabled"` + MaxConnections int64 `toml:"max_connections"` + RequestsPerSecond float64 `toml:"requests_per_second"` + BurstSize int64 `toml:"burst_size"` + ResponseMessage string `toml:"response_message"` + ResponseCode int64 `toml:"response_code"` // Default: 429 + MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"` + MaxConnectionsPerUser int64 `toml:"max_connections_per_user"` + MaxConnectionsPerToken int64 `toml:"max_connections_per_token"` + MaxConnectionsTotal int64 `toml:"max_connections_total"` + IPWhitelist []string `toml:"ip_whitelist"` + IPBlacklist []string `toml:"ip_blacklist"` +} + +type TLSConfig struct { + Enabled bool `toml:"enabled"` + CertFile string `toml:"cert_file"` + KeyFile string `toml:"key_file"` + CAFile string `toml:"ca_file"` + ServerName string `toml:"server_name"` // for client verification + SkipVerify bool `toml:"skip_verify"` + + // Client certificate authentication + ClientAuth bool `toml:"client_auth"` + ClientCAFile string `toml:"client_ca_file"` + VerifyClientCert bool `toml:"verify_client_cert"` + + // TLS version constraints + MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3" + MaxVersion string `toml:"max_version"` + + // Cipher suites (comma-separated list) + CipherSuites string `toml:"cipher_suites"` +} + +type HeartbeatConfig struct { + Enabled bool `toml:"enabled"` + Interval int64 `toml:"interval_ms"` + IncludeTimestamp bool `toml:"include_timestamp"` + IncludeStats bool `toml:"include_stats"` + Format string `toml:"format"` +} + +type ClientAuthConfig struct { + Type string `toml:"type"` // "none", "basic", "token", "scram" + Username string `toml:"username"` + Password string `toml:"password"` + Token string `toml:"token"` +} + +// --- Source Options --- + +type SourceConfig struct { + Type string `toml:"type"` + + // Polymorphic - only one populated based on type + Directory *DirectorySourceOptions `toml:"directory,omitempty"` + Stdin *StdinSourceOptions `toml:"stdin,omitempty"` + HTTP *HTTPSourceOptions `toml:"http,omitempty"` + TCP *TCPSourceOptions `toml:"tcp,omitempty"` +} + +type DirectorySourceOptions struct { + Path string `toml:"path"` + Pattern string `toml:"pattern"` // glob pattern + CheckIntervalMS int64 `toml:"check_interval_ms"` + Recursive bool `toml:"recursive"` + FollowSymlinks bool `toml:"follow_symlinks"` + DeleteAfterRead bool `toml:"delete_after_read"` + MoveToDirectory string `toml:"move_to_directory"` // move after processing +} + +type StdinSourceOptions struct { + BufferSize int64 `toml:"buffer_size"` +} + +type HTTPSourceOptions struct { + Host string `toml:"host"` + Port int64 `toml:"port"` + IngestPath string `toml:"ingest_path"` + BufferSize int64 `toml:"buffer_size"` + MaxRequestBodySize int64 `toml:"max_body_size"` + ReadTimeout int64 `toml:"read_timeout_ms"` + WriteTimeout int64 `toml:"write_timeout_ms"` + NetLimit *NetLimitConfig `toml:"net_limit"` + TLS *TLSConfig `toml:"tls"` + Auth *ServerAuthConfig `toml:"auth"` +} + +type TCPSourceOptions struct { + Host string `toml:"host"` + Port int64 `toml:"port"` + BufferSize int64 `toml:"buffer_size"` + ReadTimeout int64 `toml:"read_timeout_ms"` + KeepAlive bool `toml:"keep_alive"` + KeepAlivePeriod int64 `toml:"keep_alive_period_ms"` + NetLimit *NetLimitConfig `toml:"net_limit"` + Auth *ServerAuthConfig `toml:"auth"` +} + +// --- Sink Options --- + +type SinkConfig struct { + Type string `toml:"type"` + + // Polymorphic - only one populated based on type + Console *ConsoleSinkOptions `toml:"console,omitempty"` + File *FileSinkOptions `toml:"file,omitempty"` + HTTP *HTTPSinkOptions `toml:"http,omitempty"` + TCP *TCPSinkOptions `toml:"tcp,omitempty"` + HTTPClient *HTTPClientSinkOptions `toml:"http_client,omitempty"` + TCPClient *TCPClientSinkOptions `toml:"tcp_client,omitempty"` +} + +type ConsoleSinkOptions struct { + Target string `toml:"target"` // "stdout", "stderr", "split" + Colorize bool `toml:"colorize"` + BufferSize int64 `toml:"buffer_size"` +} + +type FileSinkOptions struct { + Directory string `toml:"directory"` + Name string `toml:"name"` + // Extension string `toml:"extension"` + MaxSizeMB int64 `toml:"max_size_mb"` + MaxTotalSizeMB int64 `toml:"max_total_size_mb"` + MinDiskFreeMB int64 `toml:"min_disk_free_mb"` + RetentionHours float64 `toml:"retention_hours"` + BufferSize int64 `toml:"buffer_size"` + FlushInterval int64 `toml:"flush_interval_ms"` +} + +type HTTPSinkOptions struct { + Host string `toml:"host"` + Port int64 `toml:"port"` + StreamPath string `toml:"stream_path"` + StatusPath string `toml:"status_path"` + BufferSize int64 `toml:"buffer_size"` + WriteTimeout int64 `toml:"write_timeout_ms"` + Heartbeat *HeartbeatConfig `toml:"heartbeat"` + NetLimit *NetLimitConfig `toml:"net_limit"` + TLS *TLSConfig `toml:"tls"` + Auth *ServerAuthConfig `toml:"auth"` +} + +type TCPSinkOptions struct { + Host string `toml:"host"` + Port int64 `toml:"port"` + BufferSize int64 `toml:"buffer_size"` + WriteTimeout int64 `toml:"write_timeout_ms"` + KeepAlive bool `toml:"keep_alive"` + KeepAlivePeriod int64 `toml:"keep_alive_period_ms"` + Heartbeat *HeartbeatConfig `toml:"heartbeat"` + NetLimit *NetLimitConfig `toml:"net_limit"` + Auth *ServerAuthConfig `toml:"auth"` +} + +type HTTPClientSinkOptions struct { + URL string `toml:"url"` + Headers map[string]string `toml:"headers"` + BufferSize int64 `toml:"buffer_size"` + BatchSize int64 `toml:"batch_size"` + BatchDelayMS int64 `toml:"batch_delay_ms"` + Timeout int64 `toml:"timeout_seconds"` + MaxRetries int64 `toml:"max_retries"` + RetryDelayMS int64 `toml:"retry_delay_ms"` + RetryBackoff float64 `toml:"retry_backoff"` + InsecureSkipVerify bool `toml:"insecure_skip_verify"` + TLS *TLSConfig `toml:"tls"` + Auth *ClientAuthConfig `toml:"auth"` +} + +type TCPClientSinkOptions struct { + Host string `toml:"host"` + Port int64 `toml:"port"` + BufferSize int64 `toml:"buffer_size"` + DialTimeout int64 `toml:"dial_timeout_seconds"` + WriteTimeout int64 `toml:"write_timeout_seconds"` + ReadTimeout int64 `toml:"read_timeout_seconds"` + KeepAlive int64 `toml:"keep_alive_seconds"` + ReconnectDelayMS int64 `toml:"reconnect_delay_ms"` + MaxReconnectDelayMS int64 `toml:"max_reconnect_delay_ms"` + ReconnectBackoff float64 `toml:"reconnect_backoff"` + Auth *ClientAuthConfig `toml:"auth"` +} + +// --- Rate Limit Options --- + +// Defines the action to take when a rate limit is exceeded. +type RateLimitPolicy int + +const ( + // PolicyPass allows all logs through, effectively disabling the limiter. + PolicyPass RateLimitPolicy = iota + // PolicyDrop drops logs that exceed the rate limit. + PolicyDrop +) + +// Defines the configuration for pipeline-level rate limiting. +type RateLimitConfig struct { + // Rate is the number of log entries allowed per second. Default: 0 (disabled). + Rate float64 `toml:"rate"` + // Burst is the maximum number of log entries that can be sent in a short burst. Defaults to the Rate. + Burst float64 `toml:"burst"` + // Policy defines the action to take when the limit is exceeded. "pass" or "drop". + Policy string `toml:"policy"` + // MaxEntrySizeBytes is the maximum allowed size for a single log entry. 0 = no limit. + MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"` +} + +// --- Filter Options --- + +// Represents the filter type +type FilterType string + +const ( + FilterTypeInclude FilterType = "include" // Whitelist - only matching logs pass + FilterTypeExclude FilterType = "exclude" // Blacklist - matching logs are dropped +) + +// Represents how multiple patterns are combined +type FilterLogic string + +const ( + FilterLogicOr FilterLogic = "or" // Match any pattern + FilterLogicAnd FilterLogic = "and" // Match all patterns +) + +// Represents filter configuration +type FilterConfig struct { + Type FilterType `toml:"type"` + Logic FilterLogic `toml:"logic"` + Patterns []string `toml:"patterns"` +} + +// --- Formatter Options --- + +type FormatConfig struct { + // Format configuration - polymorphic like sources/sinks + Type string `toml:"type"` // "json", "text", "raw" + + // Only one will be populated based on format type + JSONFormatOptions *JSONFormatterOptions `toml:"json_format,omitempty"` + TextFormatOptions *TextFormatterOptions `toml:"text_format,omitempty"` + RawFormatOptions *RawFormatterOptions `toml:"raw_format,omitempty"` +} + +type JSONFormatterOptions struct { + Pretty bool `toml:"pretty"` + TimestampField string `toml:"timestamp_field"` + LevelField string `toml:"level_field"` + MessageField string `toml:"message_field"` + SourceField string `toml:"source_field"` +} + +type TextFormatterOptions struct { + Template string `toml:"template"` + TimestampFormat string `toml:"timestamp_format"` +} + +type RawFormatterOptions struct { + AddNewLine bool `toml:"add_new_line"` +} + +// --- Server-side Auth (for sources) --- + +type BasicAuthConfig struct { + Users []BasicAuthUser `toml:"users"` + Realm string `toml:"realm"` +} + +type BasicAuthUser struct { + Username string `toml:"username"` + PasswordHash string `toml:"password_hash"` // Argon2 +} + +type ScramAuthConfig struct { + Users []ScramUser `toml:"users"` +} + +type ScramUser struct { + Username string `toml:"username"` + StoredKey string `toml:"stored_key"` // base64 + ServerKey string `toml:"server_key"` // base64 + Salt string `toml:"salt"` // base64 + ArgonTime uint32 `toml:"argon_time"` + ArgonMemory uint32 `toml:"argon_memory"` + ArgonThreads uint8 `toml:"argon_threads"` +} + +type TokenAuthConfig struct { + Tokens []string `toml:"tokens"` +} + +// Server auth wrapper (for sources accepting connections) +type ServerAuthConfig struct { + Type string `toml:"type"` // "none", "basic", "token", "scram" + Basic *BasicAuthConfig `toml:"basic,omitempty"` + Token *TokenAuthConfig `toml:"token,omitempty"` + Scram *ScramAuthConfig `toml:"scram,omitempty"` } \ No newline at end of file diff --git a/src/internal/config/filter.go b/src/internal/config/filter.go deleted file mode 100644 index 09143cf..0000000 --- a/src/internal/config/filter.go +++ /dev/null @@ -1,65 +0,0 @@ -// FILE: logwisp/src/internal/config/filter.go -package config - -import ( - "fmt" - "regexp" -) - -// Represents the filter type -type FilterType string - -const ( - FilterTypeInclude FilterType = "include" // Whitelist - only matching logs pass - FilterTypeExclude FilterType = "exclude" // Blacklist - matching logs are dropped -) - -// Represents how multiple patterns are combined -type FilterLogic string - -const ( - FilterLogicOr FilterLogic = "or" // Match any pattern - FilterLogicAnd FilterLogic = "and" // Match all patterns -) - -// Represents filter configuration -type FilterConfig struct { - Type FilterType `toml:"type"` - Logic FilterLogic `toml:"logic"` - Patterns []string `toml:"patterns"` -} - -func validateFilter(pipelineName string, filterIndex int, cfg *FilterConfig) error { - // Validate filter type - switch cfg.Type { - case FilterTypeInclude, FilterTypeExclude, "": - // Valid types - default: - return fmt.Errorf("pipeline '%s' filter[%d]: invalid type '%s' (must be 'include' or 'exclude')", - pipelineName, filterIndex, cfg.Type) - } - - // Validate filter logic - switch cfg.Logic { - case FilterLogicOr, FilterLogicAnd, "": - // Valid logic - default: - return fmt.Errorf("pipeline '%s' filter[%d]: invalid logic '%s' (must be 'or' or 'and')", - pipelineName, filterIndex, cfg.Logic) - } - - // Empty patterns is valid - passes everything - if len(cfg.Patterns) == 0 { - return nil - } - - // Validate regex patterns - for i, pattern := range cfg.Patterns { - if _, err := regexp.Compile(pattern); err != nil { - return fmt.Errorf("pipeline '%s' filter[%d] pattern[%d] '%s': invalid regex: %w", - pipelineName, filterIndex, i, pattern, err) - } - } - - return nil -} \ No newline at end of file diff --git a/src/internal/config/limit.go b/src/internal/config/limit.go deleted file mode 100644 index 48980d9..0000000 --- a/src/internal/config/limit.go +++ /dev/null @@ -1,58 +0,0 @@ -// FILE: logwisp/src/internal/config/ratelimit.go -package config - -import ( - "fmt" - "strings" -) - -// Defines the action to take when a rate limit is exceeded. -type RateLimitPolicy int - -const ( - // PolicyPass allows all logs through, effectively disabling the limiter. - PolicyPass RateLimitPolicy = iota - // PolicyDrop drops logs that exceed the rate limit. - PolicyDrop -) - -// Defines the configuration for pipeline-level rate limiting. -type RateLimitConfig struct { - // Rate is the number of log entries allowed per second. Default: 0 (disabled). - Rate float64 `toml:"rate"` - // Burst is the maximum number of log entries that can be sent in a short burst. Defaults to the Rate. - Burst float64 `toml:"burst"` - // Policy defines the action to take when the limit is exceeded. "pass" or "drop". - Policy string `toml:"policy"` - // MaxEntrySizeBytes is the maximum allowed size for a single log entry. 0 = no limit. - MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"` -} - -func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { - if cfg == nil { - return nil - } - - if cfg.Rate < 0 { - return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName) - } - - if cfg.Burst < 0 { - return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName) - } - - if cfg.MaxEntrySizeBytes < 0 { - return fmt.Errorf("pipeline '%s': max entry size bytes cannot be negative", pipelineName) - } - - // Validate policy - switch strings.ToLower(cfg.Policy) { - case "", "pass", "drop": - // Valid policies - default: - return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')", - pipelineName, cfg.Policy) - } - - return nil -} \ No newline at end of file diff --git a/src/internal/config/loader.go b/src/internal/config/loader.go index e1b580b..2e1b75a 100644 --- a/src/internal/config/loader.go +++ b/src/internal/config/loader.go @@ -11,6 +11,8 @@ import ( lconfig "github.com/lixenwraith/config" ) +var configManager *lconfig.Config + func defaults() *Config { return &Config{ // Top-level flag defaults @@ -21,41 +23,46 @@ func defaults() *Config { // Runtime behavior defaults DisableStatusReporter: false, ConfigAutoReload: false, - ConfigSaveOnExit: false, // Child process indicator BackgroundDaemon: false, // Existing defaults - Logging: DefaultLogConfig(), + Logging: &LogConfig{ + Output: "stdout", + Level: "info", + File: &LogFileConfig{ + Directory: "./log", + Name: "logwisp", + MaxSizeMB: 100, + MaxTotalSizeMB: 1000, + RetentionHours: 168, // 7 days + }, + Console: &LogConsoleConfig{ + Target: "stdout", + Format: "txt", + }, + }, Pipelines: []PipelineConfig{ { Name: "default", Sources: []SourceConfig{ { Type: "directory", - Options: map[string]any{ - "path": "./", - "pattern": "*.log", - "check_interval_ms": int64(100), + Directory: &DirectorySourceOptions{ + Path: "./", + Pattern: "*.log", + CheckIntervalMS: int64(100), }, }, }, Sinks: []SinkConfig{ { - Type: "http", - Options: map[string]any{ - "port": int64(8080), - "buffer_size": int64(1000), - "stream_path": "/stream", - "status_path": "/status", - "heartbeat": map[string]any{ - "enabled": true, - "interval_seconds": int64(30), - "include_timestamp": true, - "include_stats": false, - "format": "comment", - }, + Type: "console", + Console: &ConsoleSinkOptions{ + Target: "stdout", + Colorize: false, + BufferSize: 100, }, }, }, @@ -68,18 +75,30 @@ func defaults() *Config { func Load(args []string) (*Config, error) { configPath, isExplicit := resolveConfigPath(args) // Build configuration with all sources + + // Create target config instance that will be populated + finalConfig := &Config{} + + // The builder now handles loading, populating the target struct, and validation cfg, err := lconfig.NewBuilder(). - WithDefaults(defaults()). - WithEnvPrefix("LOGWISP_"). - WithEnvTransform(customEnvTransform). - WithArgs(args). - WithFile(configPath). + WithTarget(finalConfig). // Typed target struct + WithDefaults(defaults()). // Default values WithSources( lconfig.SourceCLI, lconfig.SourceEnv, lconfig.SourceFile, lconfig.SourceDefault, ). + WithEnvTransform(customEnvTransform). // Convert '.' to '_' in env separation + WithEnvPrefix("LOGWISP_"). // Environment variable prefix + WithArgs(args). // Command-line arguments + WithFile(configPath). // TOML config file + WithFileFormat("toml"). // Explicit format + WithTypedValidator(validateConfig). // Centralized validation + WithSecurityOptions(lconfig.SecurityOptions{ + PreventPathTraversal: true, + MaxFileSize: 10 * 1024 * 1024, // 10MB max config + }). Build() if err != nil { @@ -88,42 +107,28 @@ func Load(args []string) (*Config, error) { if isExplicit { return nil, fmt.Errorf("config file not found: %s", configPath) } - // If the default config file is not found, it's not an error + // If the default config file is not found, it's not an error, default/cli/env will be used } else { - return nil, fmt.Errorf("failed to load config: %w", err) + return nil, fmt.Errorf("failed to load or validate config: %w", err) } } - // Scan into final config struct - using new interface - finalConfig := &Config{} - if err := cfg.Scan(finalConfig); err != nil { - return nil, fmt.Errorf("failed to scan config: %w", err) + // Store the config file path for hot reload + finalConfig.ConfigFile = configPath + + // Store the manager for hot reload + if cfg != nil { + configManager = cfg } - // Set config file path if it exists - if _, err := os.Stat(configPath); err == nil { - finalConfig.ConfigFile = configPath - } - - // Ensure critical fields are not nil - if finalConfig.Logging == nil { - finalConfig.Logging = DefaultLogConfig() - } - - // Apply console target overrides if needed - if err := applyConsoleTargetOverrides(finalConfig); err != nil { - return nil, fmt.Errorf("failed to apply console target overrides: %w", err) - } - - // Validate configuration - return finalConfig, finalConfig.validate() + return finalConfig, nil } // Returns the configuration file path func resolveConfigPath(args []string) (path string, isExplicit bool) { // 1. Check for --config flag in command-line arguments (highest precedence) for i, arg := range args { - if (arg == "--config" || arg == "-c") && i+1 < len(args) { + if arg == "-c" { return args[i+1], true } if strings.HasPrefix(arg, "--config=") { @@ -160,38 +165,4 @@ func customEnvTransform(path string) string { env = strings.ToUpper(env) // env = "LOGWISP_" + env // already added by WithEnvPrefix return env -} - -// Centralizes console target configuration -func applyConsoleTargetOverrides(cfg *Config) error { - // Check environment variable for console target override - consoleTarget := os.Getenv("LOGWISP_CONSOLE_TARGET") - if consoleTarget == "" { - return nil - } - - // Validate console target value - validTargets := map[string]bool{ - "stdout": true, - "stderr": true, - "split": true, - } - if !validTargets[consoleTarget] { - return fmt.Errorf("invalid LOGWISP_CONSOLE_TARGET value: %s", consoleTarget) - } - - // Apply to console sinks - for i, pipeline := range cfg.Pipelines { - for j, sink := range pipeline.Sinks { - if sink.Type == "console" { - if sink.Options == nil { - cfg.Pipelines[i].Sinks[j].Options = make(map[string]any) - } - // Set target for split mode handling - cfg.Pipelines[i].Sinks[j].Options["target"] = consoleTarget - } - } - } - - return nil } \ No newline at end of file diff --git a/src/internal/config/logging.go b/src/internal/config/logging.go deleted file mode 100644 index 8e6a15c..0000000 --- a/src/internal/config/logging.go +++ /dev/null @@ -1,99 +0,0 @@ -// FILE: logwisp/src/internal/config/logging.go -package config - -import "fmt" - -// Represents logging configuration for LogWisp -type LogConfig struct { - // Output mode: "file", "stdout", "stderr", "split", "all", "none" - Output string `toml:"output"` - - // Log level: "debug", "info", "warn", "error" - Level string `toml:"level"` - - // File output settings (when Output includes "file" or "all") - File *LogFileConfig `toml:"file"` - - // Console output settings - Console *LogConsoleConfig `toml:"console"` -} - -type LogFileConfig struct { - // Directory for log files - Directory string `toml:"directory"` - - // Base name for log files - Name string `toml:"name"` - - // Maximum size per log file in MB - MaxSizeMB int64 `toml:"max_size_mb"` - - // Maximum total size of all logs in MB - MaxTotalSizeMB int64 `toml:"max_total_size_mb"` - - // Log retention in hours (0 = disabled) - RetentionHours float64 `toml:"retention_hours"` -} - -type LogConsoleConfig struct { - // Target for console output: "stdout", "stderr", "split" - // "split": info/debug to stdout, warn/error to stderr - Target string `toml:"target"` - - // Format: "txt" or "json" - Format string `toml:"format"` -} - -// Returns sensible logging defaults -func DefaultLogConfig() *LogConfig { - return &LogConfig{ - Output: "stdout", - Level: "info", - File: &LogFileConfig{ - Directory: "./log", - Name: "logwisp", - MaxSizeMB: 100, - MaxTotalSizeMB: 1000, - RetentionHours: 168, // 7 days - }, - Console: &LogConsoleConfig{ - Target: "stdout", - Format: "txt", - }, - } -} - -func validateLogConfig(cfg *LogConfig) error { - validOutputs := map[string]bool{ - "file": true, "stdout": true, "stderr": true, - "split": true, "all": true, "none": true, - } - if !validOutputs[cfg.Output] { - return fmt.Errorf("invalid log output mode: %s", cfg.Output) - } - - validLevels := map[string]bool{ - "debug": true, "info": true, "warn": true, "error": true, - } - if !validLevels[cfg.Level] { - return fmt.Errorf("invalid log level: %s", cfg.Level) - } - - if cfg.Console != nil { - validTargets := map[string]bool{ - "stdout": true, "stderr": true, "split": true, - } - if !validTargets[cfg.Console.Target] { - return fmt.Errorf("invalid console target: %s", cfg.Console.Target) - } - - validFormats := map[string]bool{ - "txt": true, "json": true, "": true, - } - if !validFormats[cfg.Console.Format] { - return fmt.Errorf("invalid console format: %s", cfg.Console.Format) - } - } - - return nil -} \ No newline at end of file diff --git a/src/internal/config/pipeline.go b/src/internal/config/pipeline.go deleted file mode 100644 index bcc2b16..0000000 --- a/src/internal/config/pipeline.go +++ /dev/null @@ -1,416 +0,0 @@ -// FILE: logwisp/src/internal/config/pipeline.go -package config - -import ( - "fmt" - "net" - "net/url" - "path/filepath" - "strings" -) - -// Represents a data processing pipeline -type PipelineConfig struct { - // Pipeline identifier (used in logs and metrics) - Name string `toml:"name"` - - // Data sources for this pipeline - Sources []SourceConfig `toml:"sources"` - - // Rate limiting - RateLimit *RateLimitConfig `toml:"rate_limit"` - - // Filter configuration - Filters []FilterConfig `toml:"filters"` - - // Log formatting configuration - Format string `toml:"format"` - FormatOptions map[string]any `toml:"format_options"` - - // Output sinks for this pipeline - Sinks []SinkConfig `toml:"sinks"` - - // Authentication/Authorization (applies to network sinks) - Auth *AuthConfig `toml:"auth"` -} - -// Represents an input data source -type SourceConfig struct { - // Source type - Type string `toml:"type"` - - // Type-specific configuration options - Options map[string]any `toml:"options"` -} - -// Represents an output destination -type SinkConfig struct { - // Sink type - Type string `toml:"type"` - - // Type-specific configuration options - Options map[string]any `toml:"options"` -} - -func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) error { - if cfg.Type == "" { - return fmt.Errorf("pipeline '%s' source[%d]: missing type", pipelineName, sourceIndex) - } - - switch cfg.Type { - case "directory": - // Validate path - path, ok := cfg.Options["path"].(string) - if !ok || path == "" { - return fmt.Errorf("pipeline '%s' source[%d]: directory source requires 'path' option", - pipelineName, sourceIndex) - } - - // Check for directory traversal - if strings.Contains(path, "..") { - return fmt.Errorf("pipeline '%s' source[%d]: path contains directory traversal", - pipelineName, sourceIndex) - } - - // Validate pattern - if pattern, ok := cfg.Options["pattern"].(string); ok && pattern != "" { - // Try to compile as glob pattern (will be converted to regex internally) - if strings.Count(pattern, "*") == 0 && strings.Count(pattern, "?") == 0 { - // If no wildcards, ensure it's a valid filename - if filepath.Base(pattern) != pattern { - return fmt.Errorf("pipeline '%s' source[%d]: pattern contains path separators", - pipelineName, sourceIndex) - } - } - } - - // Validate check interval - if interval, ok := cfg.Options["check_interval_ms"]; ok { - if intVal, ok := interval.(int64); ok { - if intVal < 10 { - return fmt.Errorf("pipeline '%s' source[%d]: check interval too small: %d ms (min: 10ms)", - pipelineName, sourceIndex, intVal) - } - } else { - return fmt.Errorf("pipeline '%s' source[%d]: invalid check_interval_ms type", - pipelineName, sourceIndex) - } - } - - case "stdin": - // Validate buffer size - if bufSize, ok := cfg.Options["buffer_size"].(int64); ok { - if bufSize < 1 { - return fmt.Errorf("pipeline '%s' source[%d]: stdin buffer_size must be positive: %d", - pipelineName, sourceIndex, bufSize) - } - } - - case "http": - // Validate host - if host, ok := cfg.Options["host"].(string); ok && host != "" { - if net.ParseIP(host) == nil { - return fmt.Errorf("pipeline '%s' source[%d]: invalid IP address: %s", - pipelineName, sourceIndex, host) - } - } - - // Validate port - port, ok := cfg.Options["port"].(int64) - if !ok || port < 1 || port > 65535 { - return fmt.Errorf("pipeline '%s' source[%d]: invalid or missing HTTP port", - pipelineName, sourceIndex) - } - - // Validate path - if path, ok := cfg.Options["ingest_path"].(string); ok { - if !strings.HasPrefix(path, "/") { - return fmt.Errorf("pipeline '%s' source[%d]: ingest path must start with /: %s", - pipelineName, sourceIndex, path) - } - } - - // Validate net_limit - if nl, ok := cfg.Options["net_limit"].(map[string]any); ok { - if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, nl); err != nil { - return err - } - } - - // Validate TLS - if tls, ok := cfg.Options["tls"].(map[string]any); ok { - if err := validateTLSOptions("HTTP source", pipelineName, sourceIndex, tls); err != nil { - return err - } - } - - case "tcp": - // Validate host - if host, ok := cfg.Options["host"].(string); ok && host != "" { - if net.ParseIP(host) == nil { - return fmt.Errorf("pipeline '%s' source[%d]: invalid IP address: %s", - pipelineName, sourceIndex, host) - } - } - - // Validate port - port, ok := cfg.Options["port"].(int64) - if !ok || port < 1 || port > 65535 { - return fmt.Errorf("pipeline '%s' source[%d]: invalid or missing TCP port", - pipelineName, sourceIndex) - } - - // Validate net_limit - if nl, ok := cfg.Options["net_limit"].(map[string]any); ok { - if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, nl); err != nil { - return err - } - } - - default: - return fmt.Errorf("pipeline '%s' source[%d]: unknown source type '%s'", - pipelineName, sourceIndex, cfg.Type) - } - - return nil -} - -func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts map[int64]string) error { - if cfg.Type == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: missing type", pipelineName, sinkIndex) - } - - switch cfg.Type { - case "http": - // Extract and validate HTTP configuration - port, ok := cfg.Options["port"].(int64) - if !ok || port < 1 || port > 65535 { - return fmt.Errorf("pipeline '%s' sink[%d]: invalid or missing HTTP port", - pipelineName, sinkIndex) - } - - // Validate host - if host, ok := cfg.Options["host"].(string); ok && host != "" { - if net.ParseIP(host) == nil { - return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s", - pipelineName, sinkIndex, host) - } - } - - // Check port conflicts - if existing, exists := allPorts[port]; exists { - return fmt.Errorf("pipeline '%s' sink[%d]: HTTP port %d already used by %s", - pipelineName, sinkIndex, port, existing) - } - allPorts[port] = fmt.Sprintf("%s-http[%d]", pipelineName, sinkIndex) - - // Validate buffer size - if bufSize, ok := cfg.Options["buffer_size"].(int64); ok { - if bufSize < 1 { - return fmt.Errorf("pipeline '%s' sink[%d]: HTTP buffer size must be positive: %d", - pipelineName, sinkIndex, bufSize) - } - } - - // Validate paths - if streamPath, ok := cfg.Options["stream_path"].(string); ok { - if !strings.HasPrefix(streamPath, "/") { - return fmt.Errorf("pipeline '%s' sink[%d]: stream path must start with /: %s", - pipelineName, sinkIndex, streamPath) - } - } - - if statusPath, ok := cfg.Options["status_path"].(string); ok { - if !strings.HasPrefix(statusPath, "/") { - return fmt.Errorf("pipeline '%s' sink[%d]: status path must start with /: %s", - pipelineName, sinkIndex, statusPath) - } - } - - // Validate heartbeat - if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok { - if err := validateHeartbeatOptions("HTTP", pipelineName, sinkIndex, hb); err != nil { - return err - } - } - - // Validate TLS if present - if tls, ok := cfg.Options["tls"].(map[string]any); ok { - if err := validateTLSOptions("HTTP", pipelineName, sinkIndex, tls); err != nil { - return err - } - } - - // Validate net limit - if nl, ok := cfg.Options["net_limit"].(map[string]any); ok { - if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, nl); err != nil { - return err - } - } - - case "tcp": - // Extract and validate TCP configuration - port, ok := cfg.Options["port"].(int64) - if !ok || port < 1 || port > 65535 { - return fmt.Errorf("pipeline '%s' sink[%d]: invalid or missing TCP port", - pipelineName, sinkIndex) - } - - // Validate host - if host, ok := cfg.Options["host"].(string); ok && host != "" { - if net.ParseIP(host) == nil { - return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s", - pipelineName, sinkIndex, host) - } - } - - // Check port conflicts - if existing, exists := allPorts[port]; exists { - return fmt.Errorf("pipeline '%s' sink[%d]: TCP port %d already used by %s", - pipelineName, sinkIndex, port, existing) - } - allPorts[port] = fmt.Sprintf("%s-tcp[%d]", pipelineName, sinkIndex) - - // Validate buffer size - if bufSize, ok := cfg.Options["buffer_size"].(int64); ok { - if bufSize < 1 { - return fmt.Errorf("pipeline '%s' sink[%d]: TCP buffer size must be positive: %d", - pipelineName, sinkIndex, bufSize) - } - } - - // Validate heartbeat - if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok { - if err := validateHeartbeatOptions("TCP", pipelineName, sinkIndex, hb); err != nil { - return err - } - } - - // Validate net limit - if nl, ok := cfg.Options["net_limit"].(map[string]any); ok { - if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, nl); err != nil { - return err - } - } - - case "http_client": - // Validate URL - urlStr, ok := cfg.Options["url"].(string) - if !ok || urlStr == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: http_client sink requires 'url' option", - pipelineName, sinkIndex) - } - - // Validate URL format - parsedURL, err := url.Parse(urlStr) - if err != nil { - return fmt.Errorf("pipeline '%s' sink[%d]: invalid URL: %w", - pipelineName, sinkIndex, err) - } - if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return fmt.Errorf("pipeline '%s' sink[%d]: URL must use http or https scheme", - pipelineName, sinkIndex) - } - - // Validate batch size - if batchSize, ok := cfg.Options["batch_size"].(int64); ok { - if batchSize < 1 { - return fmt.Errorf("pipeline '%s' sink[%d]: batch_size must be positive: %d", - pipelineName, sinkIndex, batchSize) - } - } - - // Validate timeout - if timeout, ok := cfg.Options["timeout_seconds"].(int64); ok { - if timeout < 1 { - return fmt.Errorf("pipeline '%s' sink[%d]: timeout_seconds must be positive: %d", - pipelineName, sinkIndex, timeout) - } - } - - case "tcp_client": - // Added validation for TCP client sink - // Validate address - address, ok := cfg.Options["address"].(string) - if !ok || address == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: tcp_client sink requires 'address' option", - pipelineName, sinkIndex) - } - - // Validate address format - _, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("pipeline '%s' sink[%d]: invalid address format (expected host:port): %w", - pipelineName, sinkIndex, err) - } - - // Validate timeouts - if dialTimeout, ok := cfg.Options["dial_timeout_seconds"].(int64); ok { - if dialTimeout < 1 { - return fmt.Errorf("pipeline '%s' sink[%d]: dial_timeout_seconds must be positive: %d", - pipelineName, sinkIndex, dialTimeout) - } - } - - if writeTimeout, ok := cfg.Options["write_timeout_seconds"].(int64); ok { - if writeTimeout < 1 { - return fmt.Errorf("pipeline '%s' sink[%d]: write_timeout_seconds must be positive: %d", - pipelineName, sinkIndex, writeTimeout) - } - } - - case "file": - // Validate directory - directory, ok := cfg.Options["directory"].(string) - if !ok || directory == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: file sink requires 'directory' option", - pipelineName, sinkIndex) - } - - // Validate filename - name, ok := cfg.Options["name"].(string) - if !ok || name == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: file sink requires 'name' option", - pipelineName, sinkIndex) - } - - // Validate size options - if maxSize, ok := cfg.Options["max_size_mb"].(int64); ok { - if maxSize < 1 { - return fmt.Errorf("pipeline '%s' sink[%d]: max_size_mb must be positive: %d", - pipelineName, sinkIndex, maxSize) - } - } - - if maxTotalSize, ok := cfg.Options["max_total_size_mb"].(int64); ok { - if maxTotalSize < 0 { - return fmt.Errorf("pipeline '%s' sink[%d]: max_total_size_mb cannot be negative: %d", - pipelineName, sinkIndex, maxTotalSize) - } - } - - if minDiskFree, ok := cfg.Options["min_disk_free_mb"].(int64); ok { - if minDiskFree < 0 { - return fmt.Errorf("pipeline '%s' sink[%d]: min_disk_free_mb cannot be negative: %d", - pipelineName, sinkIndex, minDiskFree) - } - } - - // Validate retention period - if retention, ok := cfg.Options["retention_hours"].(float64); ok { - if retention < 0 { - return fmt.Errorf("pipeline '%s' sink[%d]: retention_hours cannot be negative: %f", - pipelineName, sinkIndex, retention) - } - } - - case "console": - // No specific validation needed for console sinks - - default: - return fmt.Errorf("pipeline '%s' sink[%d]: unknown sink type '%s'", - pipelineName, sinkIndex, cfg.Type) - } - - return nil -} \ No newline at end of file diff --git a/src/internal/config/saver.go b/src/internal/config/saver.go deleted file mode 100644 index 4587ec5..0000000 --- a/src/internal/config/saver.go +++ /dev/null @@ -1,33 +0,0 @@ -// FILE: logwisp/src/internal/config/saver.go -package config - -import ( - "fmt" - - lconfig "github.com/lixenwraith/config" -) - -// Saves the configuration to the specified file path. -func (c *Config) SaveToFile(path string) error { - if path == "" { - return fmt.Errorf("cannot save config: path is empty") - } - - // Create a temporary lconfig instance just for saving - // This avoids the need to track lconfig throughout the application - lcfg, err := lconfig.NewBuilder(). - WithFile(path). - WithTarget(c). - WithFileFormat("toml"). - Build() - if err != nil { - return fmt.Errorf("failed to create config builder: %w", err) - } - - // Use lconfig's Save method which handles atomic writes - if err := lcfg.Save(path); err != nil { - return fmt.Errorf("failed to save config: %w", err) - } - - return nil -} \ No newline at end of file diff --git a/src/internal/config/server.go b/src/internal/config/server.go deleted file mode 100644 index 07e935f..0000000 --- a/src/internal/config/server.go +++ /dev/null @@ -1,203 +0,0 @@ -// FILE: logwisp/src/internal/config/server.go -package config - -import ( - "fmt" - "net" - "strings" -) - -type TCPConfig struct { - Enabled bool `toml:"enabled"` - Port int64 `toml:"port"` - BufferSize int64 `toml:"buffer_size"` - - // Net limiting - NetLimit *NetLimitConfig `toml:"net_limit"` - - // Heartbeat - Heartbeat *HeartbeatConfig `toml:"heartbeat"` -} - -type HTTPConfig struct { - Enabled bool `toml:"enabled"` - Port int64 `toml:"port"` - BufferSize int64 `toml:"buffer_size"` - - // Endpoint paths - StreamPath string `toml:"stream_path"` - StatusPath string `toml:"status_path"` - - // TLS Configuration - TLS *TLSConfig `toml:"tls"` - - // Nate limiting - NetLimit *NetLimitConfig `toml:"net_limit"` - - // Heartbeat - Heartbeat *HeartbeatConfig `toml:"heartbeat"` -} - -type HeartbeatConfig struct { - Enabled bool `toml:"enabled"` - IntervalSeconds int64 `toml:"interval_seconds"` - IncludeTimestamp bool `toml:"include_timestamp"` - IncludeStats bool `toml:"include_stats"` - Format string `toml:"format"` -} - -type NetLimitConfig struct { - // Enable net limiting - Enabled bool `toml:"enabled"` - - // IP Access Control Lists - IPWhitelist []string `toml:"ip_whitelist"` - IPBlacklist []string `toml:"ip_blacklist"` - - // Requests per second per client - RequestsPerSecond float64 `toml:"requests_per_second"` - - // Burst size (token bucket) - BurstSize int64 `toml:"burst_size"` - - // Response when net limited - ResponseCode int64 `toml:"response_code"` // Default: 429 - ResponseMessage string `toml:"response_message"` // Default: "Net limit exceeded" - - // Connection limits - MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"` - MaxConnectionsPerUser int64 `toml:"max_connections_per_user"` - MaxConnectionsPerToken int64 `toml:"max_connections_per_token"` - MaxConnectionsTotal int64 `toml:"max_connections_total"` -} - -func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb map[string]any) error { - if enabled, ok := hb["enabled"].(bool); ok && enabled { - interval, ok := hb["interval_seconds"].(int64) - if !ok || interval < 1 { - return fmt.Errorf("pipeline '%s' sink[%d] %s: heartbeat interval must be positive", - pipelineName, sinkIndex, serverType) - } - - if format, ok := hb["format"].(string); ok { - if format != "json" && format != "comment" { - return fmt.Errorf("pipeline '%s' sink[%d] %s: heartbeat format must be 'json' or 'comment': %s", - pipelineName, sinkIndex, serverType, format) - } - } - } - return nil -} - -func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, nl map[string]any) error { - if enabled, ok := nl["enabled"].(bool); !ok || !enabled { - return nil - } - - // Validate IP lists if present - if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok { - for i, entry := range ipWhitelist { - entryStr, ok := entry.(string) - if !ok { - continue - } - if err := validateIPv4Entry(entryStr); err != nil { - return fmt.Errorf("pipeline '%s' sink[%d] %s: whitelist[%d] %v", - pipelineName, sinkIndex, serverType, i, err) - } - } - } - - if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok { - for i, entry := range ipBlacklist { - entryStr, ok := entry.(string) - if !ok { - continue - } - if err := validateIPv4Entry(entryStr); err != nil { - return fmt.Errorf("pipeline '%s' sink[%d] %s: blacklist[%d] %v", - pipelineName, sinkIndex, serverType, i, err) - } - } - } - - // Validate requests per second - rps, ok := nl["requests_per_second"].(float64) - if !ok || rps <= 0 { - return fmt.Errorf("pipeline '%s' sink[%d] %s: requests_per_second must be positive", - pipelineName, sinkIndex, serverType) - } - - // Validate burst size - burst, ok := nl["burst_size"].(int64) - if !ok || burst < 1 { - return fmt.Errorf("pipeline '%s' sink[%d] %s: burst_size must be at least 1", - pipelineName, sinkIndex, serverType) - } - - // Validate response code - if respCode, ok := nl["response_code"].(int64); ok { - if respCode > 0 && (respCode < 400 || respCode >= 600) { - return fmt.Errorf("pipeline '%s' sink[%d] %s: response_code must be 4xx or 5xx: %d", - pipelineName, sinkIndex, serverType, respCode) - } - } - - // Validate connection limits - maxPerIP, perIPOk := nl["max_connections_per_ip"].(int64) - maxPerUser, perUserOk := nl["max_connections_per_user"].(int64) - maxPerToken, perTokenOk := nl["max_connections_per_token"].(int64) - maxTotal, totalOk := nl["max_connections_total"].(int64) - - if perIPOk && perUserOk && perTokenOk && totalOk && - maxPerIP > 0 && maxPerUser > 0 && maxPerToken > 0 && maxTotal > 0 { - if maxPerIP > maxTotal { - return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_connections_total (%d)", - pipelineName, sinkIndex, serverType, maxPerIP, maxTotal) - } - if maxPerUser > maxTotal { - return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_user (%d) cannot exceed max_connections_total (%d)", - pipelineName, sinkIndex, serverType, maxPerUser, maxTotal) - } - if maxPerToken > maxTotal { - return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_token (%d) cannot exceed max_connections_total (%d)", - pipelineName, sinkIndex, serverType, maxPerToken, maxTotal) - } - } - - return nil -} - -// Ensures an IP or CIDR is IPv4 -func validateIPv4Entry(entry string) error { - // Handle single IP - if !strings.Contains(entry, "/") { - ip := net.ParseIP(entry) - if ip == nil { - return fmt.Errorf("invalid IP address: %s", entry) - } - if ip.To4() == nil { - return fmt.Errorf("IPv6 not supported (IPv4-only): %s", entry) - } - return nil - } - - // Handle CIDR - ipAddr, ipNet, err := net.ParseCIDR(entry) - if err != nil { - return fmt.Errorf("invalid CIDR: %s", entry) - } - - // Check if the IP is IPv4 - if ipAddr.To4() == nil { - return fmt.Errorf("IPv6 CIDR not supported (IPv4-only): %s", entry) - } - - // Verify the network mask is appropriate for IPv4 - _, bits := ipNet.Mask.Size() - if bits != 32 { - return fmt.Errorf("invalid IPv4 CIDR mask (got %d bits, expected 32): %s", bits, entry) - } - - return nil -} \ No newline at end of file diff --git a/src/internal/config/tls.go b/src/internal/config/tls.go deleted file mode 100644 index 9252b45..0000000 --- a/src/internal/config/tls.go +++ /dev/null @@ -1,82 +0,0 @@ -// FILE: logwisp/src/internal/config/tls.go -package config - -import ( - "fmt" - "os" -) - -type TLSConfig struct { - Enabled bool `toml:"enabled"` - CertFile string `toml:"cert_file"` - KeyFile string `toml:"key_file"` - - // Client certificate authentication - ClientAuth bool `toml:"client_auth"` - ClientCAFile string `toml:"client_ca_file"` - VerifyClientCert bool `toml:"verify_client_cert"` - - // Option to skip verification for clients - InsecureSkipVerify bool `toml:"insecure_skip_verify"` - - // CA file for client to trust specific server certificates - CAFile string `toml:"ca_file"` - - // TLS version constraints - MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3" - MaxVersion string `toml:"max_version"` - - // Cipher suites (comma-separated list) - CipherSuites string `toml:"cipher_suites"` -} - -func validateTLSOptions(serverType, pipelineName string, sinkIndex int, tls map[string]any) error { - if enabled, ok := tls["enabled"].(bool); ok && enabled { - certFile, certOk := tls["cert_file"].(string) - keyFile, keyOk := tls["key_file"].(string) - - if !certOk || certFile == "" || !keyOk || keyFile == "" { - return fmt.Errorf("pipeline '%s' sink[%d] %s: TLS enabled but cert/key files not specified", - pipelineName, sinkIndex, serverType) - } - - // Validate that certificate files exist and are readable - if _, err := os.Stat(certFile); err != nil { - return fmt.Errorf("pipeline '%s' sink[%d] %s: cert_file is not accessible: %w", - pipelineName, sinkIndex, serverType, err) - } - if _, err := os.Stat(keyFile); err != nil { - return fmt.Errorf("pipeline '%s' sink[%d] %s: key_file is not accessible: %w", - pipelineName, sinkIndex, serverType, err) - } - - if clientAuth, ok := tls["client_auth"].(bool); ok && clientAuth { - caFile, caOk := tls["client_ca_file"].(string) - if !caOk || caFile == "" { - return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified", - pipelineName, sinkIndex, serverType) - } - // Validate that the client CA file exists and is readable - if _, err := os.Stat(caFile); err != nil { - return fmt.Errorf("pipeline '%s' sink[%d] %s: client_ca_file is not accessible: %w", - pipelineName, sinkIndex, serverType, err) - } - } - - // Validate TLS versions - validVersions := map[string]bool{"TLS1.0": true, "TLS1.1": true, "TLS1.2": true, "TLS1.3": true} - if minVer, ok := tls["min_version"].(string); ok && minVer != "" { - if !validVersions[minVer] { - return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid min TLS version: %s", - pipelineName, sinkIndex, serverType, minVer) - } - } - if maxVer, ok := tls["max_version"].(string); ok && maxVer != "" { - if !validVersions[maxVer] { - return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid max TLS version: %s", - pipelineName, sinkIndex, serverType, maxVer) - } - } - } - return nil -} \ No newline at end of file diff --git a/src/internal/config/validation.go b/src/internal/config/validation.go index b343683..a7fd98e 100644 --- a/src/internal/config/validation.go +++ b/src/internal/config/validation.go @@ -1,24 +1,28 @@ -// FILE: logwisp/src/internal/config/validation.go package config import ( "fmt" + "net/url" + "path/filepath" + "regexp" + "strings" + "time" + + lconfig "github.com/lixenwraith/config" ) -func (c *Config) validate() error { - if c == nil { +// validateConfig is the centralized validator for the entire configuration +// This replaces the old (c *Config) validate() method +func validateConfig(cfg *Config) error { + if cfg == nil { return fmt.Errorf("config is nil") } - if c.Logging == nil { - c.Logging = DefaultLogConfig() - } - - if len(c.Pipelines) == 0 { + if len(cfg.Pipelines) == 0 { return fmt.Errorf("no pipelines configured") } - if err := validateLogConfig(c.Logging); err != nil { + if err := validateLogConfig(cfg.Logging); err != nil { return fmt.Errorf("logging config: %w", err) } @@ -26,57 +30,924 @@ func (c *Config) validate() error { allPorts := make(map[int64]string) pipelineNames := make(map[string]bool) - for i, pipeline := range c.Pipelines { - if pipeline.Name == "" { - return fmt.Errorf("pipeline %d: missing name", i) - } - - if pipelineNames[pipeline.Name] { - return fmt.Errorf("pipeline %d: duplicate name '%s'", i, pipeline.Name) - } - pipelineNames[pipeline.Name] = true - - // Pipeline must have at least one source - if len(pipeline.Sources) == 0 { - return fmt.Errorf("pipeline '%s': no sources specified", pipeline.Name) - } - - // Validate sources - for j, source := range pipeline.Sources { - if err := validateSource(pipeline.Name, j, &source); err != nil { - return err - } - } - - // Validate rate limit if present - if err := validateRateLimit(pipeline.Name, pipeline.RateLimit); err != nil { - return err - } - - // Validate filters - for j, filterCfg := range pipeline.Filters { - if err := validateFilter(pipeline.Name, j, &filterCfg); err != nil { - return err - } - } - - // Pipeline must have at least one sink - if len(pipeline.Sinks) == 0 { - return fmt.Errorf("pipeline '%s': no sinks specified", pipeline.Name) - } - - // Validate sinks and check for port conflicts - for j, sink := range pipeline.Sinks { - if err := validateSink(pipeline.Name, j, &sink, allPorts); err != nil { - return err - } - } - - // Validate auth if present - if err := validateAuth(pipeline.Name, pipeline.Auth); err != nil { + for i, pipeline := range cfg.Pipelines { + if err := validatePipeline(i, &pipeline, pipelineNames, allPorts); err != nil { return err } } + return nil +} + +func validateLogConfig(cfg *LogConfig) error { + validOutputs := map[string]bool{ + "file": true, "stdout": true, "stderr": true, + "split": true, "all": true, "none": true, + } + if !validOutputs[cfg.Output] { + return fmt.Errorf("invalid log output mode: %s", cfg.Output) + } + + validLevels := map[string]bool{ + "debug": true, "info": true, "warn": true, "error": true, + } + if !validLevels[cfg.Level] { + return fmt.Errorf("invalid log level: %s", cfg.Level) + } + + if cfg.Console != nil { + validTargets := map[string]bool{ + "stdout": true, "stderr": true, "split": true, + } + if !validTargets[cfg.Console.Target] { + return fmt.Errorf("invalid console target: %s", cfg.Console.Target) + } + + validFormats := map[string]bool{ + "txt": true, "json": true, "": true, + } + if !validFormats[cfg.Console.Format] { + return fmt.Errorf("invalid console format: %s", cfg.Console.Format) + } + } + + return nil +} + +func validatePipeline(index int, p *PipelineConfig, pipelineNames map[string]bool, allPorts map[int64]string) error { + // Validate pipeline name + if err := lconfig.NonEmpty(p.Name); err != nil { + return fmt.Errorf("pipeline %d: missing name", index) + } + + if pipelineNames[p.Name] { + return fmt.Errorf("pipeline %d: duplicate name '%s'", index, p.Name) + } + pipelineNames[p.Name] = true + + // Must have at least one source + if len(p.Sources) == 0 { + return fmt.Errorf("pipeline '%s': no sources specified", p.Name) + } + + // Validate each source + for j, source := range p.Sources { + if err := validateSourceConfig(p.Name, j, &source); err != nil { + return err + } + } + + // Validate rate limit if present + if p.RateLimit != nil { + if err := validateRateLimit(p.Name, p.RateLimit); err != nil { + return err + } + } + + // Validate filters + for j, filter := range p.Filters { + if err := validateFilter(p.Name, j, &filter); err != nil { + return err + } + } + + // Validate formatter configuration + if err := validateFormatterConfig(p); err != nil { + return fmt.Errorf("pipeline '%s': %w", p.Name, err) + } + + // Must have at least one sink + if len(p.Sinks) == 0 { + return fmt.Errorf("pipeline '%s': no sinks specified", p.Name) + } + + // Validate each sink + for j, sink := range p.Sinks { + if err := validateSinkConfig(p.Name, j, &sink, allPorts); err != nil { + return err + } + } + + return nil +} + +// validateSourceConfig validates typed source configuration +func validateSourceConfig(pipelineName string, index int, s *SourceConfig) error { + if err := lconfig.NonEmpty(s.Type); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: missing type", pipelineName, index) + } + + // Count how many source configs are populated + populated := 0 + var populatedType string + + if s.Directory != nil { + populated++ + populatedType = "directory" + } + if s.Stdin != nil { + populated++ + populatedType = "stdin" + } + if s.HTTP != nil { + populated++ + populatedType = "http" + } + if s.TCP != nil { + populated++ + populatedType = "tcp" + } + + if populated == 0 { + return fmt.Errorf("pipeline '%s' source[%d]: no configuration provided for type '%s'", + pipelineName, index, s.Type) + } + if populated > 1 { + return fmt.Errorf("pipeline '%s' source[%d]: multiple configurations provided, only one allowed", + pipelineName, index) + } + if populatedType != s.Type { + return fmt.Errorf("pipeline '%s' source[%d]: type mismatch - type is '%s' but config is for '%s'", + pipelineName, index, s.Type, populatedType) + } + + // Validate specific source type + switch s.Type { + case "directory": + return validateDirectorySource(pipelineName, index, s.Directory) + case "stdin": + return validateStdinSource(pipelineName, index, s.Stdin) + case "http": + return validateHTTPSource(pipelineName, index, s.HTTP) + case "tcp": + return validateTCPSource(pipelineName, index, s.TCP) + default: + return fmt.Errorf("pipeline '%s' source[%d]: unknown type '%s'", pipelineName, index, s.Type) + } +} + +func validateDirectorySource(pipelineName string, index int, opts *DirectorySourceOptions) error { + if err := lconfig.NonEmpty(opts.Path); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: directory requires 'path'", pipelineName, index) + } else { + absPath, err := filepath.Abs(opts.Path) + if err != nil { + return fmt.Errorf("invalid path %s: %w", opts.Path, err) + } + opts.Path = absPath + } + + // Check for directory traversal + // TODO: traversal check only if optional security settings from cli/env set + if strings.Contains(opts.Path, "..") { + return fmt.Errorf("pipeline '%s' source[%d]: path contains directory traversal", pipelineName, index) + } + + // Validate pattern if provided + if opts.Pattern != "" { + if strings.Count(opts.Pattern, "*") == 0 && strings.Count(opts.Pattern, "?") == 0 { + // If no wildcards, ensure valid filename + if filepath.Base(opts.Pattern) != opts.Pattern { + return fmt.Errorf("pipeline '%s' source[%d]: pattern contains path separators", pipelineName, index) + } + } + } else { + opts.Pattern = "*" + } + + // Validate check interval + if opts.CheckIntervalMS < 10 { + return fmt.Errorf("pipeline '%s' source[%d]: check_interval_ms must be at least 10ms", pipelineName, index) + } + + return nil +} + +func validateStdinSource(pipelineName string, index int, opts *StdinSourceOptions) error { + if opts.BufferSize < 0 { + return fmt.Errorf("pipeline '%s' source[%d]: buffer_size must be positive", pipelineName, index) + } else if opts.BufferSize == 0 { + opts.BufferSize = 1000 + } + return nil +} + +func validateHTTPSource(pipelineName string, index int, opts *HTTPSourceOptions) error { + // Validate port + if err := lconfig.Port(opts.Port); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) + } + + // Set defaults + if opts.Host == "" { + opts.Host = "0.0.0.0" + } + if opts.IngestPath == "" { + opts.IngestPath = "/ingest" + } + if opts.MaxRequestBodySize <= 0 { + opts.MaxRequestBodySize = 10 * 1024 * 1024 // 10MB default + } + if opts.ReadTimeout <= 0 { + opts.ReadTimeout = 5000 // 5 seconds + } + if opts.WriteTimeout <= 0 { + opts.WriteTimeout = 5000 // 5 seconds + } + + // Validate host if specified + if opts.Host != "" && opts.Host != "0.0.0.0" { + if err := lconfig.IPAddress(opts.Host); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) + } + } + + // Validate paths + if !strings.HasPrefix(opts.IngestPath, "/") { + return fmt.Errorf("pipeline '%s' source[%d]: ingest_path must start with /", pipelineName, index) + } + + // Validate auth configuration + validHTTPSourceAuthTypes := map[string]bool{"basic": true, "token": true, "mtls": true} + if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" { + if !validHTTPSourceAuthTypes[opts.Auth.Type] { + return fmt.Errorf("pipeline '%s' source[%d]: %s is not a valid auth type", + pipelineName, index, opts.Auth.Type) + } + // All non-none auth types require TLS for HTTP + if opts.TLS == nil || !opts.TLS.Enabled { + return fmt.Errorf("pipeline '%s' source[%d]: %s auth requires TLS to be enabled", + pipelineName, index, opts.Auth.Type) + } + + // Validate specific auth types + if err := validateServerAuth(pipelineName, opts.Auth); err != nil { + return fmt.Errorf("source[%d]: %w", index, err) + } + } + + // Validate nested configs + if opts.NetLimit != nil { + if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil { + return err + } + } + + if opts.TLS != nil { + if err := validateTLS(pipelineName, fmt.Sprintf("source[%d]", index), opts.TLS); err != nil { + return err + } + } + + return nil +} + +func validateTCPSource(pipelineName string, index int, opts *TCPSourceOptions) error { + // Validate port + if err := lconfig.Port(opts.Port); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) + } + + // Set defaults + if opts.Host == "" { + opts.Host = "0.0.0.0" + } + if opts.ReadTimeout <= 0 { + opts.ReadTimeout = 5000 // 5 seconds + } + if !opts.KeepAlive { + opts.KeepAlive = true // Default enabled + } + if opts.KeepAlivePeriod <= 0 { + opts.KeepAlivePeriod = 30000 // 30 seconds + } + + // Validate host if specified + if opts.Host != "" && opts.Host != "0.0.0.0" { + if err := lconfig.IPAddress(opts.Host); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) + } + } + + // TCP source does NOT support TLS + // Validate auth configuration - only none and scram are allowed + if opts.Auth != nil { + switch opts.Auth.Type { + case "", "none": + // OK + case "scram": + // SCRAM doesn't require TLS + if err := validateServerAuth(pipelineName, opts.Auth); err != nil { + return fmt.Errorf("source[%d]: %w", index, err) + } + default: + return fmt.Errorf("pipeline '%s' source[%d]: TCP source only supports 'none' or 'scram' auth (got '%s')", + pipelineName, index, opts.Auth.Type) + } + } + + // Validate NetLimit if present + if opts.NetLimit != nil { + if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil { + return err + } + } + + return nil +} + +// validateSinkConfig validates typed sink configuration +func validateSinkConfig(pipelineName string, index int, s *SinkConfig, allPorts map[int64]string) error { + if err := lconfig.NonEmpty(s.Type); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: missing type", pipelineName, index) + } + + // Count populated sink configs + populated := 0 + var populatedType string + + if s.Console != nil { + populated++ + populatedType = "console" + } + if s.File != nil { + populated++ + populatedType = "file" + } + if s.HTTP != nil { + populated++ + populatedType = "http" + } + if s.TCP != nil { + populated++ + populatedType = "tcp" + } + if s.HTTPClient != nil { + populated++ + populatedType = "http_client" + } + if s.TCPClient != nil { + populated++ + populatedType = "tcp_client" + } + + if populated == 0 { + return fmt.Errorf("pipeline '%s' sink[%d]: no configuration provided for type '%s'", + pipelineName, index, s.Type) + } + if populated > 1 { + return fmt.Errorf("pipeline '%s' sink[%d]: multiple configurations provided, only one allowed", + pipelineName, index) + } + if populatedType != s.Type { + return fmt.Errorf("pipeline '%s' sink[%d]: type mismatch - type is '%s' but config is for '%s'", + pipelineName, index, s.Type, populatedType) + } + + // Validate specific sink type + switch s.Type { + case "console": + return validateConsoleSink(pipelineName, index, s.Console) + case "file": + return validateFileSink(pipelineName, index, s.File) + case "http": + return validateHTTPSink(pipelineName, index, s.HTTP, allPorts) + case "tcp": + return validateTCPSink(pipelineName, index, s.TCP, allPorts) + case "http_client": + return validateHTTPClientSink(pipelineName, index, s.HTTPClient) + case "tcp_client": + return validateTCPClientSink(pipelineName, index, s.TCPClient) + default: + return fmt.Errorf("pipeline '%s' sink[%d]: unknown type '%s'", pipelineName, index, s.Type) + } +} + +func validateConsoleSink(pipelineName string, index int, opts *ConsoleSinkOptions) error { + if opts.BufferSize < 1 { + return fmt.Errorf("pipeline '%s' sink[%d]: buffer_size must be positive", pipelineName, index) + } + return nil +} + +func validateFileSink(pipelineName string, index int, opts *FileSinkOptions) error { + if err := lconfig.NonEmpty(opts.Directory); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: file requires 'directory'", pipelineName, index) + } + + if err := lconfig.NonEmpty(opts.Name); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: file requires 'name'", pipelineName, index) + } + + if opts.BufferSize <= 0 { + return fmt.Errorf("pipeline '%s' sink[%d]: max_size_mb must be positive", pipelineName, index) + } + + // Validate sizes + if opts.MaxSizeMB < 0 { + return fmt.Errorf("pipeline '%s' sink[%d]: max_size_mb must be positive", pipelineName, index) + } + + if opts.MaxTotalSizeMB <= 0 { + return fmt.Errorf("pipeline '%s' sink[%d]: max_total_size_mb cannot be negative", pipelineName, index) + } + + if opts.MinDiskFreeMB < 0 { + return fmt.Errorf("pipeline '%s' sink[%d]: min_disk_free_mb must be positive", pipelineName, index) + } + + if opts.RetentionHours <= 0 { + return fmt.Errorf("pipeline '%s' sink[%d]: retention_hours cannot be negative", pipelineName, index) + } + + return nil +} + +func validateHTTPSink(pipelineName string, index int, opts *HTTPSinkOptions, allPorts map[int64]string) error { + // Validate port + if err := lconfig.Port(opts.Port); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: %w", pipelineName, index, err) + } + + // Check port conflicts + if existing, exists := allPorts[opts.Port]; exists { + return fmt.Errorf("pipeline '%s' sink[%d]: port %d already used by %s", + pipelineName, index, opts.Port, existing) + } + allPorts[opts.Port] = fmt.Sprintf("%s-http[%d]", pipelineName, index) + + // Validate host if specified + if opts.Host != "" { + if err := lconfig.IPAddress(opts.Host); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: %w", pipelineName, index, err) + } + } + + // Validate paths + if !strings.HasPrefix(opts.StreamPath, "/") { + return fmt.Errorf("pipeline '%s' sink[%d]: stream_path must start with /", pipelineName, index) + } + + if !strings.HasPrefix(opts.StatusPath, "/") { + return fmt.Errorf("pipeline '%s' sink[%d]: status_path must start with /", pipelineName, index) + } + + // Validate buffer + if opts.BufferSize < 1 { + return fmt.Errorf("pipeline '%s' sink[%d]: buffer_size must be positive", pipelineName, index) + } + + // Validate nested configs + if opts.Heartbeat != nil { + if err := validateHeartbeat(pipelineName, fmt.Sprintf("sink[%d]", index), opts.Heartbeat); err != nil { + return err + } + } + + if opts.NetLimit != nil { + if err := validateNetLimit(pipelineName, fmt.Sprintf("sink[%d]", index), opts.NetLimit); err != nil { + return err + } + } + + if opts.TLS != nil { + if err := validateTLS(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil { + return err + } + } + + return nil +} + +func validateTCPSink(pipelineName string, index int, opts *TCPSinkOptions, allPorts map[int64]string) error { + // Validate port + if err := lconfig.Port(opts.Port); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: %w", pipelineName, index, err) + } + + // Check port conflicts + if existing, exists := allPorts[opts.Port]; exists { + return fmt.Errorf("pipeline '%s' sink[%d]: port %d already used by %s", + pipelineName, index, opts.Port, existing) + } + allPorts[opts.Port] = fmt.Sprintf("%s-tcp[%d]", pipelineName, index) + + // Validate host if specified + if opts.Host != "" { + if err := lconfig.IPAddress(opts.Host); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: %w", pipelineName, index, err) + } + } + + // Validate buffer + if opts.BufferSize < 1 { + return fmt.Errorf("pipeline '%s' sink[%d]: buffer_size must be positive", pipelineName, index) + } + + // Validate nested configs + if opts.Heartbeat != nil { + if err := validateHeartbeat(pipelineName, fmt.Sprintf("sink[%d]", index), opts.Heartbeat); err != nil { + return err + } + } + + if opts.NetLimit != nil { + if err := validateNetLimit(pipelineName, fmt.Sprintf("sink[%d]", index), opts.NetLimit); err != nil { + return err + } + } + + return nil +} + +func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSinkOptions) error { + // Validate URL + if err := lconfig.NonEmpty(opts.URL); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: http_client requires 'url'", pipelineName, index) + } + + parsedURL, err := url.Parse(opts.URL) + if err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: invalid URL: %w", pipelineName, index, err) + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("pipeline '%s' sink[%d]: URL must use http or https scheme", pipelineName, index) + } + + isHTTPS := parsedURL.Scheme == "https" + + // Set defaults for unspecified fields + if opts.BufferSize <= 0 { + opts.BufferSize = 1000 + } + if opts.BatchSize <= 0 { + opts.BatchSize = 100 + } + if opts.BatchDelayMS <= 0 { + opts.BatchDelayMS = 1000 // 1 second in ms + } + if opts.Timeout <= 0 { + opts.Timeout = 30 // 30 seconds + } + if opts.MaxRetries < 0 { + opts.MaxRetries = 3 + } + if opts.RetryDelayMS <= 0 { + opts.RetryDelayMS = 1000 // 1 second in ms + } + if opts.RetryBackoff < 1.0 { + opts.RetryBackoff = 2.0 + } + if opts.Headers == nil { + opts.Headers = make(map[string]string) + } + + // Set default Content-Type if not specified + if _, exists := opts.Headers["Content-Type"]; !exists { + opts.Headers["Content-Type"] = "application/json" + } + + // Validate auth configuration + if opts.Auth != nil { + switch opts.Auth.Type { + case "basic": + if opts.Auth.Username == "" || opts.Auth.Password == "" { + return fmt.Errorf("pipeline '%s' sink[%d]: username and password required for basic auth", + pipelineName, index) + } + if !isHTTPS && !opts.InsecureSkipVerify { + return fmt.Errorf("pipeline '%s' sink[%d]: basic auth requires HTTPS (security: credentials would be sent in plaintext)", + pipelineName, index) + } + + case "token": + if opts.Auth.Token == "" { + return fmt.Errorf("pipeline '%s' sink[%d]: token required for %s auth", + pipelineName, index, opts.Auth.Type) + } + if !isHTTPS && !opts.InsecureSkipVerify { + return fmt.Errorf("pipeline '%s' sink[%d]: %s auth requires HTTPS (security: token would be sent in plaintext)", + pipelineName, index, opts.Auth.Type) + } + + case "mtls": + if !isHTTPS { + return fmt.Errorf("pipeline '%s' sink[%d]: mTLS requires HTTPS", + pipelineName, index) + } + // mTLS certs should be in TLS config, not auth config + if opts.TLS == nil || opts.TLS.CertFile == "" || opts.TLS.KeyFile == "" { + return fmt.Errorf("pipeline '%s' sink[%d]: cert_file and key_file required in TLS config for mTLS auth", + pipelineName, index) + } + + case "none", "": + // Clear any credentials if auth is "none" or empty + if opts.Auth != nil { + opts.Auth.Username = "" + opts.Auth.Password = "" + opts.Auth.Token = "" + } + + default: + return fmt.Errorf("pipeline '%s' sink[%d]: invalid auth type '%s' (valid: none, basic, token, mtls)", + pipelineName, index, opts.Auth.Type) + } + } + + // Validate TLS config if present + if opts.TLS != nil { + if err := validateTLS(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil { + return err + } + } + + return nil +} + +func validateTCPClientSink(pipelineName string, index int, opts *TCPClientSinkOptions) error { + // Validate host and port + if err := lconfig.NonEmpty(opts.Host); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: tcp_client requires 'host'", pipelineName, index) + } + + if err := lconfig.Port(opts.Port); err != nil { + return fmt.Errorf("pipeline '%s' sink[%d]: %w", pipelineName, index, err) + } + + // Set defaults + if opts.BufferSize <= 0 { + opts.BufferSize = 1000 + } + if opts.DialTimeout <= 0 { + opts.DialTimeout = 10 // 10 seconds + } + if opts.WriteTimeout <= 0 { + opts.WriteTimeout = 30 // 30 seconds + } + if opts.ReadTimeout <= 0 { + opts.ReadTimeout = 10 // 10 seconds + } + if opts.KeepAlive <= 0 { + opts.KeepAlive = 30 // 30 seconds + } + if opts.ReconnectDelayMS <= 0 { + opts.ReconnectDelayMS = 1000 // 1 second in ms + } + if opts.MaxReconnectDelayMS <= 0 { + opts.MaxReconnectDelayMS = 30000 // 30 seconds in ms + } + if opts.ReconnectBackoff < 1.0 { + opts.ReconnectBackoff = 1.5 + } + + // Validate auth configuration + if opts.Auth != nil { + switch opts.Auth.Type { + + case "scram": + if opts.Auth.Username == "" || opts.Auth.Password == "" { + return fmt.Errorf("pipeline '%s' sink[%d]: username and password required for SCRAM auth", + pipelineName, index) + } + // SCRAM doesn't require TLS as it uses challenge-response + + case "none", "": + // Clear credentials + if opts.Auth != nil { + opts.Auth.Username = "" + opts.Auth.Password = "" + opts.Auth.Token = "" + } + + default: + return fmt.Errorf("pipeline '%s' sink[%d]: invalid auth type '%s' (valid: none, basic, token, scram, mtls)", + pipelineName, index, opts.Auth.Type) + } + } + + return nil +} + +// validateFormatterConfig validates formatter configuration +func validateFormatterConfig(p *PipelineConfig) error { + if p.Format == nil { + p.Format = &FormatConfig{ + Type: "raw", + } + } else if p.Format.Type == "" { + p.Format.Type = "raw" // Default + } + + switch p.Format.Type { + + case "raw": + if p.Format.RawFormatOptions == nil { + p.Format.RawFormatOptions = &RawFormatterOptions{} + } + + case "txt": + if p.Format.TextFormatOptions == nil { + p.Format.TextFormatOptions = &TextFormatterOptions{} + } + + // Default template format + templateStr := "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}" + if p.Format.TextFormatOptions.Template != "" { + p.Format.TextFormatOptions.Template = templateStr + } + + // Default timestamp format + timestampFormat := time.RFC3339 + if p.Format.TextFormatOptions.TimestampFormat != "" { + p.Format.TextFormatOptions.TimestampFormat = timestampFormat + } + + case "json": + if p.Format.JSONFormatOptions == nil { + p.Format.JSONFormatOptions = &JSONFormatterOptions{} + } + } + + return nil +} + +// Helper validation functions for nested configs +func validateNetLimit(pipelineName, location string, nl *NetLimitConfig) error { + if !nl.Enabled { + return nil // Skip validation if disabled + } + + if nl.MaxConnections < 0 { + return fmt.Errorf("pipeline '%s' %s: max_connections cannot be negative", pipelineName, location) + } + + if nl.BurstSize < 0 { + return fmt.Errorf("pipeline '%s' %s: burst_size cannot be negative", pipelineName, location) + } + + return nil +} + +func validateTLS(pipelineName, location string, tls *TLSConfig) error { + if !tls.Enabled { + return nil // Skip validation if disabled + } + + // If TLS enabled, cert and key files required (unless skip verify) + if !tls.SkipVerify { + if tls.CertFile == "" || tls.KeyFile == "" { + return fmt.Errorf("pipeline '%s' %s: TLS enabled requires cert_file and key_file", pipelineName, location) + } + } + + return nil +} + +func validateHeartbeat(pipelineName, location string, hb *HeartbeatConfig) error { + if !hb.Enabled { + return nil // Skip validation if disabled + } + + if hb.Interval < 1000 { // At least 1 second + return fmt.Errorf("pipeline '%s' %s: heartbeat interval must be at least 1000ms", pipelineName, location) + } + + return nil +} + +func validateServerAuth(pipelineName string, auth *ServerAuthConfig) error { + if auth.Type == "" || auth.Type == "none" { + return nil + } + + // Count populated auth configs + populated := 0 + var populatedType string + + if auth.Basic != nil { + populated++ + populatedType = "basic" + } + if auth.Token != nil { + populated++ + populatedType = "token" + } + if auth.Scram != nil { + populated++ + populatedType = "scram" + } + + if populated == 0 { + return fmt.Errorf("pipeline '%s': auth type '%s' specified but config missing", pipelineName, auth.Type) + } + if populated > 1 { + return fmt.Errorf("pipeline '%s': multiple auth configurations provided", pipelineName) + } + if populatedType != auth.Type { + return fmt.Errorf("pipeline '%s': auth type mismatch - type is '%s' but config is for '%s'", + pipelineName, auth.Type, populatedType) + } + + // Validate specific auth type + switch auth.Type { + case "basic": + if len(auth.Basic.Users) == 0 { + return fmt.Errorf("pipeline '%s': basic auth requires at least one user", pipelineName) + } + for i, user := range auth.Basic.Users { + if err := lconfig.NonEmpty(user.Username); err != nil { + return fmt.Errorf("pipeline '%s': basic auth user[%d] missing username", pipelineName, i) + } + if err := lconfig.NonEmpty(user.PasswordHash); err != nil { + return fmt.Errorf("pipeline '%s': basic auth user[%d] missing password_hash", pipelineName, i) + } + } + case "token": + if len(auth.Token.Tokens) == 0 { + return fmt.Errorf("pipeline '%s': token auth requires at least one token", pipelineName) + } + case "scram": + if len(auth.Scram.Users) == 0 { + return fmt.Errorf("pipeline '%s': scram auth requires at least one user", pipelineName) + } + for i, user := range auth.Scram.Users { + if err := lconfig.NonEmpty(user.Username); err != nil { + return fmt.Errorf("pipeline '%s': scram auth user[%d] missing username", pipelineName, i) + } + // Validate required SCRAM fields + if user.StoredKey == "" || user.ServerKey == "" || user.Salt == "" { + return fmt.Errorf("pipeline '%s': scram auth user[%d] missing required fields", pipelineName, i) + } + } + default: + return fmt.Errorf("pipeline '%s': unknown auth type '%s'", pipelineName, auth.Type) + } + + return nil +} + +func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { + if cfg == nil { + return nil + } + + if cfg.Rate < 0 { + return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName) + } + + if cfg.Burst < 0 { + return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName) + } + + if cfg.MaxEntrySizeBytes < 0 { + return fmt.Errorf("pipeline '%s': max entry size bytes cannot be negative", pipelineName) + } + + // Validate policy + switch strings.ToLower(cfg.Policy) { + case "", "pass", "drop": + // Valid policies + default: + return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')", + pipelineName, cfg.Policy) + } + + return nil +} + +func validateFilter(pipelineName string, filterIndex int, cfg *FilterConfig) error { + // Validate filter type + switch cfg.Type { + case FilterTypeInclude, FilterTypeExclude, "": + // Valid types + default: + return fmt.Errorf("pipeline '%s' filter[%d]: invalid type '%s' (must be 'include' or 'exclude')", + pipelineName, filterIndex, cfg.Type) + } + + // Validate filter logic + switch cfg.Logic { + case FilterLogicOr, FilterLogicAnd, "": + // Valid logic + default: + return fmt.Errorf("pipeline '%s' filter[%d]: invalid logic '%s' (must be 'or' or 'and')", + pipelineName, filterIndex, cfg.Logic) + } + + // Empty patterns is valid - passes everything + if len(cfg.Patterns) == 0 { + return nil + } + + // Validate regex patterns + for i, pattern := range cfg.Patterns { + if _, err := regexp.Compile(pattern); err != nil { + return fmt.Errorf("pipeline '%s' filter[%d] pattern[%d] '%s': invalid regex: %w", + pipelineName, filterIndex, i, pattern, err) + } + } + return nil } \ No newline at end of file diff --git a/src/internal/core/const.go b/src/internal/core/const.go new file mode 100644 index 0000000..e8b4a64 --- /dev/null +++ b/src/internal/core/const.go @@ -0,0 +1,13 @@ +// FILE: logwisp/src/internal/core/const.go +package core + +// Argon2id parameters +const ( + Argon2Time = 3 + Argon2Memory = 64 * 1024 // 64 MB + Argon2Threads = 4 + Argon2SaltLen = 16 + Argon2KeyLen = 32 +) + +const DefaultTokenLength = 32 \ No newline at end of file diff --git a/src/internal/filter/chain_test.go b/src/internal/filter/chain_test.go deleted file mode 100644 index 91a67cd..0000000 --- a/src/internal/filter/chain_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// FILE: logwisp/src/internal/filter/chain_test.go -package filter - -import ( - "testing" - - "logwisp/src/internal/config" - "logwisp/src/internal/core" - - "github.com/stretchr/testify/assert" -) - -func TestNewChain(t *testing.T) { - logger := newTestLogger() - - t.Run("Success", func(t *testing.T) { - configs := []config.FilterConfig{ - {Type: config.FilterTypeInclude, Patterns: []string{"apple"}}, - {Type: config.FilterTypeExclude, Patterns: []string{"banana"}}, - } - chain, err := NewChain(configs, logger) - assert.NoError(t, err) - assert.NotNil(t, chain) - assert.Len(t, chain.filters, 2) - }) - - t.Run("ErrorInvalidRegexInChain", func(t *testing.T) { - configs := []config.FilterConfig{ - {Patterns: []string{"apple"}}, - {Patterns: []string{"["}}, - } - chain, err := NewChain(configs, logger) - assert.Error(t, err) - assert.Nil(t, chain) - assert.Contains(t, err.Error(), "filter[1]") - }) -} - -func TestChain_Apply(t *testing.T) { - logger := newTestLogger() - entry := core.LogEntry{Message: "an apple a day"} - - t.Run("EmptyChain", func(t *testing.T) { - chain, err := NewChain([]config.FilterConfig{}, logger) - assert.NoError(t, err) - assert.True(t, chain.Apply(entry)) - }) - - t.Run("AllFiltersPass", func(t *testing.T) { - configs := []config.FilterConfig{ - {Type: config.FilterTypeInclude, Patterns: []string{"apple"}}, - {Type: config.FilterTypeInclude, Patterns: []string{"day"}}, - {Type: config.FilterTypeExclude, Patterns: []string{"banana"}}, - } - chain, err := NewChain(configs, logger) - assert.NoError(t, err) - assert.True(t, chain.Apply(entry)) - }) - - t.Run("OneFilterFails", func(t *testing.T) { - configs := []config.FilterConfig{ - {Type: config.FilterTypeInclude, Patterns: []string{"apple"}}, - {Type: config.FilterTypeExclude, Patterns: []string{"day"}}, // This one will fail - {Type: config.FilterTypeInclude, Patterns: []string{"a"}}, - } - chain, err := NewChain(configs, logger) - assert.NoError(t, err) - assert.False(t, chain.Apply(entry)) - }) - - t.Run("FirstFilterFails", func(t *testing.T) { - configs := []config.FilterConfig{ - {Type: config.FilterTypeInclude, Patterns: []string{"banana"}}, // This one will fail - {Type: config.FilterTypeInclude, Patterns: []string{"apple"}}, - } - chain, err := NewChain(configs, logger) - assert.NoError(t, err) - assert.False(t, chain.Apply(entry)) - }) -} \ No newline at end of file diff --git a/src/internal/filter/filter_test.go b/src/internal/filter/filter_test.go deleted file mode 100644 index b561b15..0000000 --- a/src/internal/filter/filter_test.go +++ /dev/null @@ -1,159 +0,0 @@ -// FILE: logwisp/src/internal/filter/filter_test.go -package filter - -import ( - "logwisp/src/internal/config" - "logwisp/src/internal/core" - "testing" - - "github.com/lixenwraith/log" - "github.com/stretchr/testify/assert" -) - -func newTestLogger() *log.Logger { - return log.NewLogger() -} - -func TestNewFilter(t *testing.T) { - logger := newTestLogger() - - t.Run("SuccessWithDefaults", func(t *testing.T) { - cfg := config.FilterConfig{Patterns: []string{"test"}} - f, err := NewFilter(cfg, logger) - assert.NoError(t, err) - assert.NotNil(t, f) - assert.Equal(t, config.FilterTypeInclude, f.config.Type) - assert.Equal(t, config.FilterLogicOr, f.config.Logic) - }) - - t.Run("SuccessWithCustomConfig", func(t *testing.T) { - cfg := config.FilterConfig{ - Type: config.FilterTypeExclude, - Logic: config.FilterLogicAnd, - Patterns: []string{"test", "pattern"}, - } - f, err := NewFilter(cfg, logger) - assert.NoError(t, err) - assert.NotNil(t, f) - assert.Equal(t, config.FilterTypeExclude, f.config.Type) - assert.Equal(t, config.FilterLogicAnd, f.config.Logic) - assert.Len(t, f.patterns, 2) - }) - - t.Run("ErrorInvalidRegex", func(t *testing.T) { - cfg := config.FilterConfig{Patterns: []string{"["}} - f, err := NewFilter(cfg, logger) - assert.Error(t, err) - assert.Nil(t, f) - assert.Contains(t, err.Error(), "invalid regex pattern") - }) -} - -func TestFilter_Apply(t *testing.T) { - logger := newTestLogger() - - testCases := []struct { - name string - cfg config.FilterConfig - entry core.LogEntry - expected bool - }{ - // Include OR logic - { - name: "IncludeOR_MatchOne", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}}, - entry: core.LogEntry{Message: "this is an apple"}, - expected: true, - }, - { - name: "IncludeOR_NoMatch", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}}, - entry: core.LogEntry{Message: "this is a pear"}, - expected: false, - }, - // Include AND logic - { - name: "IncludeAND_MatchAll", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}}, - entry: core.LogEntry{Message: "an apple keeps the doctor away"}, - expected: true, - }, - { - name: "IncludeAND_MatchOne", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}}, - entry: core.LogEntry{Message: "this is an apple"}, - expected: false, - }, - // Exclude OR logic - { - name: "ExcludeOR_MatchOne", - cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}}, - entry: core.LogEntry{Message: "this is an error"}, - expected: false, - }, - { - name: "ExcludeOR_NoMatch", - cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}}, - entry: core.LogEntry{Message: "this is a warning"}, - expected: true, - }, - // Exclude AND logic - { - name: "ExcludeAND_MatchAll", - cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}}, - entry: core.LogEntry{Message: "critical error in database"}, - expected: false, - }, - { - name: "ExcludeAND_MatchOne", - cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}}, - entry: core.LogEntry{Message: "critical error in app"}, - expected: true, - }, - // Edge Cases - { - name: "NoPatterns", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{}}, - entry: core.LogEntry{Message: "any message"}, - expected: true, - }, - { - name: "EmptyEntry_NoPatterns", - cfg: config.FilterConfig{Patterns: []string{}}, - entry: core.LogEntry{}, - expected: true, - }, - { - name: "EmptyEntry_DoesNotMatchSpace", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{" "}}, - entry: core.LogEntry{Level: "", Source: "", Message: ""}, - expected: false, // CORRECTED: An empty entry results in an empty string, which doesn't match a space. - }, - { - name: "MatchOnLevel", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"ERROR"}}, - entry: core.LogEntry{Level: "ERROR", Message: "A message"}, - expected: true, - }, - { - name: "MatchOnSource", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"database"}}, - entry: core.LogEntry{Source: "database", Message: "A message"}, - expected: true, - }, - { - name: "MatchOnCombinedFields", - cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"^app ERROR"}}, - entry: core.LogEntry{Source: "app", Level: "ERROR", Message: "A message"}, - expected: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - f, err := NewFilter(tc.cfg, logger) - assert.NoError(t, err) - assert.Equal(t, tc.expected, f.Apply(tc.entry)) - }) - } -} \ No newline at end of file diff --git a/src/internal/format/format.go b/src/internal/format/format.go index d62967f..c39a5ad 100644 --- a/src/internal/format/format.go +++ b/src/internal/format/format.go @@ -3,6 +3,7 @@ package format import ( "fmt" + "logwisp/src/internal/config" "logwisp/src/internal/core" @@ -19,20 +20,15 @@ type Formatter interface { } // Creates a new Formatter based on the provided configuration. -func NewFormatter(name string, options map[string]any, logger *log.Logger) (Formatter, error) { - // Default to raw if no format specified - if name == "" { - name = "raw" - } - - switch name { +func NewFormatter(cfg *config.FormatConfig, logger *log.Logger) (Formatter, error) { + switch cfg.Type { case "json": - return NewJSONFormatter(options, logger) + return NewJSONFormatter(cfg.JSONFormatOptions, logger) case "txt": - return NewTextFormatter(options, logger) - case "raw": - return NewRawFormatter(options, logger) + return NewTextFormatter(cfg.TextFormatOptions, logger) + case "raw", "": + return NewRawFormatter(cfg.RawFormatOptions, logger) default: - return nil, fmt.Errorf("unknown formatter type: %s", name) + return nil, fmt.Errorf("unknown formatter type: %s", cfg.Type) } } \ No newline at end of file diff --git a/src/internal/format/format_test.go b/src/internal/format/format_test.go deleted file mode 100644 index 2d803f0..0000000 --- a/src/internal/format/format_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// FILE: logwisp/src/internal/format/format_test.go -package format - -import ( - "testing" - - "github.com/lixenwraith/log" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func newTestLogger() *log.Logger { - return log.NewLogger() -} - -func TestNewFormatter(t *testing.T) { - logger := newTestLogger() - - testCases := []struct { - name string - formatName string - expected string - expectError bool - }{ - { - name: "JSONFormatter", - formatName: "json", - expected: "json", - }, - { - name: "TextFormatter", - formatName: "txt", - expected: "txt", - }, - { - name: "RawFormatter", - formatName: "raw", - expected: "raw", - }, - { - name: "DefaultToRaw", - formatName: "", - expected: "raw", - }, - { - name: "UnknownFormatter", - formatName: "xml", - expectError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - formatter, err := NewFormatter(tc.formatName, nil, logger) - if tc.expectError { - assert.Error(t, err) - assert.Nil(t, formatter) - } else { - require.NoError(t, err) - require.NotNil(t, formatter) - assert.Equal(t, tc.expected, formatter.Name()) - } - }) - } -} \ No newline at end of file diff --git a/src/internal/format/json.go b/src/internal/format/json.go index a670472..282310c 100644 --- a/src/internal/format/json.go +++ b/src/internal/format/json.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "logwisp/src/internal/config" "logwisp/src/internal/core" "github.com/lixenwraith/log" @@ -13,39 +14,15 @@ import ( // Produces structured JSON logs type JSONFormatter struct { - pretty bool - timestampField string - levelField string - messageField string - sourceField string - logger *log.Logger + config *config.JSONFormatterOptions + logger *log.Logger } // Creates a new JSON formatter -func NewJSONFormatter(options map[string]any, logger *log.Logger) (*JSONFormatter, error) { +func NewJSONFormatter(opts *config.JSONFormatterOptions, logger *log.Logger) (*JSONFormatter, error) { f := &JSONFormatter{ - timestampField: "timestamp", - levelField: "level", - messageField: "message", - sourceField: "source", - logger: logger, - } - - // Extract options - if pretty, ok := options["pretty"].(bool); ok { - f.pretty = pretty - } - if field, ok := options["timestamp_field"].(string); ok && field != "" { - f.timestampField = field - } - if field, ok := options["level_field"].(string); ok && field != "" { - f.levelField = field - } - if field, ok := options["message_field"].(string); ok && field != "" { - f.messageField = field - } - if field, ok := options["source_field"].(string); ok && field != "" { - f.sourceField = field + config: opts, + logger: logger, } return f, nil @@ -57,9 +34,9 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) { output := make(map[string]any) // First, populate with LogWisp metadata - output[f.timestampField] = entry.Time.Format(time.RFC3339Nano) - output[f.levelField] = entry.Level - output[f.sourceField] = entry.Source + output[f.config.TimestampField] = entry.Time.Format(time.RFC3339Nano) + output[f.config.LevelField] = entry.Level + output[f.config.SourceField] = entry.Source // Try to parse the message as JSON var msgData map[string]any @@ -68,21 +45,21 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) { // LogWisp metadata takes precedence for k, v := range msgData { // Don't overwrite our standard fields - if k != f.timestampField && k != f.levelField && k != f.sourceField { + if k != f.config.TimestampField && k != f.config.LevelField && k != f.config.SourceField { output[k] = v } } // If the original JSON had these fields, log that we're overriding - if _, hasTime := msgData[f.timestampField]; hasTime { + if _, hasTime := msgData[f.config.TimestampField]; hasTime { f.logger.Debug("msg", "Overriding timestamp from JSON message", "component", "json_formatter", - "original", msgData[f.timestampField], - "logwisp", output[f.timestampField]) + "original", msgData[f.config.TimestampField], + "logwisp", output[f.config.TimestampField]) } } else { // Message is not valid JSON - add as message field - output[f.messageField] = entry.Message + output[f.config.MessageField] = entry.Message } // Add any additional fields from LogEntry.Fields @@ -101,7 +78,7 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) { // Marshal to JSON var result []byte var err error - if f.pretty { + if f.config.Pretty { result, err = json.MarshalIndent(output, "", " ") } else { result, err = json.Marshal(output) @@ -147,7 +124,7 @@ func (f *JSONFormatter) FormatBatch(entries []core.LogEntry) ([]byte, error) { // Marshal the entire batch as an array var result []byte var err error - if f.pretty { + if f.config.Pretty { result, err = json.MarshalIndent(batch, "", " ") } else { result, err = json.Marshal(batch) diff --git a/src/internal/format/json_test.go b/src/internal/format/json_test.go deleted file mode 100644 index 0e448b2..0000000 --- a/src/internal/format/json_test.go +++ /dev/null @@ -1,129 +0,0 @@ -// FILE: logwisp/src/internal/format/json_test.go -package format - -import ( - "encoding/json" - "strings" - "testing" - "time" - - "logwisp/src/internal/core" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestJSONFormatter_Format(t *testing.T) { - logger := newTestLogger() - testTime := time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC) - entry := core.LogEntry{ - Time: testTime, - Source: "test-app", - Level: "INFO", - Message: "this is a test", - } - - t.Run("BasicFormatting", func(t *testing.T) { - formatter, err := NewJSONFormatter(nil, logger) - require.NoError(t, err) - - output, err := formatter.Format(entry) - require.NoError(t, err) - - var result map[string]interface{} - err = json.Unmarshal(output, &result) - require.NoError(t, err, "Output should be valid JSON") - - assert.Equal(t, testTime.Format(time.RFC3339Nano), result["timestamp"]) - assert.Equal(t, "INFO", result["level"]) - assert.Equal(t, "test-app", result["source"]) - assert.Equal(t, "this is a test", result["message"]) - assert.True(t, strings.HasSuffix(string(output), "\n"), "Output should end with a newline") - }) - - t.Run("PrettyFormatting", func(t *testing.T) { - formatter, err := NewJSONFormatter(map[string]any{"pretty": true}, logger) - require.NoError(t, err) - - output, err := formatter.Format(entry) - require.NoError(t, err) - - assert.Contains(t, string(output), ` "level": "INFO"`) - assert.True(t, strings.HasSuffix(string(output), "\n")) - }) - - t.Run("MessageIsJSON", func(t *testing.T) { - jsonMessageEntry := entry - jsonMessageEntry.Message = `{"user":"test","request_id":"abc-123"}` - formatter, err := NewJSONFormatter(nil, logger) - require.NoError(t, err) - - output, err := formatter.Format(jsonMessageEntry) - require.NoError(t, err) - - var result map[string]interface{} - err = json.Unmarshal(output, &result) - require.NoError(t, err) - - assert.Equal(t, "test", result["user"]) - assert.Equal(t, "abc-123", result["request_id"]) - _, messageExists := result["message"] - assert.False(t, messageExists, "message field should not exist when message is merged JSON") - }) - - t.Run("MessageIsJSONWithConflicts", func(t *testing.T) { - jsonMessageEntry := entry - jsonMessageEntry.Level = "INFO" // top-level - jsonMessageEntry.Message = `{"level":"DEBUG","msg":"hello"}` - formatter, err := NewJSONFormatter(nil, logger) - require.NoError(t, err) - - output, err := formatter.Format(jsonMessageEntry) - require.NoError(t, err) - - var result map[string]interface{} - err = json.Unmarshal(output, &result) - require.NoError(t, err) - - assert.Equal(t, "INFO", result["level"], "Top-level LogEntry field should take precedence") - }) - - t.Run("CustomFieldNames", func(t *testing.T) { - options := map[string]any{"timestamp_field": "@timestamp"} - formatter, err := NewJSONFormatter(options, logger) - require.NoError(t, err) - - output, err := formatter.Format(entry) - require.NoError(t, err) - - var result map[string]interface{} - err = json.Unmarshal(output, &result) - require.NoError(t, err) - - _, defaultExists := result["timestamp"] - assert.False(t, defaultExists) - assert.Equal(t, testTime.Format(time.RFC3339Nano), result["@timestamp"]) - }) -} - -func TestJSONFormatter_FormatBatch(t *testing.T) { - logger := newTestLogger() - formatter, err := NewJSONFormatter(nil, logger) - require.NoError(t, err) - - entries := []core.LogEntry{ - {Time: time.Now(), Level: "INFO", Message: "First message"}, - {Time: time.Now(), Level: "WARN", Message: "Second message"}, - } - - output, err := formatter.FormatBatch(entries) - require.NoError(t, err) - - var result []map[string]interface{} - err = json.Unmarshal(output, &result) - require.NoError(t, err, "Batch output should be a valid JSON array") - require.Len(t, result, 2) - - assert.Equal(t, "First message", result[0]["message"]) - assert.Equal(t, "WARN", result[1]["level"]) -} \ No newline at end of file diff --git a/src/internal/format/raw.go b/src/internal/format/raw.go index 923acda..6c3f745 100644 --- a/src/internal/format/raw.go +++ b/src/internal/format/raw.go @@ -2,6 +2,7 @@ package format import ( + "logwisp/src/internal/config" "logwisp/src/internal/core" "github.com/lixenwraith/log" @@ -9,20 +10,26 @@ import ( // Outputs the log message as-is with a newline type RawFormatter struct { + config *config.RawFormatterOptions logger *log.Logger } // Creates a new raw formatter -func NewRawFormatter(options map[string]any, logger *log.Logger) (*RawFormatter, error) { +func NewRawFormatter(cfg *config.RawFormatterOptions, logger *log.Logger) (*RawFormatter, error) { return &RawFormatter{ + config: cfg, logger: logger, }, nil } // Returns the message with a newline appended func (f *RawFormatter) Format(entry core.LogEntry) ([]byte, error) { - // Simply return the message with newline - return append([]byte(entry.Message), '\n'), nil + // TODO: Standardize not to add "\n" when processing raw, check lixenwraith/log for consistency + if f.config.AddNewLine { + return append([]byte(entry.Message), '\n'), nil + } else { + return []byte(entry.Message), nil + } } // Returns the formatter name diff --git a/src/internal/format/raw_test.go b/src/internal/format/raw_test.go deleted file mode 100644 index 84c8b98..0000000 --- a/src/internal/format/raw_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// FILE: logwisp/src/internal/format/raw_test.go -package format - -import ( - "testing" - "time" - - "logwisp/src/internal/core" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestRawFormatter_Format(t *testing.T) { - logger := newTestLogger() - formatter, err := NewRawFormatter(nil, logger) - require.NoError(t, err) - - entry := core.LogEntry{ - Time: time.Now(), - Message: "This is a raw log line.", - } - - output, err := formatter.Format(entry) - require.NoError(t, err) - - expected := "This is a raw log line.\n" - assert.Equal(t, expected, string(output)) -} \ No newline at end of file diff --git a/src/internal/format/text.go b/src/internal/format/text.go index a7c6d4e..67a8129 100644 --- a/src/internal/format/text.go +++ b/src/internal/format/text.go @@ -4,6 +4,7 @@ package format import ( "bytes" "fmt" + "logwisp/src/internal/config" "strings" "text/template" "time" @@ -15,41 +16,29 @@ import ( // Produces human-readable text logs using templates type TextFormatter struct { - template *template.Template - timestampFormat string - logger *log.Logger + config *config.TextFormatterOptions + template *template.Template + logger *log.Logger } // Creates a new text formatter -func NewTextFormatter(options map[string]any, logger *log.Logger) (*TextFormatter, error) { - // Default template - templateStr := "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}" - if tmpl, ok := options["template"].(string); ok && tmpl != "" { - templateStr = tmpl - } - - // Default timestamp format - timestampFormat := time.RFC3339 - if tsFormat, ok := options["timestamp_format"].(string); ok && tsFormat != "" { - timestampFormat = tsFormat - } - +func NewTextFormatter(opts *config.TextFormatterOptions, logger *log.Logger) (*TextFormatter, error) { f := &TextFormatter{ - timestampFormat: timestampFormat, - logger: logger, + config: opts, + logger: logger, } // Create template with helper functions funcMap := template.FuncMap{ "FmtTime": func(t time.Time) string { - return t.Format(f.timestampFormat) + return t.Format(f.config.TimestampFormat) }, "ToUpper": strings.ToUpper, "ToLower": strings.ToLower, "TrimSpace": strings.TrimSpace, } - tmpl, err := template.New("log").Funcs(funcMap).Parse(templateStr) + tmpl, err := template.New("log").Funcs(funcMap).Parse(f.config.Template) if err != nil { return nil, fmt.Errorf("invalid template: %w", err) } @@ -86,7 +75,7 @@ func (f *TextFormatter) Format(entry core.LogEntry) ([]byte, error) { "error", err) fallback := fmt.Sprintf("[%s] [%s] %s - %s\n", - entry.Time.Format(f.timestampFormat), + entry.Time.Format(f.config.TimestampFormat), strings.ToUpper(entry.Level), entry.Source, entry.Message) diff --git a/src/internal/format/text_test.go b/src/internal/format/text_test.go deleted file mode 100644 index 464441f..0000000 --- a/src/internal/format/text_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// FILE: logwisp/src/internal/format/text_test.go -package format - -import ( - "fmt" - "strings" - "testing" - "time" - - "logwisp/src/internal/core" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewTextFormatter(t *testing.T) { - logger := newTestLogger() - t.Run("InvalidTemplate", func(t *testing.T) { - options := map[string]any{"template": "{{ .Timestamp | InvalidFunc }}"} - _, err := NewTextFormatter(options, logger) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid template") - }) -} - -func TestTextFormatter_Format(t *testing.T) { - logger := newTestLogger() - testTime := time.Date(2023, 10, 27, 10, 30, 0, 0, time.UTC) - entry := core.LogEntry{ - Time: testTime, - Source: "api", - Level: "WARN", - Message: "rate limit exceeded", - } - - t.Run("DefaultTemplate", func(t *testing.T) { - formatter, err := NewTextFormatter(nil, logger) - require.NoError(t, err) - - output, err := formatter.Format(entry) - require.NoError(t, err) - - expected := fmt.Sprintf("[%s] [WARN] api - rate limit exceeded\n", testTime.Format(time.RFC3339)) - assert.Equal(t, expected, string(output)) - }) - - t.Run("CustomTemplate", func(t *testing.T) { - options := map[string]any{"template": "{{.Level}}:{{.Source}}:{{.Message}}"} - formatter, err := NewTextFormatter(options, logger) - require.NoError(t, err) - - output, err := formatter.Format(entry) - require.NoError(t, err) - - expected := "WARN:api:rate limit exceeded\n" - assert.Equal(t, expected, string(output)) - }) - - t.Run("CustomTimestampFormat", func(t *testing.T) { - options := map[string]any{"timestamp_format": "2006-01-02"} - formatter, err := NewTextFormatter(options, logger) - require.NoError(t, err) - - output, err := formatter.Format(entry) - require.NoError(t, err) - - assert.True(t, strings.HasPrefix(string(output), "[2023-10-27]")) - }) - - t.Run("EmptyLevelDefaultsToInfo", func(t *testing.T) { - emptyLevelEntry := entry - emptyLevelEntry.Level = "" - formatter, err := NewTextFormatter(nil, logger) - require.NoError(t, err) - - output, err := formatter.Format(emptyLevelEntry) - require.NoError(t, err) - - assert.Contains(t, string(output), "[INFO]") - }) -} \ No newline at end of file diff --git a/src/internal/limit/net.go b/src/internal/limit/net.go index f24df56..3bcc2ba 100644 --- a/src/internal/limit/net.go +++ b/src/internal/limit/net.go @@ -34,7 +34,7 @@ const ( // NetLimiter manages net limiting for a transport type NetLimiter struct { - config config.NetLimitConfig + config *config.NetLimitConfig logger *log.Logger // IP Access Control Lists @@ -89,7 +89,11 @@ type connTracker struct { } // Creates a new net limiter -func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter { +func NewNetLimiter(cfg *config.NetLimitConfig, logger *log.Logger) *NetLimiter { + if cfg == nil { + return nil + } + // Return nil only if nothing is configured hasACL := len(cfg.IPWhitelist) > 0 || len(cfg.IPBlacklist) > 0 hasRateLimit := cfg.Enabled @@ -120,7 +124,7 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter { } // Parse IP lists - l.parseIPLists(cfg) + l.parseIPLists() // Start cleanup goroutine only if rate limiting is enabled if cfg.Enabled { @@ -144,16 +148,16 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter { } // parseIPLists parses and validates IP whitelist/blacklist -func (l *NetLimiter) parseIPLists(cfg config.NetLimitConfig) { +func (l *NetLimiter) parseIPLists() { // Parse whitelist - for _, entry := range cfg.IPWhitelist { + for _, entry := range l.config.IPWhitelist { if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil { l.ipWhitelist = append(l.ipWhitelist, ipNet) } } // Parse blacklist - for _, entry := range cfg.IPBlacklist { + for _, entry := range l.config.IPBlacklist { if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil { l.ipBlacklist = append(l.ipBlacklist, ipNet) } diff --git a/src/internal/scram/integration.go b/src/internal/scram/integration.go deleted file mode 100644 index be190d9..0000000 --- a/src/internal/scram/integration.go +++ /dev/null @@ -1,117 +0,0 @@ -// FILE: src/internal/scram/integration.go -package scram - -import ( - "crypto/rand" - "fmt" - "sync" - "time" - - "golang.org/x/time/rate" -) - -// ScramManager provides high-level SCRAM operations with rate limiting -type ScramManager struct { - server *Server - sessions map[string]*SessionInfo - limiter map[string]*rate.Limiter - mu sync.RWMutex -} - -// SessionInfo tracks authenticated sessions -type SessionInfo struct { - Username string - RemoteAddr string - SessionID string - CreatedAt time.Time - LastActivity time.Time - Method string // "scram-sha-256" -} - -// NewScramManager creates SCRAM manager -func NewScramManager() *ScramManager { - m := &ScramManager{ - server: NewServer(), - sessions: make(map[string]*SessionInfo), - limiter: make(map[string]*rate.Limiter), - } - - // Start cleanup goroutine - go m.cleanupLoop() - return m -} - -// RegisterUser creates new user credential -func (sm *ScramManager) RegisterUser(username, password string) error { - salt := make([]byte, 16) - if _, err := rand.Read(salt); err != nil { - return fmt.Errorf("salt generation failed: %w", err) - } - - cred, err := DeriveCredential(username, password, salt, - sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads) - if err != nil { - return err - } - - sm.server.AddCredential(cred) - return nil -} - -// GetRateLimiter returns per-IP rate limiter -func (sm *ScramManager) GetRateLimiter(remoteAddr string) *rate.Limiter { - sm.mu.Lock() - defer sm.mu.Unlock() - - if limiter, exists := sm.limiter[remoteAddr]; exists { - return limiter - } - - // 10 attempts per minute, burst of 3 - limiter := rate.NewLimiter(rate.Every(6*time.Second), 3) - sm.limiter[remoteAddr] = limiter - - // Prevent unbounded growth - if len(sm.limiter) > 10000 { - // Remove oldest entries - for addr := range sm.limiter { - delete(sm.limiter, addr) - if len(sm.limiter) < 8000 { - break - } - } - } - - return limiter -} - -func (sm *ScramManager) cleanupLoop() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for range ticker.C { - sm.mu.Lock() - cutoff := time.Now().Add(-30 * time.Minute) - for sid, session := range sm.sessions { - if session.LastActivity.Before(cutoff) { - delete(sm.sessions, sid) - } - } - sm.mu.Unlock() - } -} - -// HandleClientFirst wraps server's HandleClientFirst -func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { - return sm.server.HandleClientFirst(msg) -} - -// HandleClientFinal wraps server's HandleClientFinal -func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { - return sm.server.HandleClientFinal(msg) -} - -// AddCredential wraps server's AddCredential -func (sm *ScramManager) AddCredential(cred *Credential) { - sm.server.AddCredential(cred) -} \ No newline at end of file diff --git a/src/internal/scram/message.go b/src/internal/scram/message.go deleted file mode 100644 index 6b97a5b..0000000 --- a/src/internal/scram/message.go +++ /dev/null @@ -1,101 +0,0 @@ -// FILE: src/internal/scram/message.go -package scram - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "strings" -) - -// ClientFirst initiates authentication -type ClientFirst struct { - Username string `json:"u"` - ClientNonce string `json:"n"` -} - -// ServerFirst contains server challenge -type ServerFirst struct { - FullNonce string `json:"r"` // client_nonce + server_nonce - Salt string `json:"s"` // base64 - ArgonTime uint32 `json:"t"` - ArgonMemory uint32 `json:"m"` - ArgonThreads uint8 `json:"p"` -} - -// ClientFinal contains client proof -type ClientFinal struct { - FullNonce string `json:"r"` - ClientProof string `json:"p"` // base64 -} - -// ServerFinal contains server signature for mutual auth -type ServerFinal struct { - ServerSignature string `json:"v"` // base64 - SessionID string `json:"sid,omitempty"` -} - -// Marshal/Unmarshal helpers for TCP protocol (line-based) -func (cf *ClientFirst) Marshal() string { - return fmt.Sprintf("u=%s,n=%s", cf.Username, cf.ClientNonce) -} - -func ParseClientFirst(data string) (*ClientFirst, error) { - parts := strings.Split(data, ",") - msg := &ClientFirst{} - for _, part := range parts { - kv := strings.SplitN(part, "=", 2) - if len(kv) != 2 { - continue - } - switch kv[0] { - case "u": - msg.Username = kv[1] - case "n": - msg.ClientNonce = kv[1] - } - } - if msg.Username == "" || msg.ClientNonce == "" { - return nil, fmt.Errorf("missing required fields") - } - return msg, nil -} - -func (sf *ServerFirst) Marshal() string { - return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d", - sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads) -} - -func ParseServerFirst(data string) (*ServerFirst, error) { - parts := strings.Split(data, ",") - msg := &ServerFirst{} - for _, part := range parts { - kv := strings.SplitN(part, "=", 2) - if len(kv) != 2 { - continue - } - switch kv[0] { - case "r": - msg.FullNonce = kv[1] - case "s": - msg.Salt = kv[1] - case "t": - fmt.Sscanf(kv[1], "%d", &msg.ArgonTime) - case "m": - fmt.Sscanf(kv[1], "%d", &msg.ArgonMemory) - case "p": - fmt.Sscanf(kv[1], "%d", &msg.ArgonThreads) - } - } - return msg, nil -} - -// JSON variants for HTTP -func (cf *ClientFirst) MarshalJSON() ([]byte, error) { - return json.Marshal(*cf) -} - -func (sf *ServerFirst) MarshalJSON() ([]byte, error) { - sf.Salt = base64.StdEncoding.EncodeToString([]byte(sf.Salt)) - return json.Marshal(*sf) -} \ No newline at end of file diff --git a/src/internal/scram/scram_test.go b/src/internal/scram/scram_test.go deleted file mode 100644 index 0b05fda..0000000 --- a/src/internal/scram/scram_test.go +++ /dev/null @@ -1,228 +0,0 @@ -// FILE: src/internal/scram/scram_test.go -package scram - -import ( - "crypto/rand" - "encoding/base64" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/crypto/argon2" -) - -func TestCredentialDerivation(t *testing.T) { - salt := make([]byte, 16) - _, err := rand.Read(salt) - require.NoError(t, err) - - cred, err := DeriveCredential("testuser", "testpass123", salt, 3, 64*1024, 4) - require.NoError(t, err) - - assert.Equal(t, "testuser", cred.Username) - assert.Equal(t, uint32(3), cred.ArgonTime) - assert.Equal(t, uint32(64*1024), cred.ArgonMemory) - assert.Equal(t, uint8(4), cred.ArgonThreads) - assert.Len(t, cred.StoredKey, 32) - assert.Len(t, cred.ServerKey, 32) -} - -func TestFullHandshake(t *testing.T) { - // Setup server with credential - server := NewServer() - salt := make([]byte, 16) - _, err := rand.Read(salt) - require.NoError(t, err) - - cred, err := DeriveCredential("alice", "secret123", salt, 3, 64*1024, 4) - require.NoError(t, err) - server.AddCredential(cred) - - // Setup client - client := NewClient("alice", "secret123") - - // Step 1: Client starts - clientFirst, err := client.StartAuthentication() - require.NoError(t, err) - assert.Equal(t, "alice", clientFirst.Username) - assert.NotEmpty(t, clientFirst.ClientNonce) - - // Step 2: Server responds - serverFirst, err := server.HandleClientFirst(clientFirst) - require.NoError(t, err) - assert.Contains(t, serverFirst.FullNonce, clientFirst.ClientNonce) - assert.Equal(t, base64.StdEncoding.EncodeToString(salt), serverFirst.Salt) - - // Step 3: Client proves - clientFinal, err := client.ProcessServerFirst(serverFirst) - require.NoError(t, err) - assert.Equal(t, serverFirst.FullNonce, clientFinal.FullNonce) - assert.NotEmpty(t, clientFinal.ClientProof) - - // Step 4: Server verifies and signs - serverFinal, err := server.HandleClientFinal(clientFinal) - require.NoError(t, err) - assert.NotEmpty(t, serverFinal.ServerSignature) - assert.NotEmpty(t, serverFinal.SessionID) - - // Step 5: Client verifies server - err = client.VerifyServerFinal(serverFinal) - assert.NoError(t, err) -} - -func TestInvalidPassword(t *testing.T) { - server := NewServer() - salt := make([]byte, 16) - rand.Read(salt) - - cred, _ := DeriveCredential("bob", "correct", salt, 3, 64*1024, 4) - server.AddCredential(cred) - - // Client with wrong password - client := NewClient("bob", "wrong") - - clientFirst, _ := client.StartAuthentication() - serverFirst, _ := server.HandleClientFirst(clientFirst) - clientFinal, _ := client.ProcessServerFirst(serverFirst) - - clientFinal, err := client.ProcessServerFirst(serverFirst) - require.NoError(t, err) // Check error to prevent panic - - // Server should reject - _, err = server.HandleClientFinal(clientFinal) - assert.Error(t, err) - assert.Contains(t, err.Error(), "authentication failed") -} - -func TestHandshakeTimeout(t *testing.T) { - server := NewServer() - salt := make([]byte, 16) - rand.Read(salt) - - cred, _ := DeriveCredential("charlie", "pass", salt, 3, 64*1024, 4) - server.AddCredential(cred) - - client := NewClient("charlie", "pass") - clientFirst, _ := client.StartAuthentication() - serverFirst, _ := server.HandleClientFirst(clientFirst) - - // Manipulate handshake timestamp - server.mu.Lock() - if state, exists := server.handshakes[serverFirst.FullNonce]; exists { - state.CreatedAt = time.Now().Add(-61 * time.Second) - } - server.mu.Unlock() - - clientFinal, _ := client.ProcessServerFirst(serverFirst) - _, err := server.HandleClientFinal(clientFinal) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "timeout") -} - -func TestMessageParsing(t *testing.T) { - // Test ClientFirst parsing - cf, err := ParseClientFirst("u=alice,n=abc123") - require.NoError(t, err) - assert.Equal(t, "alice", cf.Username) - assert.Equal(t, "abc123", cf.ClientNonce) - - // Test marshal round-trip - marshaled := cf.Marshal() - parsed, err := ParseClientFirst(marshaled) - require.NoError(t, err) - assert.Equal(t, cf.Username, parsed.Username) - assert.Equal(t, cf.ClientNonce, parsed.ClientNonce) -} - -func TestPHCMigration(t *testing.T) { - // Create PHC hash using known values - password := "testpass" - salt := []byte("saltsaltsaltsalt") - saltB64 := base64.RawStdEncoding.EncodeToString(salt) - - // This would be generated by the existing auth system - phcHash := "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + "dGVzdGhhc2g" // dummy hash - - // For testing, we need a valid hash - generate it - hash := argon2.IDKey([]byte(password), salt, 3, 65536, 4, 32) - hashB64 := base64.RawStdEncoding.EncodeToString(hash) - phcHash = "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + hashB64 - - cred, err := MigrateFromPHC("alice", password, phcHash) - require.NoError(t, err) - assert.Equal(t, "alice", cred.Username) - assert.Equal(t, salt, cred.Salt) -} - -func TestRateLimiting(t *testing.T) { - manager := NewScramManager() - - limiter := manager.GetRateLimiter("192.168.1.1") - - // Should allow burst - for i := 0; i < 3; i++ { - assert.True(t, limiter.Allow()) - } - - // 4th should be rate limited - assert.False(t, limiter.Allow()) -} - -func TestNonceUniqueness(t *testing.T) { - nonces := make(map[string]bool) - - for i := 0; i < 1000; i++ { - nonce := generateNonce() - assert.False(t, nonces[nonce], "Duplicate nonce generated") - nonces[nonce] = true - } -} - -func TestConstantTimeComparison(t *testing.T) { - server := NewServer() - salt := make([]byte, 16) - rand.Read(salt) - - // Add real user - cred, _ := DeriveCredential("realuser", "realpass", salt, 3, 64*1024, 4) - server.AddCredential(cred) - - // Test timing independence for invalid user - fakeClient := NewClient("fakeuser", "fakepass") - clientFirst, _ := fakeClient.StartAuthentication() - - // Should still return ServerFirst (no user enumeration) - serverFirst, err := server.HandleClientFirst(clientFirst) - assert.NotNil(t, serverFirst) - assert.Error(t, err) // Internal error, not exposed to client -} - -func BenchmarkArgon2Derivation(b *testing.B) { - salt := make([]byte, 16) - rand.Read(salt) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - DeriveCredential("user", "password", salt, 3, 64*1024, 4) - } -} - -func BenchmarkFullHandshake(b *testing.B) { - server := NewServer() - salt := make([]byte, 16) - rand.Read(salt) - cred, _ := DeriveCredential("bench", "pass", salt, 3, 64*1024, 4) - server.AddCredential(cred) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - client := NewClient("bench", "pass") - clientFirst, _ := client.StartAuthentication() - serverFirst, _ := server.HandleClientFirst(clientFirst) - clientFinal, _ := client.ProcessServerFirst(serverFirst) - serverFinal, _ := server.HandleClientFinal(clientFinal) - client.VerifyServerFinal(serverFinal) - } -} \ No newline at end of file diff --git a/src/internal/service/pipeline.go b/src/internal/service/pipeline.go index e60cddf..af6de71 100644 --- a/src/internal/service/pipeline.go +++ b/src/internal/service/pipeline.go @@ -3,12 +3,14 @@ package service import ( "context" + "fmt" "sync" "sync/atomic" "time" "logwisp/src/internal/config" "logwisp/src/internal/filter" + "logwisp/src/internal/format" "logwisp/src/internal/limit" "logwisp/src/internal/sink" "logwisp/src/internal/source" @@ -18,8 +20,7 @@ import ( // Manages the flow of data from sources through filters to sinks type Pipeline struct { - Name string - Config config.PipelineConfig + Config *config.PipelineConfig Sources []source.Source RateLimiter *limit.RateLimiter FilterChain *filter.Chain @@ -43,11 +44,116 @@ type PipelineStats struct { FilterStats map[string]any } +// Creates and starts a new pipeline +func (s *Service) NewPipeline(cfg *config.PipelineConfig) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.pipelines[cfg.Name]; exists { + err := fmt.Errorf("pipeline '%s' already exists", cfg.Name) + s.logger.Error("msg", "Failed to create pipeline - duplicate name", + "component", "service", + "pipeline", cfg.Name, + "error", err) + return err + } + + s.logger.Debug("msg", "Creating pipeline", "pipeline", cfg.Name) + + // Create pipeline context + pipelineCtx, pipelineCancel := context.WithCancel(s.ctx) + + // Create pipeline instance + pipeline := &Pipeline{ + Config: cfg, + Stats: &PipelineStats{ + StartTime: time.Now(), + }, + ctx: pipelineCtx, + cancel: pipelineCancel, + logger: s.logger, + } + + // Create sources + for i, srcCfg := range cfg.Sources { + src, err := s.createSource(&srcCfg) + if err != nil { + pipelineCancel() + return fmt.Errorf("failed to create source[%d]: %w", i, err) + } + pipeline.Sources = append(pipeline.Sources, src) + } + + // Create pipeline rate limiter + if cfg.RateLimit != nil { + limiter, err := limit.NewRateLimiter(*cfg.RateLimit, s.logger) + if err != nil { + pipelineCancel() + return fmt.Errorf("failed to create pipeline rate limiter: %w", err) + } + pipeline.RateLimiter = limiter + } + + // Create filter chain + if len(cfg.Filters) > 0 { + chain, err := filter.NewChain(cfg.Filters, s.logger) + if err != nil { + pipelineCancel() + return fmt.Errorf("failed to create filter chain: %w", err) + } + pipeline.FilterChain = chain + } + + // Create formatter for the pipeline + formatter, err := format.NewFormatter(cfg.Format, s.logger) + if err != nil { + pipelineCancel() + return fmt.Errorf("failed to create formatter: %w", err) + } + + // Create sinks + for i, sinkCfg := range cfg.Sinks { + sinkInst, err := s.createSink(sinkCfg, formatter) + if err != nil { + pipelineCancel() + return fmt.Errorf("failed to create sink[%d]: %w", i, err) + } + pipeline.Sinks = append(pipeline.Sinks, sinkInst) + } + + // Start all sources + for i, src := range pipeline.Sources { + if err := src.Start(); err != nil { + pipeline.Shutdown() + return fmt.Errorf("failed to start source[%d]: %w", i, err) + } + } + + // Start all sinks + for i, sinkInst := range pipeline.Sinks { + if err := sinkInst.Start(pipelineCtx); err != nil { + pipeline.Shutdown() + return fmt.Errorf("failed to start sink[%d]: %w", i, err) + } + } + + // Wire sources to sinks through filters + s.wirePipeline(pipeline) + + // Start stats updater + pipeline.startStatsUpdater(pipelineCtx) + + s.pipelines[cfg.Name] = pipeline + s.logger.Info("msg", "Pipeline created successfully", + "pipeline", cfg.Name) + return nil +} + // Gracefully stops the pipeline func (p *Pipeline) Shutdown() { p.logger.Info("msg", "Shutting down pipeline", "component", "pipeline", - "pipeline", p.Name) + "pipeline", p.Config.Name) // Cancel context to stop processing p.cancel() @@ -78,7 +184,7 @@ func (p *Pipeline) Shutdown() { p.logger.Info("msg", "Pipeline shutdown complete", "component", "pipeline", - "pipeline", p.Name) + "pipeline", p.Config.Name) } // Returns pipeline statistics @@ -88,7 +194,7 @@ func (p *Pipeline) GetStats() map[string]any { defer func() { if r := recover(); r != nil { p.logger.Error("msg", "Panic getting pipeline stats", - "pipeline", p.Name, + "pipeline", p.Config.Name, "panic", r) } }() @@ -142,7 +248,7 @@ func (p *Pipeline) GetStats() map[string]any { } return map[string]any{ - "name": p.Name, + "name": p.Config.Name, "uptime_seconds": int(time.Since(p.Stats.StartTime).Seconds()), "total_processed": p.Stats.TotalEntriesProcessed.Load(), "total_dropped_rate_limit": p.Stats.TotalEntriesDroppedByRateLimit.Load(), diff --git a/src/internal/service/service.go b/src/internal/service/service.go index 8fc5e78..a2ba640 100644 --- a/src/internal/service/service.go +++ b/src/internal/service/service.go @@ -5,13 +5,10 @@ import ( "context" "fmt" "sync" - "time" "logwisp/src/internal/config" "logwisp/src/internal/core" - "logwisp/src/internal/filter" "logwisp/src/internal/format" - "logwisp/src/internal/limit" "logwisp/src/internal/sink" "logwisp/src/internal/source" @@ -39,127 +36,6 @@ func NewService(ctx context.Context, logger *log.Logger) *Service { } } -// Creates and starts a new pipeline -func (s *Service) NewPipeline(cfg config.PipelineConfig) error { - s.mu.Lock() - defer s.mu.Unlock() - - if _, exists := s.pipelines[cfg.Name]; exists { - err := fmt.Errorf("pipeline '%s' already exists", cfg.Name) - s.logger.Error("msg", "Failed to create pipeline - duplicate name", - "component", "service", - "pipeline", cfg.Name, - "error", err) - return err - } - - s.logger.Debug("msg", "Creating pipeline", "pipeline", cfg.Name) - - // Create pipeline context - pipelineCtx, pipelineCancel := context.WithCancel(s.ctx) - - // Create pipeline instance - pipeline := &Pipeline{ - Name: cfg.Name, - Config: cfg, - Stats: &PipelineStats{ - StartTime: time.Now(), - }, - ctx: pipelineCtx, - cancel: pipelineCancel, - logger: s.logger, - } - - // Create sources - for i, srcCfg := range cfg.Sources { - src, err := s.createSource(srcCfg) - if err != nil { - pipelineCancel() - return fmt.Errorf("failed to create source[%d]: %w", i, err) - } - pipeline.Sources = append(pipeline.Sources, src) - } - - // Create pipeline rate limiter - if cfg.RateLimit != nil { - limiter, err := limit.NewRateLimiter(*cfg.RateLimit, s.logger) - if err != nil { - pipelineCancel() - return fmt.Errorf("failed to create pipeline rate limiter: %w", err) - } - pipeline.RateLimiter = limiter - } - - // Create filter chain - if len(cfg.Filters) > 0 { - chain, err := filter.NewChain(cfg.Filters, s.logger) - if err != nil { - pipelineCancel() - return fmt.Errorf("failed to create filter chain: %w", err) - } - pipeline.FilterChain = chain - } - - // Create formatter for the pipeline - var formatter format.Formatter - var err error - if cfg.Format != "" || len(cfg.FormatOptions) > 0 { - formatter, err = format.NewFormatter(cfg.Format, cfg.FormatOptions, s.logger) - if err != nil { - pipelineCancel() - return fmt.Errorf("failed to create formatter: %w", err) - } - } - - // Create sinks - for i, sinkCfg := range cfg.Sinks { - sinkInst, err := s.createSink(sinkCfg, formatter) - if err != nil { - pipelineCancel() - return fmt.Errorf("failed to create sink[%d]: %w", i, err) - } - pipeline.Sinks = append(pipeline.Sinks, sinkInst) - } - - // Configure authentication for sources that support it before starting them - for _, sourceInst := range pipeline.Sources { - sourceInst.SetAuth(cfg.Auth) - } - - // Start all sources - for i, src := range pipeline.Sources { - if err := src.Start(); err != nil { - pipeline.Shutdown() - return fmt.Errorf("failed to start source[%d]: %w", i, err) - } - } - - // Configure authentication for sinks that support it before starting them - for _, sinkInst := range pipeline.Sinks { - sinkInst.SetAuth(cfg.Auth) - } - - // Start all sinks - for i, sinkInst := range pipeline.Sinks { - if err := sinkInst.Start(pipelineCtx); err != nil { - pipeline.Shutdown() - return fmt.Errorf("failed to start sink[%d]: %w", i, err) - } - } - - // Wire sources to sinks through filters - s.wirePipeline(pipeline) - - // Start stats updater - pipeline.startStatsUpdater(pipelineCtx) - - s.pipelines[cfg.Name] = pipeline - s.logger.Info("msg", "Pipeline created successfully", - "pipeline", cfg.Name, - "auth_enabled", cfg.Auth != nil && cfg.Auth.Type != "none") - return nil -} - // Connects sources to sinks through filters func (s *Service) wirePipeline(p *Pipeline) { // For each source, subscribe and process entries @@ -175,17 +51,17 @@ func (s *Service) wirePipeline(p *Pipeline) { defer func() { if r := recover(); r != nil { s.logger.Error("msg", "Panic in pipeline processing", - "pipeline", p.Name, + "pipeline", p.Config.Name, "source", source.GetStats().Type, "panic", r) // Ensure failed pipelines don't leave resources hanging go func() { s.logger.Warn("msg", "Shutting down pipeline due to panic", - "pipeline", p.Name) - if err := s.RemovePipeline(p.Name); err != nil { + "pipeline", p.Config.Name) + if err := s.RemovePipeline(p.Config.Name); err != nil { s.logger.Error("msg", "Failed to remove panicked pipeline", - "pipeline", p.Name, + "pipeline", p.Config.Name, "error", err) } }() @@ -228,7 +104,7 @@ func (s *Service) wirePipeline(p *Pipeline) { default: // Drop if sink buffer is full, may flood logging for slow client s.logger.Debug("msg", "Dropped log entry - sink buffer full", - "pipeline", p.Name) + "pipeline", p.Config.Name) } } } @@ -238,16 +114,16 @@ func (s *Service) wirePipeline(p *Pipeline) { } // Creates a source instance based on configuration -func (s *Service) createSource(cfg config.SourceConfig) (source.Source, error) { +func (s *Service) createSource(cfg *config.SourceConfig) (source.Source, error) { switch cfg.Type { case "directory": - return source.NewDirectorySource(cfg.Options, s.logger) + return source.NewDirectorySource(cfg.Directory, s.logger) case "stdin": - return source.NewStdinSource(cfg.Options, s.logger) + return source.NewStdinSource(cfg.Stdin, s.logger) case "http": - return source.NewHTTPSource(cfg.Options, s.logger) + return source.NewHTTPSource(cfg.HTTP, s.logger) case "tcp": - return source.NewTCPSource(cfg.Options, s.logger) + return source.NewTCPSource(cfg.TCP, s.logger) default: return nil, fmt.Errorf("unknown source type: %s", cfg.Type) } @@ -255,34 +131,28 @@ func (s *Service) createSource(cfg config.SourceConfig) (source.Source, error) { // Creates a sink instance based on configuration func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) (sink.Sink, error) { - if formatter == nil { - // Default formatters for different sink types - defaultFormat := "raw" - switch cfg.Type { - case "http", "tcp", "http_client", "tcp_client": - defaultFormat = "json" - } - - var err error - formatter, err = format.NewFormatter(defaultFormat, nil, s.logger) - if err != nil { - return nil, fmt.Errorf("failed to create default formatter: %w", err) - } - } switch cfg.Type { case "http": - return sink.NewHTTPSink(cfg.Options, s.logger, formatter) + if cfg.HTTP == nil { + return nil, fmt.Errorf("HTTP sink configuration missing") + } + return sink.NewHTTPSink(cfg.HTTP, s.logger, formatter) + case "tcp": - return sink.NewTCPSink(cfg.Options, s.logger, formatter) + if cfg.TCP == nil { + return nil, fmt.Errorf("TCP sink configuration missing") + } + return sink.NewTCPSink(cfg.TCP, s.logger, formatter) + case "http_client": - return sink.NewHTTPClientSink(cfg.Options, s.logger, formatter) + return sink.NewHTTPClientSink(cfg.HTTPClient, s.logger, formatter) case "tcp_client": - return sink.NewTCPClientSink(cfg.Options, s.logger, formatter) + return sink.NewTCPClientSink(cfg.TCPClient, s.logger, formatter) case "file": - return sink.NewFileSink(cfg.Options, s.logger, formatter) + return sink.NewFileSink(cfg.File, s.logger, formatter) case "console": - return sink.NewConsoleSink(cfg.Options, s.logger, formatter) + return sink.NewConsoleSink(cfg.Console, s.logger, formatter) default: return nil, fmt.Errorf("unknown sink type: %s", cfg.Type) } diff --git a/src/internal/sink/console.go b/src/internal/sink/console.go index 39c61a3..6ac6139 100644 --- a/src/internal/sink/console.go +++ b/src/internal/sink/console.go @@ -18,6 +18,7 @@ import ( // ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance type ConsoleSink struct { + config *config.ConsoleSinkOptions input chan core.LogEntry writer *log.Logger // Dedicated internal logger instance for console writing done chan struct{} @@ -31,22 +32,24 @@ type ConsoleSink struct { } // Creates a new console sink -func NewConsoleSink(options map[string]any, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) { - target := "stdout" - if t, ok := options["target"].(string); ok { - target = t +func NewConsoleSink(opts *config.ConsoleSinkOptions, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) { + if opts == nil { + return nil, fmt.Errorf("console sink options cannot be nil") } - bufferSize := int64(1000) - if buf, ok := options["buffer_size"].(int64); ok && buf > 0 { - bufferSize = buf + // Set defaults if not configured + if opts.Target == "" { + opts.Target = "stdout" + } + if opts.BufferSize <= 0 { + opts.BufferSize = 1000 } // Dedicated logger instance as console writer writer, err := log.NewBuilder(). EnableFile(false). EnableConsole(true). - ConsoleTarget(target). + ConsoleTarget(opts.Target). Format("raw"). // Passthrough pre-formatted messages ShowTimestamp(false). // Disable writer's own timestamp ShowLevel(false). // Disable writer's own level prefix @@ -57,7 +60,8 @@ func NewConsoleSink(options map[string]any, appLogger *log.Logger, formatter for } s := &ConsoleSink{ - input: make(chan core.LogEntry, bufferSize), + config: opts, + input: make(chan core.LogEntry, opts.BufferSize), writer: writer, done: make(chan struct{}), startTime: time.Now(), @@ -156,8 +160,4 @@ func (s *ConsoleSink) processLoop(ctx context.Context) { return } } -} - -func (s *ConsoleSink) SetAuth(auth *config.AuthConfig) { - // Authentication does not apply to the console sink. } \ No newline at end of file diff --git a/src/internal/sink/file.go b/src/internal/sink/file.go index 28815f6..e8d801c 100644 --- a/src/internal/sink/file.go +++ b/src/internal/sink/file.go @@ -5,10 +5,10 @@ import ( "bytes" "context" "fmt" - "logwisp/src/internal/config" "sync/atomic" "time" + "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" @@ -17,6 +17,7 @@ import ( // Writes log entries to files with rotation type FileSink struct { + config *config.FileSinkOptions input chan core.LogEntry writer *log.Logger // Internal logger instance for file writing done chan struct{} @@ -30,64 +31,27 @@ type FileSink struct { } // Creates a new file sink -func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*FileSink, error) { - directory, ok := options["directory"].(string) - if !ok || directory == "" { - directory = "./" - logger.Warn("No directory or invalid directory provided, current directory will be used") - } - - name, ok := options["name"].(string) - if !ok || name == "" { - name = "logwisp.output" - logger.Warn(fmt.Sprintf("No filename provided, %s will be used", name)) +func NewFileSink(opts *config.FileSinkOptions, logger *log.Logger, formatter format.Formatter) (*FileSink, error) { + if opts == nil { + return nil, fmt.Errorf("file sink options cannot be nil") } // Create configuration for the internal log writer writerConfig := log.DefaultConfig() - writerConfig.Directory = directory - writerConfig.Name = name + writerConfig.Directory = opts.Directory + writerConfig.Name = opts.Name writerConfig.EnableConsole = false // File only writerConfig.ShowTimestamp = false // We already have timestamps in entries writerConfig.ShowLevel = false // We already have levels in entries - // Add optional configurations - if maxSize, ok := options["max_size_mb"].(int64); ok && maxSize > 0 { - writerConfig.MaxSizeKB = maxSize * 1000 - } - - if maxTotalSize, ok := options["max_total_size_mb"].(int64); ok && maxTotalSize >= 0 { - writerConfig.MaxTotalSizeKB = maxTotalSize * 1000 - } - - if retention, ok := options["retention_hours"].(int64); ok && retention > 0 { - writerConfig.RetentionPeriodHrs = float64(retention) - } - - if minDiskFree, ok := options["min_disk_free_mb"].(int64); ok && minDiskFree > 0 { - writerConfig.MinDiskFreeKB = minDiskFree * 1000 - } - // Create internal logger for file writing writer := log.NewLogger() if err := writer.ApplyConfig(writerConfig); err != nil { return nil, fmt.Errorf("failed to initialize file writer: %w", err) } - // Start the internal file writer - if err := writer.Start(); err != nil { - return nil, fmt.Errorf("failed to start file writer: %w", err) - } - - // Buffer size for input channel - // TODO: Centralized constant file in core package - bufferSize := int64(1000) - if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { - bufferSize = bufSize - } - fs := &FileSink{ - input: make(chan core.LogEntry, bufferSize), + input: make(chan core.LogEntry, opts.BufferSize), writer: writer, done: make(chan struct{}), startTime: time.Now(), @@ -104,6 +68,11 @@ func (fs *FileSink) Input() chan<- core.LogEntry { } func (fs *FileSink) Start(ctx context.Context) error { + // Start the internal file writer + if err := fs.writer.Start(); err != nil { + return fmt.Errorf("failed to start sink file writer: %w", err) + } + go fs.processLoop(ctx) fs.logger.Info("msg", "File sink started", "component", "file_sink") return nil @@ -166,8 +135,4 @@ func (fs *FileSink) processLoop(ctx context.Context) { return } } -} - -func (fs *FileSink) SetAuth(auth *config.AuthConfig) { - // Authentication does not apply to file sink } \ No newline at end of file diff --git a/src/internal/sink/http.go b/src/internal/sink/http.go index d9c3f25..8df9161 100644 --- a/src/internal/sink/http.go +++ b/src/internal/sink/http.go @@ -26,8 +26,11 @@ import ( // Streams log entries via Server-Sent Events type HTTPSink struct { + // Configuration reference (NOT a copy) + config *config.HTTPSinkOptions + + // Runtime input chan core.LogEntry - config HTTPConfig server *fasthttp.Server activeClients atomic.Int64 mu sync.RWMutex @@ -46,11 +49,7 @@ type HTTPSink struct { // Security components authenticator *auth.Authenticator tlsManager *tls.Manager - authConfig *config.AuthConfig - - // Path configuration - streamPath string - statusPath string + authConfig *config.ServerAuthConfig // Net limiting netLimiter *limit.NetLimiter @@ -62,151 +61,58 @@ type HTTPSink struct { authSuccesses atomic.Uint64 } -// Holds HTTP sink configuration -type HTTPConfig struct { - Host string - Port int64 - BufferSize int64 - StreamPath string - StatusPath string - Heartbeat *config.HeartbeatConfig - TLS *config.TLSConfig - NetLimit *config.NetLimitConfig -} - // Creates a new HTTP streaming sink -func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*HTTPSink, error) { - cfg := HTTPConfig{ - Host: "0.0.0.0", - Port: 8080, - BufferSize: 1000, - StreamPath: "/stream", - StatusPath: "/status", - } - - // Extract configuration from options - if host, ok := options["host"].(string); ok && host != "" { - cfg.Host = host - } - if port, ok := options["port"].(int64); ok { - cfg.Port = port - } - if bufSize, ok := options["buffer_size"].(int64); ok { - cfg.BufferSize = bufSize - } - if path, ok := options["stream_path"].(string); ok { - cfg.StreamPath = path - } - if path, ok := options["status_path"].(string); ok { - cfg.StatusPath = path - } - - // Extract heartbeat config - if hb, ok := options["heartbeat"].(map[string]any); ok { - cfg.Heartbeat = &config.HeartbeatConfig{} - cfg.Heartbeat.Enabled, _ = hb["enabled"].(bool) - if interval, ok := hb["interval_seconds"].(int64); ok { - cfg.Heartbeat.IntervalSeconds = interval - } - cfg.Heartbeat.IncludeTimestamp, _ = hb["include_timestamp"].(bool) - cfg.Heartbeat.IncludeStats, _ = hb["include_stats"].(bool) - if hbFormat, ok := hb["format"].(string); ok { - cfg.Heartbeat.Format = hbFormat - } - } - - // Extract TLS config - if tc, ok := options["tls"].(map[string]any); ok { - cfg.TLS = &config.TLSConfig{} - cfg.TLS.Enabled, _ = tc["enabled"].(bool) - if certFile, ok := tc["cert_file"].(string); ok { - cfg.TLS.CertFile = certFile - } - if keyFile, ok := tc["key_file"].(string); ok { - cfg.TLS.KeyFile = keyFile - } - cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool) - if caFile, ok := tc["client_ca_file"].(string); ok { - cfg.TLS.ClientCAFile = caFile - } - cfg.TLS.VerifyClientCert, _ = tc["verify_client_cert"].(bool) - if minVer, ok := tc["min_version"].(string); ok { - cfg.TLS.MinVersion = minVer - } - if maxVer, ok := tc["max_version"].(string); ok { - cfg.TLS.MaxVersion = maxVer - } - if ciphers, ok := tc["cipher_suites"].(string); ok { - cfg.TLS.CipherSuites = ciphers - } - } - - // Extract net limit config - if nl, ok := options["net_limit"].(map[string]any); ok { - cfg.NetLimit = &config.NetLimitConfig{} - cfg.NetLimit.Enabled, _ = nl["enabled"].(bool) - if rps, ok := nl["requests_per_second"].(float64); ok { - cfg.NetLimit.RequestsPerSecond = rps - } - if burst, ok := nl["burst_size"].(int64); ok { - cfg.NetLimit.BurstSize = burst - } - if respCode, ok := nl["response_code"].(int64); ok { - cfg.NetLimit.ResponseCode = respCode - } - if msg, ok := nl["response_message"].(string); ok { - cfg.NetLimit.ResponseMessage = msg - } - if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok { - cfg.NetLimit.MaxConnectionsPerIP = maxPerIP - } - if maxTotal, ok := nl["max_connections_total"].(int64); ok { - cfg.NetLimit.MaxConnectionsTotal = maxTotal - } - if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok { - cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist)) - for _, entry := range ipWhitelist { - if str, ok := entry.(string); ok { - cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str) - } - } - } - if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok { - cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist)) - for _, entry := range ipBlacklist { - if str, ok := entry.(string); ok { - cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str) - } - } - } +func NewHTTPSink(opts *config.HTTPSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPSink, error) { + if opts == nil { + return nil, fmt.Errorf("HTTP sink options cannot be nil") } h := &HTTPSink{ - input: make(chan core.LogEntry, cfg.BufferSize), - config: cfg, - startTime: time.Now(), - done: make(chan struct{}), - streamPath: cfg.StreamPath, - statusPath: cfg.StatusPath, - logger: logger, - formatter: formatter, - clients: make(map[uint64]chan core.LogEntry), - unregister: make(chan uint64, 10), // Buffered for non-blocking + config: opts, // Direct reference to config struct + input: make(chan core.LogEntry, opts.BufferSize), + startTime: time.Now(), + done: make(chan struct{}), + logger: logger, + formatter: formatter, + clients: make(map[uint64]chan core.LogEntry), } + h.lastProcessed.Store(time.Time{}) - // Initialize TLS manager - if cfg.TLS != nil && cfg.TLS.Enabled { - tlsManager, err := tls.NewManager(cfg.TLS, logger) + // Initialize TLS manager if configured + if opts.TLS != nil && opts.TLS.Enabled { + tlsManager, err := tls.NewManager(opts.TLS, logger) if err != nil { return nil, fmt.Errorf("failed to create TLS manager: %w", err) } h.tlsManager = tlsManager + logger.Info("msg", "TLS enabled", + "component", "http_sink") } // Initialize net limiter if configured - if cfg.NetLimit != nil && cfg.NetLimit.Enabled { - h.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger) + if opts.NetLimit != nil && (opts.NetLimit.Enabled || + len(opts.NetLimit.IPWhitelist) > 0 || + len(opts.NetLimit.IPBlacklist) > 0) { + h.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger) + } + + // Initialize authenticator if auth is not "none" + if opts.Auth != nil && opts.Auth.Type != "none" { + // Only "basic" and "token" are valid for HTTP sink + if opts.Auth.Type != "basic" && opts.Auth.Type != "token" { + return nil, fmt.Errorf("invalid auth type '%s' for HTTP sink (valid: none, basic, token)", opts.Auth.Type) + } + + authenticator, err := auth.NewAuthenticator(opts.Auth, logger) + if err != nil { + return nil, fmt.Errorf("failed to create authenticator: %w", err) + } + h.authenticator = authenticator + h.authConfig = opts.Auth + logger.Info("msg", "Authentication enabled", + "component", "http_sink", + "type", opts.Auth.Type) } return h, nil @@ -230,6 +136,9 @@ func (h *HTTPSink) Start(ctx context.Context) error { DisableKeepalive: false, StreamRequestBody: true, Logger: fasthttpLogger, + // ReadTimeout: time.Duration(h.config.ReadTimeout) * time.Millisecond, + WriteTimeout: time.Duration(h.config.WriteTimeout) * time.Millisecond, + // MaxRequestBodySize: int(h.config.MaxBodySize), } // Configure TLS if enabled @@ -250,8 +159,8 @@ func (h *HTTPSink) Start(ctx context.Context) error { "component", "http_sink", "host", h.config.Host, "port", h.config.Port, - "stream_path", h.streamPath, - "status_path", h.statusPath, + "stream_path", h.config.StreamPath, + "status_path", h.config.StatusPath, "tls_enabled", h.tlsManager != nil) var err error @@ -296,7 +205,7 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) { var tickerChan <-chan time.Time if h.config.Heartbeat != nil && h.config.Heartbeat.Enabled { - ticker = time.NewTicker(time.Duration(h.config.Heartbeat.IntervalSeconds) * time.Second) + ticker = time.NewTicker(time.Duration(h.config.Heartbeat.Interval) * time.Second) tickerChan = ticker.C defer ticker.Stop() } @@ -441,8 +350,8 @@ func (h *HTTPSink) GetStats() SinkStats { "port": h.config.Port, "buffer_size": h.config.BufferSize, "endpoints": map[string]string{ - "stream": h.streamPath, - "status": h.statusPath, + "stream": h.config.StreamPath, + "status": h.config.StatusPath, }, "net_limit": netLimitStats, "auth": authStats, @@ -489,7 +398,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) // Status endpoint doesn't require auth - if path == h.statusPath { + if path == h.config.StatusPath { h.handleStatus(ctx) return } @@ -509,14 +418,14 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { // Return 401 with WWW-Authenticate header ctx.SetStatusCode(fasthttp.StatusUnauthorized) - if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil { - realm := h.authConfig.BasicAuth.Realm + if h.authConfig.Type == "basic" && h.authConfig.Basic != nil { + realm := h.authConfig.Basic.Realm if realm == "" { realm = "Restricted" } ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm)) - } else if h.authConfig.Type == "bearer" { - ctx.Response.Header.Set("WWW-Authenticate", "Bearer") + } else if h.authConfig.Type == "token" { + ctx.Response.Header.Set("WWW-Authenticate", "Token") } ctx.SetContentType("application/json") @@ -538,7 +447,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { } switch path { - case h.streamPath: + case h.config.StreamPath: h.handleStream(ctx, session) default: ctx.SetStatusCode(fasthttp.StatusNotFound) @@ -547,6 +456,15 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { "error": "Not Found", }) } + // Handle stream endpoint + // if path == h.config.StreamPath { + // h.handleStream(ctx, session) + // return + // } + // + // // Unknown path + // ctx.SetStatusCode(fasthttp.StatusNotFound) + // ctx.SetBody([]byte("Not Found")) } func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) { @@ -611,8 +529,8 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) "client_id": fmt.Sprintf("%d", clientID), "username": session.Username, "auth_method": session.Method, - "stream_path": h.streamPath, - "status_path": h.statusPath, + "stream_path": h.config.StreamPath, + "status_path": h.config.StatusPath, "buffer_size": h.config.BufferSize, "tls": h.tlsManager != nil, } @@ -627,7 +545,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) var tickerChan <-chan time.Time if h.config.Heartbeat != nil && h.config.Heartbeat.Enabled { - ticker = time.NewTicker(time.Duration(h.config.Heartbeat.IntervalSeconds) * time.Second) + ticker = time.NewTicker(time.Duration(h.config.Heartbeat.Interval) * time.Second) tickerChan = ticker.C defer ticker.Stop() } @@ -716,7 +634,7 @@ func (h *HTTPSink) createHeartbeatEntry() core.LogEntry { fields := make(map[string]any) fields["type"] = "heartbeat" - if h.config.Heartbeat.IncludeStats { + if h.config.Heartbeat.Enabled { fields["active_clients"] = h.activeClients.Load() fields["uptime_seconds"] = int(time.Since(h.startTime).Seconds()) } @@ -775,13 +693,13 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { "uptime_seconds": int(time.Since(h.startTime).Seconds()), }, "endpoints": map[string]string{ - "transport": h.streamPath, - "status": h.statusPath, + "transport": h.config.StreamPath, + "status": h.config.StatusPath, }, "features": map[string]any{ "heartbeat": map[string]any{ "enabled": h.config.Heartbeat.Enabled, - "interval": h.config.Heartbeat.IntervalSeconds, + "interval": h.config.Heartbeat.Interval, "format": h.config.Heartbeat.Format, }, "tls": tlsStats, @@ -806,37 +724,15 @@ func (h *HTTPSink) GetActiveConnections() int64 { // Returns the configured transport endpoint path func (h *HTTPSink) GetStreamPath() string { - return h.streamPath + return h.config.StreamPath } // Returns the configured status endpoint path func (h *HTTPSink) GetStatusPath() string { - return h.statusPath + return h.config.StatusPath } // Returns the configured host func (h *HTTPSink) GetHost() string { return h.config.Host -} - -// Configures http sink auth -func (h *HTTPSink) SetAuth(authCfg *config.AuthConfig) { - if authCfg == nil || authCfg.Type == "none" { - return - } - - h.authConfig = authCfg - authenticator, err := auth.NewAuthenticator(authCfg, h.logger) - if err != nil { - h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink", - "component", "http_sink", - "error", err) - // Continue without auth - return - } - h.authenticator = authenticator - - h.logger.Info("msg", "Authentication configured for HTTP sink", - "component", "http_sink", - "auth_type", authCfg.Type) } \ No newline at end of file diff --git a/src/internal/sink/http_client.go b/src/internal/sink/http_client.go index fbfae32..183befc 100644 --- a/src/internal/sink/http_client.go +++ b/src/internal/sink/http_client.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "encoding/base64" "fmt" - "net/url" "os" "strings" "sync" @@ -28,7 +27,7 @@ import ( // Forwards log entries to a remote HTTP endpoint type HTTPClientSink struct { input chan core.LogEntry - config HTTPClientConfig + config *config.HTTPClientSinkOptions client *fasthttp.Client batch []core.LogEntry batchMu sync.Mutex @@ -48,195 +47,16 @@ type HTTPClientSink struct { activeConnections atomic.Int64 } -// Holds HTTP client sink configuration -// TODO: missing toml tags -type HTTPClientConfig struct { - // Config - URL string `toml:"url"` - BufferSize int64 `toml:"buffer_size"` - BatchSize int64 `toml:"batch_size"` - BatchDelay time.Duration `toml:"batch_delay_ms"` - Timeout time.Duration `toml:"timeout_seconds"` - Headers map[string]string `toml:"headers"` - - // Retry configuration - MaxRetries int64 `toml:"max_retries"` - RetryDelay time.Duration `toml:"retry_delay"` - RetryBackoff float64 `toml:"retry_backoff"` // Multiplier for exponential backoff - - // Security - AuthType string `toml:"auth_type"` // "none", "basic", "bearer", "mtls" - Username string `toml:"username"` // For basic auth - Password string `toml:"password"` // For basic auth - BearerToken string `toml:"bearer_token"` // For bearer auth - - // TLS configuration - InsecureSkipVerify bool `toml:"insecure_skip_verify"` - CAFile string `toml:"ca_file"` - CertFile string `toml:"cert_file"` - KeyFile string `toml:"key_file"` -} - // Creates a new HTTP client sink -func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*HTTPClientSink, error) { - cfg := HTTPClientConfig{ - BufferSize: int64(1000), - BatchSize: int64(100), - BatchDelay: time.Second, - Timeout: 30 * time.Second, - MaxRetries: int64(3), - RetryDelay: time.Second, - RetryBackoff: float64(2.0), - Headers: make(map[string]string), - } - - // Extract URL - urlStr, ok := options["url"].(string) - if !ok || urlStr == "" { - return nil, fmt.Errorf("http_client sink requires 'url' option") - } - - // Validate URL - parsedURL, err := url.Parse(urlStr) - if err != nil { - return nil, fmt.Errorf("invalid URL: %w", err) - } - if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return nil, fmt.Errorf("URL must use http or https scheme") - } - cfg.URL = urlStr - - // Extract other options - if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { - cfg.BufferSize = bufSize - } - if batchSize, ok := options["batch_size"].(int64); ok && batchSize > 0 { - cfg.BatchSize = batchSize - } - if delayMs, ok := options["batch_delay_ms"].(int64); ok && delayMs > 0 { - cfg.BatchDelay = time.Duration(delayMs) * time.Millisecond - } - if timeoutSec, ok := options["timeout_seconds"].(int64); ok && timeoutSec > 0 { - cfg.Timeout = time.Duration(timeoutSec) * time.Second - } - if maxRetries, ok := options["max_retries"].(int64); ok && maxRetries >= 0 { - cfg.MaxRetries = maxRetries - } - if retryDelayMs, ok := options["retry_delay_ms"].(int64); ok && retryDelayMs > 0 { - cfg.RetryDelay = time.Duration(retryDelayMs) * time.Millisecond - } - if backoff, ok := options["retry_backoff"].(float64); ok && backoff >= 1.0 { - cfg.RetryBackoff = backoff - } - if insecure, ok := options["insecure_skip_verify"].(bool); ok { - cfg.InsecureSkipVerify = insecure - } - if authType, ok := options["auth_type"].(string); ok { - switch authType { - case "none", "basic", "bearer", "mtls": - cfg.AuthType = authType - default: - return nil, fmt.Errorf("http_client sink: invalid auth_type '%s'", authType) - } - } else { - cfg.AuthType = "none" - } - if username, ok := options["username"].(string); ok { - cfg.Username = username - } - if password, ok := options["password"].(string); ok { - cfg.Password = password // TODO: change to Argon2 hashed password - } - if token, ok := options["bearer_token"].(string); ok { - cfg.BearerToken = token - } - - // Validate auth configuration and TLS enforcement - isHTTPS := strings.HasPrefix(cfg.URL, "https://") - - switch cfg.AuthType { - case "basic": - if cfg.Username == "" || cfg.Password == "" { - return nil, fmt.Errorf("http_client sink: username and password required for basic auth") - } - if !isHTTPS { - return nil, fmt.Errorf("http_client sink: basic auth requires HTTPS (security: credentials would be sent in plaintext)") - } - - case "bearer": - if cfg.BearerToken == "" { - return nil, fmt.Errorf("http_client sink: bearer_token required for bearer auth") - } - if !isHTTPS { - return nil, fmt.Errorf("http_client sink: bearer auth requires HTTPS (security: token would be sent in plaintext)") - } - - case "mtls": - if !isHTTPS { - return nil, fmt.Errorf("http_client sink: mTLS requires HTTPS") - } - if cfg.CertFile == "" || cfg.KeyFile == "" { - return nil, fmt.Errorf("http_client sink: cert_file and key_file required for mTLS") - } - - case "none": - // Clear any credentials if auth is "none" - if cfg.Username != "" || cfg.Password != "" || cfg.BearerToken != "" { - logger.Warn("msg", "Credentials provided but auth_type is 'none', ignoring", - "component", "http_client_sink") - cfg.Username = "" - cfg.Password = "" - cfg.BearerToken = "" - } - } - - // Extract headers - if headers, ok := options["headers"].(map[string]any); ok { - for k, v := range headers { - if strVal, ok := v.(string); ok { - cfg.Headers[k] = strVal - } - } - } - - // Set default Content-Type if not specified - if _, exists := cfg.Headers["Content-Type"]; !exists { - cfg.Headers["Content-Type"] = "application/json" - } - - // Extract TLS options - if caFile, ok := options["ca_file"].(string); ok && caFile != "" { - cfg.CAFile = caFile - } - - // Extract client certificate options from TLS config - if tc, ok := options["tls"].(map[string]any); ok { - if enabled, _ := tc["enabled"].(bool); enabled { - // Extract client certificate files for mTLS - if certFile, ok := tc["cert_file"].(string); ok && certFile != "" { - if keyFile, ok := tc["key_file"].(string); ok && keyFile != "" { - // These will be used below when configuring TLS - cfg.CertFile = certFile // Need to add these fields to HTTPClientConfig - cfg.KeyFile = keyFile - } - } - // Extract CA file from TLS config if not already set - if cfg.CAFile == "" { - if caFile, ok := tc["ca_file"].(string); ok { - cfg.CAFile = caFile - } - } - // Extract insecure skip verify from TLS config - if insecure, ok := tc["insecure_skip_verify"].(bool); ok { - cfg.InsecureSkipVerify = insecure - } - } +func NewHTTPClientSink(opts *config.HTTPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPClientSink, error) { + if opts == nil { + return nil, fmt.Errorf("HTTP client sink options cannot be nil") } h := &HTTPClientSink{ - input: make(chan core.LogEntry, cfg.BufferSize), - config: cfg, - batch: make([]core.LogEntry, 0, cfg.BatchSize), + config: opts, + input: make(chan core.LogEntry, opts.BufferSize), + batch: make([]core.LogEntry, 0, opts.BatchSize), done: make(chan struct{}), startTime: time.Now(), logger: logger, @@ -249,46 +69,48 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for h.client = &fasthttp.Client{ MaxConnsPerHost: 10, MaxIdleConnDuration: 10 * time.Second, - ReadTimeout: cfg.Timeout, - WriteTimeout: cfg.Timeout, + ReadTimeout: time.Duration(opts.Timeout) * time.Second, + WriteTimeout: time.Duration(opts.Timeout) * time.Second, DisableHeaderNamesNormalizing: true, } // Configure TLS if using HTTPS - if strings.HasPrefix(cfg.URL, "https://") { + if strings.HasPrefix(opts.URL, "https://") { tlsConfig := &tls.Config{ - InsecureSkipVerify: cfg.InsecureSkipVerify, + InsecureSkipVerify: opts.InsecureSkipVerify, } - // Load custom CA for server verification if provided - if cfg.CAFile != "" { - caCert, err := os.ReadFile(cfg.CAFile) - if err != nil { - return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.CAFile, err) + // Use TLS config if provided + if opts.TLS != nil { + // Load custom CA for server verification + if opts.TLS.CAFile != "" { + caCert, err := os.ReadFile(opts.TLS.CAFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA file '%s': %w", opts.TLS.CAFile, err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse CA certificate from '%s'", opts.TLS.CAFile) + } + tlsConfig.RootCAs = caCertPool + logger.Debug("msg", "Custom CA loaded for server verification", + "component", "http_client_sink", + "ca_file", opts.TLS.CAFile) } - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.CAFile) + + // Load client certificate for mTLS if provided + if opts.TLS.CertFile != "" && opts.TLS.KeyFile != "" { + cert, err := tls.LoadX509KeyPair(opts.TLS.CertFile, opts.TLS.KeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + logger.Info("msg", "Client certificate loaded for mTLS", + "component", "http_client_sink", + "cert_file", opts.TLS.CertFile) } - tlsConfig.RootCAs = caCertPool - logger.Debug("msg", "Custom CA loaded for server verification", - "component", "http_client_sink", - "ca_file", cfg.CAFile) } - // Load client certificate for mTLS if provided - if cfg.CertFile != "" && cfg.KeyFile != "" { - cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) - if err != nil { - return nil, fmt.Errorf("failed to load client certificate: %w", err) - } - tlsConfig.Certificates = []tls.Certificate{cert} - logger.Info("msg", "Client certificate loaded for mTLS", - "component", "http_client_sink", - "cert_file", cfg.CertFile) - } - - // Set TLS config directly on the client h.client.TLSConfig = tlsConfig } @@ -308,7 +130,7 @@ func (h *HTTPClientSink) Start(ctx context.Context) error { "component", "http_client_sink", "url", h.config.URL, "batch_size", h.config.BatchSize, - "batch_delay", h.config.BatchDelay) + "batch_delay_ms", h.config.BatchDelayMS) return nil } @@ -399,7 +221,7 @@ func (h *HTTPClientSink) processLoop(ctx context.Context) { func (h *HTTPClientSink) batchTimer(ctx context.Context) { defer h.wg.Done() - ticker := time.NewTicker(h.config.BatchDelay) + ticker := time.NewTicker(time.Duration(h.config.BatchDelayMS) * time.Millisecond) defer ticker.Stop() for { @@ -468,7 +290,7 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { // Retry logic var lastErr error - retryDelay := h.config.RetryDelay + retryDelay := time.Duration(h.config.RetryDelayMS) * time.Millisecond // TODO: verify retry loop placement is correct or should it be after acquiring resources (req :=....) for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ { @@ -480,9 +302,10 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { newDelay := time.Duration(float64(retryDelay) * h.config.RetryBackoff) // Cap at maximum to prevent integer overflow - if newDelay > h.config.Timeout || newDelay < retryDelay { + timeout := time.Duration(h.config.Timeout) * time.Second + if newDelay > timeout || newDelay < retryDelay { // Either exceeded max or overflowed (negative/wrapped) - retryDelay = h.config.Timeout + retryDelay = timeout } else { retryDelay = newDelay } @@ -500,14 +323,14 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short())) // Add authentication based on auth type - switch h.config.AuthType { + switch h.config.Auth.Type { case "basic": - creds := h.config.Username + ":" + h.config.Password + creds := h.config.Auth.Username + ":" + h.config.Auth.Password encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds)) req.Header.Set("Authorization", "Basic "+encodedCreds) - case "bearer": - req.Header.Set("Authorization", "Bearer "+h.config.BearerToken) + case "token": + req.Header.Set("Authorization", "Token "+h.config.Auth.Token) case "mtls": // mTLS auth is handled at TLS layer via client certificates @@ -523,7 +346,7 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { } // Send request - err := h.client.DoTimeout(req, resp, h.config.Timeout) + err := h.client.DoTimeout(req, resp, time.Duration(h.config.Timeout)*time.Second) // Capture response before releasing statusCode := resp.StatusCode() @@ -587,10 +410,4 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { "retries", h.config.MaxRetries, "last_error", lastErr) h.failedBatches.Add(1) -} - -// Not applicable, Clients authenticate to remote servers using Username/Password in config -func (h *HTTPClientSink) SetAuth(authCfg *config.AuthConfig) { - // No-op: client sinks don't validate incoming connections - // They authenticate to remote servers using Username/Password fields } \ No newline at end of file diff --git a/src/internal/sink/sink.go b/src/internal/sink/sink.go index 51d6d45..140b8b4 100644 --- a/src/internal/sink/sink.go +++ b/src/internal/sink/sink.go @@ -5,7 +5,6 @@ import ( "context" "time" - "logwisp/src/internal/config" "logwisp/src/internal/core" ) @@ -22,9 +21,6 @@ type Sink interface { // Returns sink statistics GetStats() SinkStats - - // Configure authentication - SetAuth(auth *config.AuthConfig) } // Contains statistics about a sink diff --git a/src/internal/sink/tcp.go b/src/internal/sink/tcp.go index 5d8085c..22f34eb 100644 --- a/src/internal/sink/tcp.go +++ b/src/internal/sink/tcp.go @@ -7,7 +7,6 @@ import ( "encoding/json" "fmt" "net" - "strings" "sync" "sync/atomic" "time" @@ -25,26 +24,22 @@ import ( // Streams log entries via TCP type TCPSink struct { - // C - input chan core.LogEntry - config TCPConfig - server *tcpServer - done chan struct{} - activeConns atomic.Int64 - startTime time.Time - engine *gnet.Engine - engineMu sync.Mutex - wg sync.WaitGroup - netLimiter *limit.NetLimiter - logger *log.Logger - formatter format.Formatter - authenticator *auth.Authenticator + input chan core.LogEntry + config *config.TCPSinkOptions + server *tcpServer + done chan struct{} + activeConns atomic.Int64 + startTime time.Time + engine *gnet.Engine + engineMu sync.Mutex + wg sync.WaitGroup + netLimiter *limit.NetLimiter + logger *log.Logger + formatter format.Formatter // Statistics totalProcessed atomic.Uint64 lastProcessed atomic.Value // time.Time - authFailures atomic.Uint64 - authSuccesses atomic.Uint64 // Write error tracking writeErrors atomic.Uint64 @@ -62,87 +57,14 @@ type TCPConfig struct { } // Creates a new TCP streaming sink -func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*TCPSink, error) { - cfg := TCPConfig{ - Host: "0.0.0.0", - Port: int64(9090), - BufferSize: int64(1000), - } - - // Extract configuration from options - if host, ok := options["host"].(string); ok && host != "" { - cfg.Host = host - } - if port, ok := options["port"].(int64); ok { - cfg.Port = port - } - if bufSize, ok := options["buffer_size"].(int64); ok { - cfg.BufferSize = bufSize - } - - // Extract heartbeat config - if hb, ok := options["heartbeat"].(map[string]any); ok { - cfg.Heartbeat = &config.HeartbeatConfig{} - cfg.Heartbeat.Enabled, _ = hb["enabled"].(bool) - if interval, ok := hb["interval_seconds"].(int64); ok { - cfg.Heartbeat.IntervalSeconds = interval - } - cfg.Heartbeat.IncludeTimestamp, _ = hb["include_timestamp"].(bool) - cfg.Heartbeat.IncludeStats, _ = hb["include_stats"].(bool) - if hbFormat, ok := hb["format"].(string); ok { - cfg.Heartbeat.Format = hbFormat - } - } - - // Extract net limit config - if nl, ok := options["net_limit"].(map[string]any); ok { - cfg.NetLimit = &config.NetLimitConfig{} - cfg.NetLimit.Enabled, _ = nl["enabled"].(bool) - if rps, ok := nl["requests_per_second"].(float64); ok { - cfg.NetLimit.RequestsPerSecond = rps - } - if burst, ok := nl["burst_size"].(int64); ok { - cfg.NetLimit.BurstSize = burst - } - if respCode, ok := nl["response_code"].(int64); ok { - cfg.NetLimit.ResponseCode = respCode - } - if msg, ok := nl["response_message"].(string); ok { - cfg.NetLimit.ResponseMessage = msg - } - if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok { - cfg.NetLimit.MaxConnectionsPerIP = maxPerIP - } - if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok { - cfg.NetLimit.MaxConnectionsPerUser = maxPerUser - } - if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok { - cfg.NetLimit.MaxConnectionsPerToken = maxPerToken - } - if maxTotal, ok := nl["max_connections_total"].(int64); ok { - cfg.NetLimit.MaxConnectionsTotal = maxTotal - } - if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok { - cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist)) - for _, entry := range ipWhitelist { - if str, ok := entry.(string); ok { - cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str) - } - } - } - if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok { - cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist)) - for _, entry := range ipBlacklist { - if str, ok := entry.(string); ok { - cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str) - } - } - } +func NewTCPSink(opts *config.TCPSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPSink, error) { + if opts == nil { + return nil, fmt.Errorf("TCP sink options cannot be nil") } t := &TCPSink{ - input: make(chan core.LogEntry, cfg.BufferSize), - config: cfg, + config: opts, // Direct reference to config + input: make(chan core.LogEntry, opts.BufferSize), done: make(chan struct{}), startTime: time.Now(), logger: logger, @@ -150,9 +72,11 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For } t.lastProcessed.Store(time.Time{}) - // Initialize net limiter - if cfg.NetLimit != nil && cfg.NetLimit.Enabled { - t.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger) + // Initialize net limiter with pointer + if opts.NetLimit != nil && (opts.NetLimit.Enabled || + len(opts.NetLimit.IPWhitelist) > 0 || + len(opts.NetLimit.IPBlacklist) > 0) { + t.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger) } return t, nil @@ -193,8 +117,7 @@ func (t *TCPSink) Start(ctx context.Context) error { go func() { t.logger.Info("msg", "Starting TCP server", "component", "tcp_sink", - "port", t.config.Port, - "auth", t.authenticator != nil) + "port", t.config.Port) err := gnet.Run(t.server, addr, opts...) if err != nil { @@ -282,7 +205,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) { var tickerChan <-chan time.Time if t.config.Heartbeat != nil && t.config.Heartbeat.Enabled { - ticker = time.NewTicker(time.Duration(t.config.Heartbeat.IntervalSeconds) * time.Second) + ticker = time.NewTicker(time.Duration(t.config.Heartbeat.Interval) * time.Second) tickerChan = ticker.C defer ticker.Stop() } @@ -329,21 +252,19 @@ func (t *TCPSink) broadcastData(data []byte) { t.server.mu.RLock() defer t.server.mu.RUnlock() - for conn, client := range t.server.clients { - if client.authenticated { - conn.AsyncWrite(data, func(c gnet.Conn, err error) error { - if err != nil { - t.writeErrors.Add(1) - t.handleWriteError(c, err) - } else { - // Reset consecutive error count on success - t.errorMu.Lock() - delete(t.consecutiveWriteErrors, c) - t.errorMu.Unlock() - } - return nil - }) - } + for conn, _ := range t.server.clients { + conn.AsyncWrite(data, func(c gnet.Conn, err error) error { + if err != nil { + t.writeErrors.Add(1) + t.handleWriteError(c, err) + } else { + // Reset consecutive error count on success + t.errorMu.Lock() + delete(t.consecutiveWriteErrors, c) + t.errorMu.Unlock() + } + return nil + }) } } @@ -408,11 +329,10 @@ func (t *TCPSink) GetActiveConnections() int64 { // Represents a connected TCP client with auth state type tcpClient struct { - conn gnet.Conn - buffer bytes.Buffer - authenticated bool - authTimeout time.Time - session *auth.Session + conn gnet.Conn + buffer bytes.Buffer + authTimeout time.Time + session *auth.Session } // Handles gnet events with authentication @@ -439,7 +359,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { remoteAddr := c.RemoteAddr() s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr) - // Reject IPv6 connections immediately + // Reject IPv6 connections if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok { if tcpAddr.IP.To4() == nil { return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close @@ -467,14 +387,10 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { s.sink.netLimiter.AddConnection(remoteStr) } - // Create client state without auth timeout initially + // TCP Sink accepts all connections without authentication client := &tcpClient{ - conn: c, - authenticated: s.sink.authenticator == nil, - } - - if s.sink.authenticator != nil { - client.authTimeout = time.Now().Add(30 * time.Second) + conn: c, + buffer: bytes.Buffer{}, } s.mu.Lock() @@ -484,13 +400,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { newCount := s.sink.activeConns.Add(1) s.sink.logger.Debug("msg", "TCP connection opened", "remote_addr", remoteAddr, - "active_connections", newCount, - "auth_enabled", s.sink.authenticator != nil) - - // Send auth prompt if authentication is required - if s.sink.authenticator != nil { - return []byte("AUTH_REQUIRED\n"), gnet.None - } + "active_connections", newCount) return nil, gnet.None } @@ -522,96 +432,7 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action { } func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { - s.mu.RLock() - client, exists := s.clients[c] - s.mu.RUnlock() - - if !exists { - return gnet.Close - } - - // Authentication phase - if !client.authenticated { - // Check auth timeout - if time.Now().After(client.authTimeout) { - s.sink.logger.Warn("msg", "Authentication timeout", - "component", "tcp_sink", - "remote_addr", c.RemoteAddr().String()) - return gnet.Close - } - - // Read auth data - data, _ := c.Next(-1) - if len(data) == 0 { - return gnet.None - } - - client.buffer.Write(data) - - // Look for complete auth line - if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 { - line := client.buffer.Bytes()[:idx] - client.buffer.Next(idx + 1) - - // Parse AUTH command: AUTH - parts := strings.SplitN(string(line), " ", 3) - if len(parts) != 3 || parts[0] != "AUTH" { - c.AsyncWrite([]byte("AUTH_FAIL\n"), nil) - return gnet.Close - } - - // Authenticate - session, err := s.sink.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String()) - if err != nil { - s.sink.authFailures.Add(1) - s.sink.logger.Warn("msg", "TCP authentication failed", - "remote_addr", c.RemoteAddr().String(), - "method", parts[1], - "error", err) - c.AsyncWrite([]byte("AUTH_FAIL\n"), nil) - return gnet.Close - } - - // Authentication successful - s.sink.authSuccesses.Add(1) - s.mu.Lock() - client.authenticated = true - client.session = session - s.mu.Unlock() - - s.sink.logger.Info("msg", "TCP client authenticated", - "component", "tcp_sink", - "remote_addr", c.RemoteAddr().String(), - "username", session.Username, - "method", session.Method) - - c.AsyncWrite([]byte("AUTH_OK\n"), nil) - client.buffer.Reset() - } - return gnet.None - } - - // Clients shouldn't send data, just discard + // TCP Sink doesn't expect any data from clients, discard all c.Discard(-1) return gnet.None -} - -// Configures tcp sink auth -func (t *TCPSink) SetAuth(authCfg *config.AuthConfig) { - if authCfg == nil || authCfg.Type == "none" { - return - } - - authenticator, err := auth.NewAuthenticator(authCfg, t.logger) - if err != nil { - t.logger.Error("msg", "Failed to initialize authenticator for TCP sink", - "component", "tcp_sink", - "error", err) - return - } - t.authenticator = authenticator - - t.logger.Info("msg", "Authentication configured for TCP sink", - "component", "tcp_sink", - "auth_type", authCfg.Type) } \ No newline at end of file diff --git a/src/internal/sink/tcp_client.go b/src/internal/sink/tcp_client.go index 657e55c..c49a468 100644 --- a/src/internal/sink/tcp_client.go +++ b/src/internal/sink/tcp_client.go @@ -7,7 +7,9 @@ import ( "encoding/json" "errors" "fmt" + "logwisp/src/internal/auth" "net" + "strconv" "strings" "sync" "sync/atomic" @@ -16,7 +18,6 @@ import ( "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" - "logwisp/src/internal/scram" "github.com/lixenwraith/log" ) @@ -24,7 +25,8 @@ import ( // Forwards log entries to a remote TCP endpoint type TCPClientSink struct { input chan core.LogEntry - config TCPClientConfig + config *config.TCPClientSinkOptions + address string conn net.Conn connMu sync.RWMutex done chan struct{} @@ -46,101 +48,17 @@ type TCPClientSink struct { connectionUptime atomic.Value // time.Duration } -// Holds TCP client sink configuration -type TCPClientConfig struct { - Address string `toml:"address"` - BufferSize int64 `toml:"buffer_size"` - DialTimeout time.Duration `toml:"dial_timeout_seconds"` - WriteTimeout time.Duration `toml:"write_timeout_seconds"` - ReadTimeout time.Duration `toml:"read_timeout_seconds"` - KeepAlive time.Duration `toml:"keep_alive_seconds"` - - // Security - AuthType string `toml:"auth_type"` - Username string `toml:"username"` - Password string `toml:"password"` - - // Reconnection settings - ReconnectDelay time.Duration `toml:"reconnect_delay_ms"` - MaxReconnectDelay time.Duration `toml:"max_reconnect_delay_seconds"` - ReconnectBackoff float64 `toml:"reconnect_backoff"` -} - // Creates a new TCP client sink -func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) { - cfg := TCPClientConfig{ - BufferSize: int64(1000), - DialTimeout: 10 * time.Second, - WriteTimeout: 30 * time.Second, - ReadTimeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - ReconnectDelay: time.Second, - MaxReconnectDelay: 30 * time.Second, - ReconnectBackoff: float64(1.5), - } - - // Extract address - address, ok := options["address"].(string) - if !ok || address == "" { - return nil, fmt.Errorf("tcp_client sink requires 'address' option") - } - - // Validate address format - _, _, err := net.SplitHostPort(address) - if err != nil { - return nil, fmt.Errorf("invalid address format (expected host:port): %w", err) - } - cfg.Address = address - - // Extract other options - if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { - cfg.BufferSize = bufSize - } - if dialTimeout, ok := options["dial_timeout_seconds"].(int64); ok && dialTimeout > 0 { - cfg.DialTimeout = time.Duration(dialTimeout) * time.Second - } - if writeTimeout, ok := options["write_timeout_seconds"].(int64); ok && writeTimeout > 0 { - cfg.WriteTimeout = time.Duration(writeTimeout) * time.Second - } - if readTimeout, ok := options["read_timeout_seconds"].(int64); ok && readTimeout > 0 { - cfg.ReadTimeout = time.Duration(readTimeout) * time.Second - } - if keepAlive, ok := options["keep_alive_seconds"].(int64); ok && keepAlive > 0 { - cfg.KeepAlive = time.Duration(keepAlive) * time.Second - } - if reconnectDelay, ok := options["reconnect_delay_ms"].(int64); ok && reconnectDelay > 0 { - cfg.ReconnectDelay = time.Duration(reconnectDelay) * time.Millisecond - } - if maxReconnectDelay, ok := options["max_reconnect_delay_seconds"].(int64); ok && maxReconnectDelay > 0 { - cfg.MaxReconnectDelay = time.Duration(maxReconnectDelay) * time.Second - } - if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 { - cfg.ReconnectBackoff = backoff - } - if authType, ok := options["auth_type"].(string); ok { - switch authType { - case "none": - cfg.AuthType = authType - case "scram": - cfg.AuthType = authType - if username, ok := options["username"].(string); ok && username != "" { - cfg.Username = username - } else { - return nil, fmt.Errorf("invalid scram username") - } - if password, ok := options["password"].(string); ok && password != "" { - cfg.Password = password - } else { - return nil, fmt.Errorf("invalid scram password") - } - default: - return nil, fmt.Errorf("tcp_client sink: invalid auth_type '%s' (must be 'none' or 'scram')", authType) - } +func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) { + // Validation and defaults are handled in config package + if opts == nil { + return nil, fmt.Errorf("TCP client sink options cannot be nil") } t := &TCPClientSink{ - input: make(chan core.LogEntry, cfg.BufferSize), - config: cfg, + config: opts, + address: opts.Host + ":" + strconv.Itoa(int(opts.Port)), + input: make(chan core.LogEntry, opts.BufferSize), done: make(chan struct{}), startTime: time.Now(), logger: logger, @@ -167,7 +85,8 @@ func (t *TCPClientSink) Start(ctx context.Context) error { t.logger.Info("msg", "TCP client sink started", "component", "tcp_client_sink", - "address", t.config.Address) + "host", t.config.Host, + "port", t.config.Port) return nil } @@ -209,7 +128,7 @@ func (t *TCPClientSink) GetStats() SinkStats { StartTime: t.startTime, LastProcessed: lastProc, Details: map[string]any{ - "address": t.config.Address, + "address": t.address, "connected": connected, "reconnecting": t.reconnecting.Load(), "total_failed": t.totalFailed.Load(), @@ -223,7 +142,7 @@ func (t *TCPClientSink) GetStats() SinkStats { func (t *TCPClientSink) connectionManager(ctx context.Context) { defer t.wg.Done() - reconnectDelay := t.config.ReconnectDelay + reconnectDelay := time.Duration(t.config.ReconnectDelayMS) * time.Millisecond for { select { @@ -243,9 +162,9 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) { t.lastConnectErr = err t.logger.Warn("msg", "Failed to connect to TCP server", "component", "tcp_client_sink", - "address", t.config.Address, + "address", t.address, "error", err, - "retry_delay", reconnectDelay) + "retry_delay_ms", reconnectDelay) // Wait before retry select { @@ -258,15 +177,15 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) { // Exponential backoff reconnectDelay = time.Duration(float64(reconnectDelay) * t.config.ReconnectBackoff) - if reconnectDelay > t.config.MaxReconnectDelay { - reconnectDelay = t.config.MaxReconnectDelay + if reconnectDelay > time.Duration(t.config.MaxReconnectDelayMS)*time.Millisecond { + reconnectDelay = time.Duration(t.config.MaxReconnectDelayMS) } continue } // Connection successful t.lastConnectErr = nil - reconnectDelay = t.config.ReconnectDelay // Reset backoff + reconnectDelay = time.Duration(t.config.ReconnectDelayMS) * time.Millisecond // Reset backoff t.connectTime = time.Now() t.totalReconnects.Add(1) @@ -276,7 +195,7 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) { t.logger.Info("msg", "Connected to TCP server", "component", "tcp_client_sink", - "address", t.config.Address, + "address", t.address, "local_addr", conn.LocalAddr()) // Monitor connection @@ -293,18 +212,18 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) { t.logger.Warn("msg", "Lost connection to TCP server", "component", "tcp_client_sink", - "address", t.config.Address, + "address", t.address, "uptime", uptime) } } func (t *TCPClientSink) connect() (net.Conn, error) { dialer := &net.Dialer{ - Timeout: t.config.DialTimeout, - KeepAlive: t.config.KeepAlive, + Timeout: time.Duration(t.config.DialTimeout) * time.Second, + KeepAlive: time.Duration(t.config.KeepAlive) * time.Second, } - conn, err := dialer.Dial("tcp", t.config.Address) + conn, err := dialer.Dial("tcp", t.address) if err != nil { return nil, err } @@ -312,18 +231,18 @@ func (t *TCPClientSink) connect() (net.Conn, error) { // Set TCP keep-alive if tcpConn, ok := conn.(*net.TCPConn); ok { tcpConn.SetKeepAlive(true) - tcpConn.SetKeepAlivePeriod(t.config.KeepAlive) + tcpConn.SetKeepAlivePeriod(time.Duration(t.config.KeepAlive) * time.Second) } // SCRAM authentication if credentials configured - if t.config.AuthType == "scram" { + if t.config.Auth != nil && t.config.Auth.Type == "scram" { if err := t.performSCRAMAuth(conn); err != nil { conn.Close() return nil, fmt.Errorf("SCRAM authentication failed: %w", err) } t.logger.Debug("msg", "SCRAM authentication completed", "component", "tcp_client_sink", - "address", t.config.Address) + "address", t.address) } return conn, nil @@ -333,7 +252,17 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error { reader := bufio.NewReader(conn) // Create SCRAM client - scramClient := scram.NewClient(t.config.Username, t.config.Password) + scramClient := auth.NewScramClient(t.config.Auth.Username, t.config.Auth.Password) + + // Wait for AUTH_REQUIRED from server + authPrompt, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("failed to read auth prompt: %w", err) + } + + if strings.TrimSpace(authPrompt) != "AUTH_REQUIRED" { + return fmt.Errorf("unexpected server greeting: %s", authPrompt) + } // Step 1: Send ClientFirst clientFirst, err := scramClient.StartAuthentication() @@ -341,8 +270,10 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error { return fmt.Errorf("failed to start SCRAM: %w", err) } - clientFirstJSON, _ := json.Marshal(clientFirst) - msg := fmt.Sprintf("SCRAM-FIRST %s\n", clientFirstJSON) + msg, err := auth.FormatSCRAMRequest("SCRAM-FIRST", clientFirst) + if err != nil { + return err + } if _, err := conn.Write([]byte(msg)); err != nil { return fmt.Errorf("failed to send SCRAM-FIRST: %w", err) @@ -354,13 +285,17 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error { return fmt.Errorf("failed to read SCRAM challenge: %w", err) } - parts := strings.Fields(strings.TrimSpace(response)) - if len(parts) != 2 || parts[0] != "SCRAM-CHALLENGE" { - return fmt.Errorf("unexpected server response: %s", response) + command, data, err := auth.ParseSCRAMResponse(response) + if err != nil { + return err } - var serverFirst scram.ServerFirst - if err := json.Unmarshal([]byte(parts[1]), &serverFirst); err != nil { + if command != "SCRAM-CHALLENGE" { + return fmt.Errorf("unexpected server response: %s", command) + } + + var serverFirst auth.ServerFirst + if err := json.Unmarshal([]byte(data), &serverFirst); err != nil { return fmt.Errorf("failed to parse server challenge: %w", err) } @@ -370,8 +305,10 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error { return fmt.Errorf("failed to process challenge: %w", err) } - clientFinalJSON, _ := json.Marshal(clientFinal) - msg = fmt.Sprintf("SCRAM-PROOF %s\n", clientFinalJSON) + msg, err = auth.FormatSCRAMRequest("SCRAM-PROOF", clientFinal) + if err != nil { + return err + } if _, err := conn.Write([]byte(msg)); err != nil { return fmt.Errorf("failed to send SCRAM-PROOF: %w", err) @@ -383,19 +320,15 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error { return fmt.Errorf("failed to read SCRAM result: %w", err) } - parts = strings.Fields(strings.TrimSpace(response)) - if len(parts) < 1 { - return fmt.Errorf("empty server response") + command, data, err = auth.ParseSCRAMResponse(response) + if err != nil { + return err } - switch parts[0] { + switch command { case "SCRAM-OK": - if len(parts) != 2 { - return fmt.Errorf("invalid SCRAM-OK response") - } - - var serverFinal scram.ServerFinal - if err := json.Unmarshal([]byte(parts[1]), &serverFinal); err != nil { + var serverFinal auth.ServerFinal + if err := json.Unmarshal([]byte(data), &serverFinal); err != nil { return fmt.Errorf("failed to parse server signature: %w", err) } @@ -406,21 +339,21 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error { t.logger.Info("msg", "SCRAM authentication successful", "component", "tcp_client_sink", - "address", t.config.Address, - "username", t.config.Username, + "address", t.address, + "username", t.config.Auth.Username, "session_id", serverFinal.SessionID) return nil case "SCRAM-FAIL": - reason := "unknown" - if len(parts) > 1 { - reason = strings.Join(parts[1:], " ") + reason := data + if reason == "" { + reason = "unknown" } return fmt.Errorf("authentication failed: %s", reason) default: - return fmt.Errorf("unexpected response: %s", response) + return fmt.Errorf("unexpected response: %s", command) } } @@ -436,7 +369,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) { return case <-ticker.C: // Set read deadline - if err := conn.SetReadDeadline(time.Now().Add(t.config.ReadTimeout)); err != nil { + if err := conn.SetReadDeadline(time.Now().Add(time.Duration(t.config.ReadTimeout) * time.Second)); err != nil { t.logger.Debug("msg", "Failed to set read deadline", "error", err) return } @@ -502,7 +435,7 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error { } // Set write deadline - if err := conn.SetWriteDeadline(time.Now().Add(t.config.WriteTimeout)); err != nil { + if err := conn.SetWriteDeadline(time.Now().Add(time.Duration(t.config.WriteTimeout) * time.Second)); err != nil { return fmt.Errorf("failed to set write deadline: %w", err) } @@ -518,10 +451,4 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error { } return nil -} - -// Not applicable, Clients authenticate to remote servers using Username/Password in config -func (h *TCPClientSink) SetAuth(authCfg *config.AuthConfig) { - // No-op: client sinks don't validate incoming connections - // They authenticate to remote servers using Username/Password fields } \ No newline at end of file diff --git a/src/internal/source/directory.go b/src/internal/source/directory.go index c28bea8..7d7764e 100644 --- a/src/internal/source/directory.go +++ b/src/internal/source/directory.go @@ -21,9 +21,7 @@ import ( // Monitors a directory for log files type DirectorySource struct { - path string - pattern string - checkInterval time.Duration + config *config.DirectorySourceOptions subscribers []chan core.LogEntry watchers map[string]*fileWatcher mu sync.RWMutex @@ -38,34 +36,16 @@ type DirectorySource struct { } // Creates a new directory monitoring source -func NewDirectorySource(options map[string]any, logger *log.Logger) (*DirectorySource, error) { - path, ok := options["path"].(string) - if !ok { - return nil, fmt.Errorf("directory source requires 'path' option") - } - - pattern, _ := options["pattern"].(string) - if pattern == "" { - pattern = "*" - } - - checkInterval := 100 * time.Millisecond - if ms, ok := options["check_interval_ms"].(int64); ok && ms > 0 { - checkInterval = time.Duration(ms) * time.Millisecond - } - - absPath, err := filepath.Abs(path) - if err != nil { - return nil, fmt.Errorf("invalid path %s: %w", path, err) +func NewDirectorySource(opts *config.DirectorySourceOptions, logger *log.Logger) (*DirectorySource, error) { + if opts == nil { + return nil, fmt.Errorf("directory source options cannot be nil") } ds := &DirectorySource{ - path: absPath, - pattern: pattern, - checkInterval: checkInterval, - watchers: make(map[string]*fileWatcher), - startTime: time.Now(), - logger: logger, + config: opts, + watchers: make(map[string]*fileWatcher), + startTime: time.Now(), + logger: logger, } ds.lastEntryTime.Store(time.Time{}) @@ -88,9 +68,9 @@ func (ds *DirectorySource) Start() error { ds.logger.Info("msg", "Directory source started", "component", "directory_source", - "path", ds.path, - "pattern", ds.pattern, - "check_interval_ms", ds.checkInterval.Milliseconds()) + "path", ds.config.Path, + "pattern", ds.config.Pattern, + "check_interval_ms", ds.config.CheckIntervalMS) return nil } @@ -111,7 +91,7 @@ func (ds *DirectorySource) Stop() { ds.logger.Info("msg", "Directory source stopped", "component", "directory_source", - "path", ds.path) + "path", ds.config.Path) } func (ds *DirectorySource) GetStats() SourceStats { @@ -171,7 +151,7 @@ func (ds *DirectorySource) monitorLoop() { ds.checkTargets() - ticker := time.NewTicker(ds.checkInterval) + ticker := time.NewTicker(time.Duration(ds.config.CheckIntervalMS) * time.Millisecond) defer ticker.Stop() for { @@ -189,8 +169,8 @@ func (ds *DirectorySource) checkTargets() { if err != nil { ds.logger.Warn("msg", "Failed to scan directory", "component", "directory_source", - "path", ds.path, - "pattern", ds.pattern, + "path", ds.config.Path, + "pattern", ds.config.Pattern, "error", err) return } @@ -203,13 +183,13 @@ func (ds *DirectorySource) checkTargets() { } func (ds *DirectorySource) scanDirectory() ([]string, error) { - entries, err := os.ReadDir(ds.path) + entries, err := os.ReadDir(ds.config.Path) if err != nil { return nil, err } // Convert glob pattern to regex - regexPattern := globToRegex(ds.pattern) + regexPattern := globToRegex(ds.config.Pattern) re, err := regexp.Compile(regexPattern) if err != nil { return nil, fmt.Errorf("invalid pattern regex: %w", err) @@ -223,7 +203,7 @@ func (ds *DirectorySource) scanDirectory() ([]string, error) { name := entry.Name() if re.MatchString(name) { - files = append(files, filepath.Join(ds.path, name)) + files = append(files, filepath.Join(ds.config.Path, name)) } } @@ -287,8 +267,4 @@ func globToRegex(glob string) string { regex = strings.ReplaceAll(regex, `\*`, `.*`) regex = strings.ReplaceAll(regex, `\?`, `.`) return "^" + regex + "$" -} - -func (ds *DirectorySource) SetAuth(auth *config.AuthConfig) { - // Authentication does not apply to directory source } \ No newline at end of file diff --git a/src/internal/source/http.go b/src/internal/source/http.go index b1275af..076d064 100644 --- a/src/internal/source/http.go +++ b/src/internal/source/http.go @@ -14,7 +14,6 @@ import ( "logwisp/src/internal/core" "logwisp/src/internal/limit" "logwisp/src/internal/tls" - "logwisp/src/internal/version" "github.com/lixenwraith/log" "github.com/valyala/fasthttp" @@ -22,12 +21,7 @@ import ( // Receives log entries via HTTP POST requests type HTTPSource struct { - // Config - host string - port int64 - path string - bufferSize int64 - maxRequestBodySize int64 + config *config.HTTPSourceOptions // Application server *fasthttp.Server @@ -42,11 +36,9 @@ type HTTPSource struct { // Security authenticator *auth.Authenticator - authConfig *config.AuthConfig authFailures atomic.Uint64 authSuccesses atomic.Uint64 tlsManager *tls.Manager - tlsConfig *config.TLSConfig // Statistics totalEntries atomic.Uint64 @@ -57,108 +49,52 @@ type HTTPSource struct { } // Creates a new HTTP server source -func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, error) { - host := "0.0.0.0" - if h, ok := options["host"].(string); ok && h != "" { - host = h - } - - port, ok := options["port"].(int64) - if !ok || port < 1 || port > 65535 { - return nil, fmt.Errorf("http source requires valid 'port' option") - } - - ingestPath := "/ingest" - if path, ok := options["path"].(string); ok && path != "" { - ingestPath = path - } - - bufferSize := int64(1000) - if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { - bufferSize = bufSize - } - - maxRequestBodySize := int64(10 * 1024 * 1024) // fasthttp default 10MB - if maxBodySize, ok := options["max_body_size"].(int64); ok && maxBodySize > 0 && maxBodySize < maxRequestBodySize { - maxRequestBodySize = maxBodySize +func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSource, error) { + // Validation done in config package + if opts == nil { + return nil, fmt.Errorf("HTTP source options cannot be nil") } h := &HTTPSource{ - host: host, - port: port, - path: ingestPath, - bufferSize: bufferSize, - maxRequestBodySize: maxRequestBodySize, - done: make(chan struct{}), - startTime: time.Now(), - logger: logger, + config: opts, + done: make(chan struct{}), + startTime: time.Now(), + logger: logger, } h.lastEntryTime.Store(time.Time{}) // Initialize net limiter if configured - if nl, ok := options["net_limit"].(map[string]any); ok { - if enabled, _ := nl["enabled"].(bool); enabled { - cfg := config.NetLimitConfig{ - Enabled: true, - } - - if rps, ok := nl["requests_per_second"].(float64); ok { - cfg.RequestsPerSecond = rps - } - if burst, ok := nl["burst_size"].(int64); ok { - cfg.BurstSize = burst - } - if respCode, ok := nl["response_code"].(int64); ok { - cfg.ResponseCode = respCode - } - if msg, ok := nl["response_message"].(string); ok { - cfg.ResponseMessage = msg - } - if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok { - cfg.MaxConnectionsPerIP = maxPerIP - } - if maxTotal, ok := nl["max_connections_total"].(int64); ok { - cfg.MaxConnectionsTotal = maxTotal - } - - h.netLimiter = limit.NewNetLimiter(cfg, logger) - } + if opts.NetLimit != nil && (opts.NetLimit.Enabled || + len(opts.NetLimit.IPWhitelist) > 0 || + len(opts.NetLimit.IPBlacklist) > 0) { + h.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger) } - // Extract TLS config after existing options - if tc, ok := options["tls"].(map[string]any); ok { - h.tlsConfig = &config.TLSConfig{} - h.tlsConfig.Enabled, _ = tc["enabled"].(bool) - if certFile, ok := tc["cert_file"].(string); ok { - h.tlsConfig.CertFile = certFile + // Initialize TLS manager if configured + if opts.TLS != nil && opts.TLS.Enabled { + tlsManager, err := tls.NewManager(opts.TLS, logger) + if err != nil { + return nil, fmt.Errorf("failed to create TLS manager: %w", err) } - if keyFile, ok := tc["key_file"].(string); ok { - h.tlsConfig.KeyFile = keyFile - } - h.tlsConfig.ClientAuth, _ = tc["client_auth"].(bool) - if caFile, ok := tc["client_ca_file"].(string); ok { - h.tlsConfig.ClientCAFile = caFile - } - h.tlsConfig.VerifyClientCert, _ = tc["verify_client_cert"].(bool) - h.tlsConfig.InsecureSkipVerify, _ = tc["insecure_skip_verify"].(bool) - if minVer, ok := tc["min_version"].(string); ok { - h.tlsConfig.MinVersion = minVer - } - if maxVer, ok := tc["max_version"].(string); ok { - h.tlsConfig.MaxVersion = maxVer - } - if ciphers, ok := tc["cipher_suites"].(string); ok { - h.tlsConfig.CipherSuites = ciphers + h.tlsManager = tlsManager + } + + // Initialize authenticator if configured + if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" { + // Verify TLS is enabled for auth (validation should have caught this) + if h.tlsManager == nil { + return nil, fmt.Errorf("authentication requires TLS to be enabled") } - // Create TLS manager - if h.tlsConfig.Enabled { - tlsManager, err := tls.NewManager(h.tlsConfig, logger) - if err != nil { - return nil, fmt.Errorf("failed to create TLS manager: %w", err) - } - h.tlsManager = tlsManager + authenticator, err := auth.NewAuthenticator(opts.Auth, logger) + if err != nil { + return nil, fmt.Errorf("failed to create authenticator: %w", err) } + h.authenticator = authenticator + + logger.Info("msg", "Authentication configured for HTTP source", + "component", "http_source", + "auth_type", opts.Auth.Type) } return h, nil @@ -168,23 +104,24 @@ func (h *HTTPSource) Subscribe() <-chan core.LogEntry { h.mu.Lock() defer h.mu.Unlock() - ch := make(chan core.LogEntry, h.bufferSize) + ch := make(chan core.LogEntry, h.config.BufferSize) h.subscribers = append(h.subscribers, ch) return ch } func (h *HTTPSource) Start() error { h.server = &fasthttp.Server{ - Name: fmt.Sprintf("LogWisp/%s", version.Short()), Handler: h.requestHandler, DisableKeepalive: false, StreamRequestBody: true, CloseOnShutdown: true, - MaxRequestBodySize: int(h.maxRequestBodySize), + ReadTimeout: time.Duration(h.config.ReadTimeout) * time.Millisecond, + WriteTimeout: time.Duration(h.config.WriteTimeout) * time.Millisecond, + MaxRequestBodySize: int(h.config.MaxRequestBodySize), } // Use configured host and port - addr := fmt.Sprintf("%s:%d", h.host, h.port) + addr := fmt.Sprintf("%s:%d", h.config.Host, h.config.Port) // Start server in background h.wg.Add(1) @@ -193,35 +130,35 @@ func (h *HTTPSource) Start() error { defer h.wg.Done() h.logger.Info("msg", "HTTP source server starting", "component", "http_source", - "port", h.port, - "path", h.path, - "tls_enabled", h.tlsManager != nil) + "port", h.config.Port, + "ingest_path", h.config.IngestPath, + "tls_enabled", h.tlsManager != nil, + "auth_enabled", h.authenticator != nil) var err error - // Check for TLS manager and start the appropriate server type if h.tlsManager != nil { + // HTTPS server h.server.TLSConfig = h.tlsManager.GetHTTPConfig() - err = h.server.ListenAndServeTLS(addr, h.tlsConfig.CertFile, h.tlsConfig.KeyFile) + err = h.server.ListenAndServeTLS(addr, h.config.TLS.CertFile, h.config.TLS.KeyFile) } else { + // HTTP server err = h.server.ListenAndServe(addr) } if err != nil { h.logger.Error("msg", "HTTP source server failed", "component", "http_source", - "port", h.port, + "port", h.config.Port, "error", err) errChan <- err } }() - // Robust server startup check with timeout + // Wait briefly for server startup select { case err := <-errChan: - // Server failed to start return fmt.Errorf("HTTP server failed to start: %w", err) case <-time.After(250 * time.Millisecond): - // Server started successfully (no immediate error) return nil } } @@ -263,6 +200,21 @@ func (h *HTTPSource) GetStats() SourceStats { netLimitStats = h.netLimiter.GetStats() } + var authStats map[string]any + if h.authenticator != nil { + authStats = map[string]any{ + "enabled": true, + "type": h.config.Auth.Type, + "failures": h.authFailures.Load(), + "successes": h.authSuccesses.Load(), + } + } + + var tlsStats map[string]any + if h.tlsManager != nil { + tlsStats = h.tlsManager.GetStats() + } + return SourceStats{ Type: "http", TotalEntries: h.totalEntries.Load(), @@ -270,10 +222,13 @@ func (h *HTTPSource) GetStats() SourceStats { StartTime: h.startTime, LastEntryTime: lastEntry, Details: map[string]any{ - "port": h.port, - "path": h.path, + "host": h.config.Host, + "port": h.config.Port, + "path": h.config.IngestPath, "invalid_entries": h.invalidEntries.Load(), "net_limit": netLimitStats, + "auth": authStats, + "tls": tlsStats, }, } } @@ -307,17 +262,10 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { } } - // 2.5. Check TLS requirement for auth (early reject) - if h.authenticator != nil && h.authConfig.Type != "none" { - // Check if connection is TLS + // 3. Check TLS requirement for auth + if h.authenticator != nil { isTLS := ctx.IsTLS() || h.tlsManager != nil - if !isTLS { - h.logger.Error("msg", "Authentication configured but connection is not TLS", - "component", "http_source", - "remote_addr", remoteAddr, - "auth_type", h.authConfig.Type) - ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]string{ @@ -326,21 +274,45 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { }) return } + + // Authenticate request + authHeader := string(ctx.Request.Header.Peek("Authorization")) + session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr) + if err != nil { + h.authFailures.Add(1) + h.logger.Warn("msg", "Authentication failed", + "component", "http_source", + "remote_addr", remoteAddr, + "error", err) + + ctx.SetStatusCode(fasthttp.StatusUnauthorized) + if h.config.Auth.Type == "basic" && h.config.Auth.Basic != nil && h.config.Auth.Basic.Realm != "" { + ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, h.config.Auth.Basic.Realm)) + } + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(map[string]string{ + "error": "Authentication failed", + }) + return + } + + h.authSuccesses.Add(1) + _ = session // Session can be used for audit logging } - // 3. Path check (only process ingest path) + // 4. Path check path := string(ctx.Path()) - if path != h.path { + if path != h.config.IngestPath { ctx.SetStatusCode(fasthttp.StatusNotFound) ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]string{ "error": "Not Found", - "hint": fmt.Sprintf("POST logs to %s", h.path), + "hint": fmt.Sprintf("POST logs to %s", h.config.IngestPath), }) return } - // 4. Method check (only accept POST) + // 5. Method check (only accepts POST) if string(ctx.Method()) != "POST" { ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed) ctx.SetContentType("application/json") @@ -352,43 +324,10 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { return } - // 5. Authentication check (if configured) - if h.authenticator != nil { - authHeader := string(ctx.Request.Header.Peek("Authorization")) - session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr) - if err != nil { - h.authFailures.Add(1) - h.logger.Warn("msg", "Authentication failed", - "component", "http_source", - "remote_addr", remoteAddr, - "error", err) - - ctx.SetStatusCode(fasthttp.StatusUnauthorized) - if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil { - realm := h.authConfig.BasicAuth.Realm - if realm == "" { - realm = "Restricted" - } - ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) - } else if h.authConfig.Type == "bearer" { - ctx.Response.Header.Set("WWW-Authenticate", "Bearer") - } - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(map[string]string{ - "error": "Unauthorized", - }) - return - } - h.authSuccesses.Add(1) - h.logger.Debug("msg", "Request authenticated", - "component", "http_source", - "remote_addr", remoteAddr, - "username", session.Username) - } - - // 6. Process request body + // 6. Process log entry body := ctx.PostBody() if len(body) == 0 { + h.invalidEntries.Add(1) ctx.SetStatusCode(fasthttp.StatusBadRequest) ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]string{ @@ -397,32 +336,34 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { return } - // 7. Parse log entries - entries, err := h.parseEntries(body) - if err != nil { + var entry core.LogEntry + if err := json.Unmarshal(body, &entry); err != nil { h.invalidEntries.Add(1) ctx.SetStatusCode(fasthttp.StatusBadRequest) ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]string{ - "error": fmt.Sprintf("Invalid log format: %v", err), + "error": fmt.Sprintf("Invalid JSON: %v", err), }) return } - // 8. Publish entries to subscribers - accepted := 0 - for _, entry := range entries { - if h.publish(entry) { - accepted++ - } + // Set defaults + if entry.Time.IsZero() { + entry.Time = time.Now() } + if entry.Source == "" { + entry.Source = "http" + } + entry.RawSize = int64(len(body)) - // 9. Return success response + // Publish to subscribers + h.publish(entry) + + // Success response ctx.SetStatusCode(fasthttp.StatusAccepted) ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(map[string]any{ - "accepted": accepted, - "total": len(entries), + json.NewEncoder(ctx).Encode(map[string]string{ + "status": "accepted", }) } @@ -501,29 +442,22 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) { return entries, nil } -func (h *HTTPSource) publish(entry core.LogEntry) bool { +func (h *HTTPSource) publish(entry core.LogEntry) { h.mu.RLock() defer h.mu.RUnlock() h.totalEntries.Add(1) h.lastEntryTime.Store(entry.Time) - dropped := false for _, ch := range h.subscribers { select { case ch <- entry: default: - dropped = true h.droppedEntries.Add(1) + h.logger.Debug("msg", "Dropped log entry - subscriber buffer full", + "component", "http_source") } } - - if dropped { - h.logger.Debug("msg", "Dropped log entry - subscriber buffer full", - "component", "http_source") - } - - return true } // Splits bytes into lines, handling both \n and \r\n @@ -549,25 +483,4 @@ func splitLines(data []byte) [][]byte { } return lines -} - -// Configure HTTP source auth -func (h *HTTPSource) SetAuth(authCfg *config.AuthConfig) { - if authCfg == nil || authCfg.Type == "none" { - return - } - - h.authConfig = authCfg - authenticator, err := auth.NewAuthenticator(authCfg, h.logger) - if err != nil { - h.logger.Error("msg", "Failed to initialize authenticator for HTTP source", - "component", "http_source", - "error", err) - return - } - h.authenticator = authenticator - - h.logger.Info("msg", "Authentication configured for HTTP source", - "component", "http_source", - "auth_type", authCfg.Type) } \ No newline at end of file diff --git a/src/internal/source/source.go b/src/internal/source/source.go index 5571d97..4fe5d64 100644 --- a/src/internal/source/source.go +++ b/src/internal/source/source.go @@ -4,7 +4,6 @@ package source import ( "time" - "logwisp/src/internal/config" "logwisp/src/internal/core" ) @@ -21,9 +20,6 @@ type Source interface { // Returns source statistics GetStats() SourceStats - - // Configure authentication - SetAuth(auth *config.AuthConfig) } // Contains statistics about a source diff --git a/src/internal/source/stdin.go b/src/internal/source/stdin.go index 50fd5cc..826b5cc 100644 --- a/src/internal/source/stdin.go +++ b/src/internal/source/stdin.go @@ -15,24 +15,25 @@ import ( // Reads log entries from standard input type StdinSource struct { + config *config.StdinSourceOptions subscribers []chan core.LogEntry done chan struct{} totalEntries atomic.Uint64 droppedEntries atomic.Uint64 - bufferSize int64 startTime time.Time lastEntryTime atomic.Value // time.Time logger *log.Logger } -func NewStdinSource(options map[string]any, logger *log.Logger) (*StdinSource, error) { - bufferSize := int64(1000) // default - if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { - bufferSize = bufSize +func NewStdinSource(opts *config.StdinSourceOptions, logger *log.Logger) (*StdinSource, error) { + if opts == nil { + opts = &config.StdinSourceOptions{ + BufferSize: 1000, // Default + } } source := &StdinSource{ - bufferSize: bufferSize, + config: opts, subscribers: make([]chan core.LogEntry, 0), done: make(chan struct{}), logger: logger, @@ -43,7 +44,7 @@ func NewStdinSource(options map[string]any, logger *log.Logger) (*StdinSource, e } func (s *StdinSource) Subscribe() <-chan core.LogEntry { - ch := make(chan core.LogEntry, s.bufferSize) + ch := make(chan core.LogEntry, s.config.BufferSize) s.subscribers = append(s.subscribers, ch) return ch } @@ -119,8 +120,4 @@ func (s *StdinSource) publish(entry core.LogEntry) { "component", "stdin_source") } } -} - -func (s *StdinSource) SetAuth(auth *config.AuthConfig) { - // Authentication does not apply to stdin source } \ No newline at end of file diff --git a/src/internal/source/tcp.go b/src/internal/source/tcp.go index 31f4fa4..bfcd633 100644 --- a/src/internal/source/tcp.go +++ b/src/internal/source/tcp.go @@ -4,7 +4,6 @@ package source import ( "bytes" "context" - "encoding/base64" "encoding/json" "fmt" "net" @@ -17,7 +16,6 @@ import ( "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/limit" - "logwisp/src/internal/scram" "github.com/lixenwraith/log" "github.com/lixenwraith/log/compat" @@ -31,19 +29,19 @@ const ( // Receives log entries via TCP connections type TCPSource struct { - host string - port int64 - bufferSize int64 - server *tcpSourceServer - subscribers []chan core.LogEntry - mu sync.RWMutex - done chan struct{} - engine *gnet.Engine - engineMu sync.Mutex - wg sync.WaitGroup - netLimiter *limit.NetLimiter - logger *log.Logger - scramManager *scram.ScramManager + config *config.TCPSourceOptions + server *tcpSourceServer + subscribers []chan core.LogEntry + mu sync.RWMutex + done chan struct{} + engine *gnet.Engine + engineMu sync.Mutex + wg sync.WaitGroup + authenticator *auth.Authenticator + netLimiter *limit.NetLimiter + logger *log.Logger + scramManager *auth.ScramManager + scramProtocolHandler *auth.ScramProtocolHandler // Statistics totalEntries atomic.Uint64 @@ -57,60 +55,36 @@ type TCPSource struct { } // Creates a new TCP server source -func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error) { - host := "0.0.0.0" - if h, ok := options["host"].(string); ok && h != "" { - host = h - } - - port, ok := options["port"].(int64) - if !ok || port < 1 || port > 65535 { - return nil, fmt.Errorf("tcp source requires valid 'port' option") - } - - bufferSize := int64(1000) - if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { - bufferSize = bufSize +func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource, error) { + // Accept typed config - validation done in config package + if opts == nil { + return nil, fmt.Errorf("TCP source options cannot be nil") } t := &TCPSource{ - host: host, - port: port, - bufferSize: bufferSize, - done: make(chan struct{}), - startTime: time.Now(), - logger: logger, + config: opts, + done: make(chan struct{}), + startTime: time.Now(), + logger: logger, } t.lastEntryTime.Store(time.Time{}) // Initialize net limiter if configured - if nl, ok := options["net_limit"].(map[string]any); ok { - if enabled, _ := nl["enabled"].(bool); enabled { - cfg := config.NetLimitConfig{ - Enabled: true, - } + if opts.NetLimit != nil && (opts.NetLimit.Enabled || + len(opts.NetLimit.IPWhitelist) > 0 || + len(opts.NetLimit.IPBlacklist) > 0) { + t.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger) + } - if rps, ok := nl["requests_per_second"].(float64); ok { - cfg.RequestsPerSecond = rps - } - if burst, ok := nl["burst_size"].(int64); ok { - cfg.BurstSize = burst - } - if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok { - cfg.MaxConnectionsPerIP = maxPerIP - } - if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok { - cfg.MaxConnectionsPerUser = maxPerUser - } - if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok { - cfg.MaxConnectionsPerToken = maxPerToken - } - if maxTotal, ok := nl["max_connections_total"].(int64); ok { - cfg.MaxConnectionsTotal = maxTotal - } - - t.netLimiter = limit.NewNetLimiter(cfg, logger) - } + // Initialize SCRAM + if opts.Auth != nil && opts.Auth.Type == "scram" && opts.Auth.Scram != nil { + t.scramManager = auth.NewScramManager(opts.Auth.Scram) + t.scramProtocolHandler = auth.NewScramProtocolHandler(t.scramManager, logger) + logger.Info("msg", "SCRAM authentication configured for TCP source", + "component", "tcp_source", + "users", len(opts.Auth.Scram.Users)) + } else if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" { + return nil, fmt.Errorf("TCP source only supports 'none' or 'scram' auth") } return t, nil @@ -120,7 +94,7 @@ func (t *TCPSource) Subscribe() <-chan core.LogEntry { t.mu.Lock() defer t.mu.Unlock() - ch := make(chan core.LogEntry, t.bufferSize) + ch := make(chan core.LogEntry, t.config.BufferSize) t.subscribers = append(t.subscribers, ch) return ch } @@ -132,7 +106,7 @@ func (t *TCPSource) Start() error { } // Use configured host and port - addr := fmt.Sprintf("tcp://%s:%d", t.host, t.port) + addr := fmt.Sprintf("tcp://%s:%d", t.config.Host, t.config.Port) // Create a gnet adapter using the existing logger instance gnetLogger := compat.NewGnetAdapter(t.logger) @@ -144,17 +118,19 @@ func (t *TCPSource) Start() error { defer t.wg.Done() t.logger.Info("msg", "TCP source server starting", "component", "tcp_source", - "port", t.port) + "port", t.config.Port, + "auth_enabled", t.authenticator != nil) err := gnet.Run(t.server, addr, gnet.WithLogger(gnetLogger), gnet.WithMulticore(true), gnet.WithReusePort(true), + gnet.WithTCPKeepAlive(time.Duration(t.config.KeepAlivePeriod)*time.Millisecond), ) if err != nil { t.logger.Error("msg", "TCP source server failed", "component", "tcp_source", - "port", t.port, + "port", t.config.Port, "error", err) } errChan <- err @@ -169,7 +145,7 @@ func (t *TCPSource) Start() error { return err case <-time.After(100 * time.Millisecond): // Server started successfully - t.logger.Info("msg", "TCP server started", "port", t.port) + t.logger.Info("msg", "TCP server started", "port", t.config.Port) return nil } } @@ -214,6 +190,16 @@ func (t *TCPSource) GetStats() SourceStats { netLimitStats = t.netLimiter.GetStats() } + var authStats map[string]any + if t.authenticator != nil { + authStats = map[string]any{ + "enabled": true, + "type": t.config.Auth.Type, + "failures": t.authFailures.Load(), + "successes": t.authSuccesses.Load(), + } + } + return SourceStats{ Type: "tcp", TotalEntries: t.totalEntries.Load(), @@ -221,49 +207,41 @@ func (t *TCPSource) GetStats() SourceStats { StartTime: t.startTime, LastEntryTime: lastEntry, Details: map[string]any{ - "port": t.port, + "port": t.config.Port, "active_connections": t.activeConns.Load(), "invalid_entries": t.invalidEntries.Load(), "net_limit": netLimitStats, + "auth": authStats, }, } } -func (t *TCPSource) publish(entry core.LogEntry) bool { +func (t *TCPSource) publish(entry core.LogEntry) { t.mu.RLock() defer t.mu.RUnlock() t.totalEntries.Add(1) t.lastEntryTime.Store(entry.Time) - dropped := false for _, ch := range t.subscribers { select { case ch <- entry: default: - dropped = true t.droppedEntries.Add(1) + t.logger.Debug("msg", "Dropped log entry - subscriber buffer full", + "component", "tcp_source") } } - - if dropped { - t.logger.Debug("msg", "Dropped log entry - subscriber buffer full", - "component", "tcp_source") - } - - return true } // Represents a connected TCP client type tcpClient struct { - conn gnet.Conn - buffer *bytes.Buffer - authenticated bool - authTimeout time.Time - session *auth.Session - maxBufferSeen int - cumulativeEncrypted int64 - scramState *scram.HandshakeState + conn gnet.Conn + buffer *bytes.Buffer + authenticated bool + authTimeout time.Time + session *auth.Session + maxBufferSeen int } // Handles gnet events @@ -282,7 +260,7 @@ func (s *tcpSourceServer) OnBoot(eng gnet.Engine) gnet.Action { s.source.logger.Debug("msg", "TCP source server booted", "component", "tcp_source", - "port", s.source.port) + "port", s.source.config.Port) return gnet.None } @@ -303,6 +281,16 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { return nil, gnet.Close } + // Check if connection is allowed + ip := tcpAddr.IP + if ip.To4() == nil { + // Reject IPv6 + s.source.logger.Warn("msg", "IPv6 connection rejected", + "component", "tcp_source", + "remote_addr", remoteAddr) + return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close + } + if !s.source.netLimiter.CheckTCP(tcpAddr) { s.source.logger.Warn("msg", "TCP connection net limited", "component", "tcp_source", @@ -311,49 +299,66 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { } // Track connection - s.source.netLimiter.AddConnection(remoteAddr) + // s.source.netLimiter.AddConnection(remoteAddr) + if !s.source.netLimiter.TrackConnection(ip.String(), "", "") { + s.source.logger.Warn("msg", "TCP connection limit exceeded", + "component", "tcp_source", + "remote_addr", remoteAddr) + return nil, gnet.Close + } } // Create client state client := &tcpClient{ conn: c, buffer: bytes.NewBuffer(nil), - authTimeout: time.Now().Add(30 * time.Second), - authenticated: s.source.scramManager == nil, + authenticated: s.source.authenticator == nil, // No auth = auto authenticated + } + + if s.source.authenticator != nil { + // Set auth timeout + client.authTimeout = time.Now().Add(10 * time.Second) + + // Send auth challenge for SCRAM + if s.source.config.Auth.Type == "scram" { + out = []byte("AUTH_REQUIRED\n") + } } s.mu.Lock() s.clients[c] = client s.mu.Unlock() - newCount := s.source.activeConns.Add(1) + s.source.activeConns.Add(1) s.source.logger.Debug("msg", "TCP connection opened", "component", "tcp_source", "remote_addr", remoteAddr, - "active_connections", newCount, - "requires_auth", s.source.scramManager != nil) + "auth_enabled", s.source.authenticator != nil) - return nil, gnet.None + return out, gnet.None } func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action { remoteAddr := c.RemoteAddr().String() + // Untrack connection + if s.source.netLimiter != nil { + if tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr); err == nil { + s.source.netLimiter.ReleaseConnection(tcpAddr.IP.String(), "", "") + // s.source.netLimiter.RemoveConnection(remoteAddr) + } + } + // Remove client state s.mu.Lock() delete(s.clients, c) s.mu.Unlock() - // Remove connection tracking - if s.source.netLimiter != nil { - s.source.netLimiter.RemoveConnection(remoteAddr) - } - - newCount := s.source.activeConns.Add(-1) + newConnectionCount := s.source.activeConns.Add(-1) s.source.logger.Debug("msg", "TCP connection closed", "component", "tcp_source", "remote_addr", remoteAddr, - "active_connections", newCount, + "active_connections", newConnectionCount, "error", err) return gnet.None } @@ -383,6 +388,8 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { s.source.logger.Warn("msg", "Authentication timeout", "component", "tcp_source", "remote_addr", c.RemoteAddr().String()) + s.source.authFailures.Add(1) + c.AsyncWrite([]byte("AUTH_TIMEOUT\n"), nil) return gnet.Close } @@ -392,7 +399,12 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { client.buffer.Write(data) - // Look for complete line + // Use centralized SCRAM protocol handler + if s.source.scramProtocolHandler == nil { + s.source.scramProtocolHandler = auth.NewScramProtocolHandler(s.source.scramManager, s.source.logger) + } + + // Look for complete auth line for { idx := bytes.IndexByte(client.buffer.Bytes(), '\n') if idx < 0 { @@ -402,85 +414,44 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { line := client.buffer.Bytes()[:idx] client.buffer.Next(idx + 1) - // Parse SCRAM messages - parts := strings.Fields(string(line)) - if len(parts) < 2 { - c.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil) - return gnet.Close + // Process auth message through handler + authenticated, session, err := s.source.scramProtocolHandler.HandleAuthMessage(line, c) + if err != nil { + s.source.logger.Warn("msg", "SCRAM authentication failed", + "component", "tcp_source", + "remote_addr", c.RemoteAddr().String(), + "error", err) + + if strings.Contains(err.Error(), "unknown command") { + return gnet.Close + } + // Continue for other errors (might be multi-step auth) } - switch parts[0] { - case "SCRAM-FIRST": - // Parse ClientFirst JSON - var clientFirst scram.ClientFirst - if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil { - c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil) - return gnet.Close - } - - // Process with SCRAM server - serverFirst, err := s.source.scramManager.HandleClientFirst(&clientFirst) - if err != nil { - // Still send challenge to prevent user enumeration - response, _ := json.Marshal(serverFirst) - c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil) - return gnet.Close - } - - // Send ServerFirst challenge - response, _ := json.Marshal(serverFirst) - c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil) - - case "SCRAM-PROOF": - // Parse ClientFinal JSON - var clientFinal scram.ClientFinal - if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil { - c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil) - return gnet.Close - } - - // Verify proof - serverFinal, err := s.source.scramManager.HandleClientFinal(&clientFinal) - if err != nil { - s.source.logger.Warn("msg", "SCRAM authentication failed", - "component", "tcp_source", - "remote_addr", c.RemoteAddr().String(), - "error", err) - c.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil) - return gnet.Close - } - + if authenticated && session != nil { // Authentication successful s.mu.Lock() client.authenticated = true - client.session = &auth.Session{ - ID: serverFinal.SessionID, - Method: "scram-sha-256", - RemoteAddr: c.RemoteAddr().String(), - CreatedAt: time.Now(), - } + client.session = session s.mu.Unlock() - // Send ServerFinal with signature - response, _ := json.Marshal(serverFinal) - c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil) - s.source.logger.Info("msg", "Client authenticated via SCRAM", "component", "tcp_source", "remote_addr", c.RemoteAddr().String(), - "session_id", serverFinal.SessionID) + "session_id", session.ID) // Clear auth buffer client.buffer.Reset() - - default: - c.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil) - return gnet.Close + break } } return gnet.None } + return s.processLogData(c, client, data) +} + +func (s *tcpSourceServer) processLogData(c gnet.Conn, client *tcpClient, data []byte) gnet.Action { // Check if appending the new data would exceed the client buffer limit. if client.buffer.Len()+len(data) > maxClientBufferSize { s.source.logger.Warn("msg", "Client buffer limit exceeded, closing connection.", @@ -571,48 +542,4 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { } return gnet.None -} - -func (t *TCPSource) InitSCRAMManager(authCfg *config.AuthConfig) { - if authCfg == nil || authCfg.Type != "scram" || authCfg.ScramAuth == nil { - return - } - - t.scramManager = scram.NewScramManager() - - // Load users from SCRAM config - for _, user := range authCfg.ScramAuth.Users { - storedKey, _ := base64.StdEncoding.DecodeString(user.StoredKey) - serverKey, _ := base64.StdEncoding.DecodeString(user.ServerKey) - salt, _ := base64.StdEncoding.DecodeString(user.Salt) - - cred := &scram.Credential{ - Username: user.Username, - StoredKey: storedKey, - ServerKey: serverKey, - Salt: salt, - ArgonTime: user.ArgonTime, - ArgonMemory: user.ArgonMemory, - ArgonThreads: user.ArgonThreads, - } - t.scramManager.AddCredential(cred) - } - - t.logger.Info("msg", "SCRAM authentication configured", - "component", "tcp_source", - "users", len(authCfg.ScramAuth.Users)) -} - -// Configure TCP source auth -func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) { - if authCfg == nil || authCfg.Type == "none" { - return - } - - // Initialize SCRAM manager - if authCfg.Type == "scram" { - t.InitSCRAMManager(authCfg) - t.logger.Info("msg", "SCRAM authentication configured for TCP source", - "component", "tcp_source") - } } \ No newline at end of file