diff --git a/.gitignore b/.gitignore index 11e3e03..4facf1e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ script build *.log *.toml +build.sh diff --git a/src/cmd/logwisp/bootstrap.go b/src/cmd/logwisp/bootstrap.go index 5404f55..0303467 100644 --- a/src/cmd/logwisp/bootstrap.go +++ b/src/cmd/logwisp/bootstrap.go @@ -13,7 +13,7 @@ import ( "github.com/lixenwraith/log" ) -// Creates and initializes the log transport service +// bootstrapService creates and initializes the main log transport service and its pipelines. func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service, error) { // Create service with logger dependency injection svc := service.NewService(ctx, logger) @@ -45,7 +45,7 @@ func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service return svc, nil } -// Sets up the logger based on configuration +// initializeLogger sets up the global logger based on the application's configuration. func initializeLogger(cfg *config.Config) error { logger = log.NewLogger() logCfg := log.DefaultConfig() @@ -103,7 +103,7 @@ func initializeLogger(cfg *config.Config) error { return logger.ApplyConfig(logCfg) } -// Sets up file-based logging parameters +// configureFileLogging sets up file-based logging parameters from the configuration. func configureFileLogging(logCfg *log.Config, cfg *config.Config) { if cfg.Logging.File != nil { logCfg.Directory = cfg.Logging.File.Directory @@ -116,6 +116,7 @@ func configureFileLogging(logCfg *log.Config, cfg *config.Config) { } } +// parseLogLevel converts a string log level to its corresponding integer value. func parseLogLevel(level string) (int64, error) { switch strings.ToLower(level) { case "debug": diff --git a/src/cmd/logwisp/commands/auth.go b/src/cmd/logwisp/commands/auth.go deleted file mode 100644 index d7b9c05..0000000 --- a/src/cmd/logwisp/commands/auth.go +++ /dev/null @@ -1,355 +0,0 @@ -// FILE: src/cmd/logwisp/commands/auth.go -package commands - -import ( - "crypto/rand" - "encoding/base64" - "flag" - "fmt" - "io" - "os" - "strings" - "syscall" - - "logwisp/src/internal/auth" - "logwisp/src/internal/core" - - "golang.org/x/term" -) - -type AuthCommand struct { - output io.Writer - errOut io.Writer -} - -func NewAuthCommand() *AuthCommand { - return &AuthCommand{ - output: os.Stdout, - errOut: os.Stderr, - } -} - -func (ac *AuthCommand) Execute(args []string) error { - cmd := flag.NewFlagSet("auth", flag.ContinueOnError) - cmd.SetOutput(ac.errOut) - - var ( - // User credentials - username = cmd.String("u", "", "Username") - usernameLong = cmd.String("user", "", "Username") - password = cmd.String("p", "", "Password (will prompt if not provided)") - passwordLong = cmd.String("password", "", "Password (will prompt if not provided)") - - // Auth type selection (multiple ways to specify) - authType = cmd.String("t", "", "Auth type: basic, scram, or token") - authTypeLong = cmd.String("type", "", "Auth type: basic, scram, or token") - useScram = cmd.Bool("s", false, "Generate SCRAM credentials (TCP)") - useScramLong = cmd.Bool("scram", false, "Generate SCRAM credentials (TCP)") - useBasic = cmd.Bool("b", false, "Generate basic auth credentials (HTTP)") - useBasicLong = cmd.Bool("basic", false, "Generate basic auth credentials (HTTP)") - - // Token generation - genToken = cmd.Bool("k", false, "Generate random bearer token") - genTokenLong = cmd.Bool("token", false, "Generate random bearer token") - tokenLen = cmd.Int("l", 32, "Token length in bytes") - tokenLenLong = cmd.Int("length", 32, "Token length in bytes") - - // Migration option - migrate = cmd.Bool("m", false, "Convert basic auth PHC to SCRAM") - migrateLong = cmd.Bool("migrate", false, "Convert basic auth PHC to SCRAM") - phcHash = cmd.String("phc", "", "PHC hash to migrate (required with --migrate)") - ) - - cmd.Usage = func() { - fmt.Fprintln(ac.errOut, "Generate authentication credentials for LogWisp") - fmt.Fprintln(ac.errOut, "\nUsage: logwisp auth [options]") - fmt.Fprintln(ac.errOut, "\nExamples:") - fmt.Fprintln(ac.errOut, " # Generate basic auth hash for HTTP sources/sinks") - fmt.Fprintln(ac.errOut, " logwisp auth -u admin -b") - fmt.Fprintln(ac.errOut, " logwisp auth --user=admin --basic") - fmt.Fprintln(ac.errOut, " ") - fmt.Fprintln(ac.errOut, " # Generate SCRAM credentials for TCP") - fmt.Fprintln(ac.errOut, " logwisp auth -u tcpuser -s") - fmt.Fprintln(ac.errOut, " logwisp auth --user=tcpuser --scram") - fmt.Fprintln(ac.errOut, " ") - fmt.Fprintln(ac.errOut, " # Generate bearer token") - fmt.Fprintln(ac.errOut, " logwisp auth -k -l 64") - fmt.Fprintln(ac.errOut, " logwisp auth --token --length=64") - fmt.Fprintln(ac.errOut, "\nOptions:") - cmd.PrintDefaults() - } - - if err := cmd.Parse(args); err != nil { - return err - } - - // Check for unparsed arguments - if cmd.NArg() > 0 { - return fmt.Errorf("unexpected argument(s): %s", strings.Join(cmd.Args(), " ")) - } - - // Merge short and long form values - finalUsername := coalesceString(*username, *usernameLong) - finalPassword := coalesceString(*password, *passwordLong) - finalAuthType := coalesceString(*authType, *authTypeLong) - finalGenToken := coalesceBool(*genToken, *genTokenLong) - finalTokenLen := coalesceInt(*tokenLen, *tokenLenLong, core.DefaultTokenLength) - finalUseScram := coalesceBool(*useScram, *useScramLong) - finalUseBasic := coalesceBool(*useBasic, *useBasicLong) - finalMigrate := coalesceBool(*migrate, *migrateLong) - - // Handle migration mode - if finalMigrate { - if *phcHash == "" || finalUsername == "" || finalPassword == "" { - return fmt.Errorf("--migrate requires --user, --password, and --phc flags") - } - return ac.migrateToScram(finalUsername, finalPassword, *phcHash) - } - - // Determine auth type from flags - if finalGenToken || finalAuthType == "token" { - return ac.generateToken(finalTokenLen) - } - - // Determine credential type - credType := "basic" // default - - // Check explicit type flags - if finalUseScram || finalAuthType == "scram" { - credType = "scram" - } else if finalUseBasic || finalAuthType == "basic" { - credType = "basic" - } else if finalAuthType != "" { - return fmt.Errorf("invalid auth type: %s (valid: basic, scram, token)", finalAuthType) - } - - // Username required for password-based auth - if finalUsername == "" { - cmd.Usage() - return fmt.Errorf("username required for %s auth generation", credType) - } - - return ac.generatePasswordHash(finalUsername, finalPassword, credType) -} - -func (ac *AuthCommand) Description() string { - return "Generate authentication credentials (passwords, tokens, SCRAM)" -} - -func (ac *AuthCommand) Help() string { - return `Auth Command - Generate authentication credentials for LogWisp - -Usage: - logwisp auth [options] - -Authentication Types: - HTTP/HTTPS Sources & Sinks (TLS required): - - Basic Auth: Username/password with Argon2id hashing - - Bearer Token: Random cryptographic tokens - - TCP Sources & Sinks (No TLS): - - SCRAM: Argon2-SCRAM-SHA256 for plaintext connections - -Options: - -u, --user Username for credential generation - -p, --password Password (will prompt if not provided) - -t, --type Auth type: "basic", "scram", or "token" - -b, --basic Generate basic auth credentials (HTTP/HTTPS) - -s, --scram Generate SCRAM credentials (TCP) - -k, --token Generate random bearer token - -l, --length Token length in bytes (default: 32) - -Examples: -Examples: - # Generate basic auth hash for HTTP/HTTPS (with TLS) - logwisp auth -u admin -b - logwisp auth --user=admin --basic - - # Generate SCRAM credentials for TCP (without TLS) - logwisp auth -u tcpuser -s - logwisp auth --user=tcpuser --type=scram - - # Generate 64-byte bearer token - logwisp auth -k -l 64 - logwisp auth --token --length=64 - - # Convert existing basic auth to SCRAM (HTTPS to TCP conversion) - logwisp auth -u admin -m --phc='$argon2id$v=19$m=65536...' --password='secret' - -Output: - The command outputs configuration snippets ready to paste into logwisp.toml - and the raw credential values for external auth files. - -Security Notes: - - Basic auth and tokens require TLS encryption for HTTP connections - - SCRAM provides authentication but NOT encryption for TCP connections - - Use strong passwords (12+ characters with mixed case, numbers, symbols) - - Store credentials securely and never commit them to version control -` -} - -func (ac *AuthCommand) generatePasswordHash(username, password, credType string) error { - // Get password if not provided - if password == "" { - var err error - password, err = ac.promptForPassword() - if err != nil { - return err - } - } - - switch credType { - case "basic": - return ac.generateBasicAuth(username, password) - case "scram": - return ac.generateScramAuth(username, password) - default: - return fmt.Errorf("invalid credential type: %s", credType) - } -} - -// promptForPassword handles password prompting with confirmation -func (ac *AuthCommand) promptForPassword() (string, error) { - pass1 := ac.promptPassword("Enter password: ") - pass2 := ac.promptPassword("Confirm password: ") - if pass1 != pass2 { - return "", fmt.Errorf("passwords don't match") - } - return pass1, nil -} - -func (ac *AuthCommand) promptPassword(prompt string) string { - fmt.Fprint(ac.errOut, prompt) - password, err := term.ReadPassword(syscall.Stdin) - fmt.Fprintln(ac.errOut) - if err != nil { - fmt.Fprintf(ac.errOut, "Failed to read password: %v\n", err) - os.Exit(1) - } - return string(password) -} - -// generateBasicAuth creates Argon2id hash for HTTP basic auth -func (ac *AuthCommand) generateBasicAuth(username, password string) error { - // Generate salt - salt := make([]byte, core.Argon2SaltLen) - if _, err := rand.Read(salt); err != nil { - return fmt.Errorf("failed to generate salt: %w", err) - } - - // Generate Argon2id hash - cred, err := auth.DeriveCredential(username, password, salt, - core.Argon2Time, core.Argon2Memory, core.Argon2Threads) - if err != nil { - return fmt.Errorf("failed to derive credential: %w", err) - } - - // Output configuration snippets - fmt.Fprintln(ac.output, "\n# Basic Auth Configuration (HTTP sources/sinks)") - fmt.Fprintln(ac.output, "# REQUIRES HTTPS/TLS for security") - fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:") - fmt.Fprintln(ac.output, "") - fmt.Fprintln(ac.output, "[pipelines.auth]") - fmt.Fprintln(ac.output, `type = "basic"`) - fmt.Fprintln(ac.output, "") - fmt.Fprintln(ac.output, "[[pipelines.auth.basic_auth.users]]") - fmt.Fprintf(ac.output, "username = %q\n", username) - fmt.Fprintf(ac.output, "password_hash = %q\n\n", cred.PHCHash) - - fmt.Fprintln(ac.output, "# For external users file:") - fmt.Fprintf(ac.output, "%s:%s\n", username, cred.PHCHash) - - return nil -} - -// generateScramAuth creates Argon2id-SCRAM-SHA256 credentials for TCP -func (ac *AuthCommand) generateScramAuth(username, password string) error { - // Generate salt - salt := make([]byte, core.Argon2SaltLen) - if _, err := rand.Read(salt); err != nil { - return fmt.Errorf("failed to generate salt: %w", err) - } - - // Use internal auth package to derive SCRAM credentials - cred, err := auth.DeriveCredential(username, password, salt, - core.Argon2Time, core.Argon2Memory, core.Argon2Threads) - if err != nil { - return fmt.Errorf("failed to derive SCRAM credential: %w", err) - } - - // Output SCRAM configuration - fmt.Fprintln(ac.output, "\n# SCRAM Auth Configuration (TCP sources/sinks)") - fmt.Fprintln(ac.output, "# Provides authentication but NOT encryption") - fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:") - fmt.Fprintln(ac.output, "") - fmt.Fprintln(ac.output, "[pipelines.auth]") - fmt.Fprintln(ac.output, `type = "scram"`) - fmt.Fprintln(ac.output, "") - fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]") - fmt.Fprintf(ac.output, "username = %q\n", username) - fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey)) - fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey)) - fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt)) - fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime) - fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory) - fmt.Fprintf(ac.output, "argon_threads = %d\n\n", cred.ArgonThreads) - - fmt.Fprintln(ac.output, "# Note: SCRAM provides authentication only.") - fmt.Fprintln(ac.output, "# Use TLS/mTLS for encryption if needed.") - - return nil -} - -func (ac *AuthCommand) generateToken(length int) error { - if length < 16 { - fmt.Fprintln(ac.errOut, "Warning: tokens < 16 bytes are cryptographically weak") - } - if length > 512 { - return fmt.Errorf("token length exceeds maximum (512 bytes)") - } - - token := make([]byte, length) - if _, err := rand.Read(token); err != nil { - return fmt.Errorf("failed to generate random bytes: %w", err) - } - - b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token) - hex := fmt.Sprintf("%x", token) - - fmt.Fprintln(ac.output, "\n# Token Configuration") - fmt.Fprintln(ac.output, "# Add to logwisp.toml:") - fmt.Fprintf(ac.output, "tokens = [%q]\n\n", b64) - - fmt.Fprintln(ac.output, "# Generated Token:") - fmt.Fprintf(ac.output, "Base64: %s\n", b64) - fmt.Fprintf(ac.output, "Hex: %s\n", hex) - - return nil -} - -// migrateToScram converts basic auth PHC hash to SCRAM credentials -func (ac *AuthCommand) migrateToScram(username, password, phcHash string) error { - // CHANGED: Moved from internal/auth to CLI command layer - cred, err := auth.MigrateFromPHC(username, password, phcHash) - if err != nil { - return fmt.Errorf("migration failed: %w", err) - } - - // Output SCRAM configuration (reuse format from generateScramAuth) - fmt.Fprintln(ac.output, "\n# Migrated SCRAM Credentials") - fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:") - fmt.Fprintln(ac.output, "") - fmt.Fprintln(ac.output, "[pipelines.auth]") - fmt.Fprintln(ac.output, `type = "scram"`) - fmt.Fprintln(ac.output, "") - fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]") - fmt.Fprintf(ac.output, "username = %q\n", username) - fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey)) - fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey)) - fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt)) - fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime) - fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory) - fmt.Fprintf(ac.output, "argon_threads = %d\n", cred.ArgonThreads) - - return nil -} \ No newline at end of file diff --git a/src/cmd/logwisp/commands/help.go b/src/cmd/logwisp/commands/help.go index 26bac8f..c10c9f3 100644 --- a/src/cmd/logwisp/commands/help.go +++ b/src/cmd/logwisp/commands/help.go @@ -7,6 +7,7 @@ import ( "strings" ) +// generalHelpTemplate is the default help message shown when no specific command is requested. const generalHelpTemplate = `LogWisp: A flexible log transport and processing tool. Usage: @@ -37,9 +38,6 @@ Configuration Sources (Precedence: CLI > Env > File > Defaults): - TOML configuration file is the primary method Examples: - # Generate password for admin user - logwisp auth -u admin - # Start service with custom config logwisp -c /etc/logwisp/prod.toml @@ -49,17 +47,17 @@ Examples: For detailed configuration options, please refer to the documentation. ` -// HelpCommand handles help display +// HelpCommand handles the display of general or command-specific help messages. type HelpCommand struct { router *CommandRouter } -// NewHelpCommand creates a new help command +// NewHelpCommand creates a new help command handler. func NewHelpCommand(router *CommandRouter) *HelpCommand { return &HelpCommand{router: router} } -// Execute displays help information +// Execute displays the appropriate help message based on the provided arguments. func (c *HelpCommand) Execute(args []string) error { // Check if help is requested for a specific command if len(args) > 0 && args[0] != "" { @@ -78,7 +76,27 @@ func (c *HelpCommand) Execute(args []string) error { return nil } -// formatCommandList creates a formatted list of available commands +// Description returns a brief one-line description of the command. +func (c *HelpCommand) Description() string { + return "Display help information" +} + +// Help returns the detailed help text for the 'help' command itself. +func (c *HelpCommand) Help() string { + return `Help Command - Display help information + +Usage: + logwisp help Show general help + logwisp help Show help for a specific command + +Examples: + logwisp help # Show general help + logwisp help auth # Show auth command help + logwisp auth --help # Alternative way to get command help +` +} + +// formatCommandList creates a formatted and aligned list of all available commands. func (c *HelpCommand) formatCommandList() string { commands := c.router.GetCommands() @@ -102,22 +120,4 @@ func (c *HelpCommand) formatCommandList() string { } return strings.Join(lines, "\n") -} - -func (c *HelpCommand) Description() string { - return "Display help information" -} - -func (c *HelpCommand) Help() string { - return `Help Command - Display help information - -Usage: - logwisp help Show general help - logwisp help Show help for a specific command - -Examples: - logwisp help # Show general help - logwisp help auth # Show auth command help - logwisp auth --help # Alternative way to get command help -` } \ No newline at end of file diff --git a/src/cmd/logwisp/commands/router.go b/src/cmd/logwisp/commands/router.go index a68df00..22fce05 100644 --- a/src/cmd/logwisp/commands/router.go +++ b/src/cmd/logwisp/commands/router.go @@ -6,26 +6,25 @@ import ( "os" ) -// Handler defines the interface for subcommands +// Handler defines the interface required for all subcommands. type Handler interface { Execute(args []string) error Description() string Help() string } -// CommandRouter handles subcommand routing before main app initialization +// CommandRouter handles the routing of CLI arguments to the appropriate subcommand handler. type CommandRouter struct { commands map[string]Handler } -// NewCommandRouter creates and initializes the command router +// NewCommandRouter creates and initializes the command router with all available commands. func NewCommandRouter() *CommandRouter { router := &CommandRouter{ commands: make(map[string]Handler), } // Register available commands - router.commands["auth"] = NewAuthCommand() router.commands["tls"] = NewTLSCommand() router.commands["version"] = NewVersionCommand() router.commands["help"] = NewHelpCommand(router) @@ -33,7 +32,7 @@ func NewCommandRouter() *CommandRouter { return router } -// Route checks for and executes subcommands +// Route checks for and executes a subcommand based on the provided CLI arguments. func (r *CommandRouter) Route(args []string) (bool, error) { if len(args) < 2 { return false, nil // No command specified, let main app continue @@ -69,18 +68,18 @@ func (r *CommandRouter) Route(args []string) (bool, error) { return true, handler.Execute(args[2:]) } -// GetCommand returns a command handler by name +// GetCommand returns a specific command handler by its name. func (r *CommandRouter) GetCommand(name string) (Handler, bool) { cmd, exists := r.commands[name] return cmd, exists } -// GetCommands returns all registered commands +// GetCommands returns a map of all registered commands. func (r *CommandRouter) GetCommands() map[string]Handler { return r.commands } -// ShowCommands displays available subcommands +// ShowCommands displays a list of available subcommands to stderr. func (r *CommandRouter) ShowCommands() { for name, handler := range r.commands { fmt.Fprintf(os.Stderr, " %-10s %s\n", name, handler.Description()) @@ -88,7 +87,7 @@ func (r *CommandRouter) ShowCommands() { fmt.Fprintln(os.Stderr, "\nUse 'logwisp --help' for command-specific help") } -// Helper functions to merge short and long options +// coalesceString returns the first non-empty string from a list of arguments. func coalesceString(values ...string) string { for _, v := range values { if v != "" { @@ -98,6 +97,7 @@ func coalesceString(values ...string) string { return "" } +// coalesceInt returns the first non-default integer from a list of arguments. func coalesceInt(primary, secondary, defaultVal int) int { if primary != defaultVal { return primary @@ -108,6 +108,7 @@ func coalesceInt(primary, secondary, defaultVal int) int { return defaultVal } +// coalesceBool returns true if any of the boolean arguments is true. func coalesceBool(values ...bool) bool { for _, v := range values { if v { diff --git a/src/cmd/logwisp/commands/tls.go b/src/cmd/logwisp/commands/tls.go index 562d4d0..601726a 100644 --- a/src/cmd/logwisp/commands/tls.go +++ b/src/cmd/logwisp/commands/tls.go @@ -17,11 +17,13 @@ import ( "time" ) +// TLSCommand handles the generation of TLS certificates. type TLSCommand struct { output io.Writer errOut io.Writer } +// NewTLSCommand creates a new TLS command handler. func NewTLSCommand() *TLSCommand { return &TLSCommand{ output: os.Stdout, @@ -29,6 +31,7 @@ func NewTLSCommand() *TLSCommand { } } +// Execute parses flags and routes to the appropriate certificate generation function. func (tc *TLSCommand) Execute(args []string) error { cmd := flag.NewFlagSet("tls", flag.ContinueOnError) cmd.SetOutput(tc.errOut) @@ -133,10 +136,12 @@ func (tc *TLSCommand) Execute(args []string) error { } } +// Description returns a brief one-line description of the command. func (tc *TLSCommand) Description() string { return "Generate TLS certificates (CA, server, client, self-signed)" } +// Help returns the detailed help text for the command. func (tc *TLSCommand) Help() string { return `TLS Command - Generate TLS certificates for LogWisp @@ -195,7 +200,7 @@ Security Notes: ` } -// Create and manage private CA +// generateCA creates a new Certificate Authority (CA) certificate and private key. func (tc *TLSCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error { // Generate RSA key priv, err := rsa.GenerateKey(rand.Reader, bits) @@ -250,28 +255,7 @@ func (tc *TLSCommand) generateCA(cn, org, country string, days, bits int, certFi return nil } -func parseHosts(hostList string) ([]string, []net.IP) { - var dnsNames []string - var ipAddrs []net.IP - - if hostList == "" { - return dnsNames, ipAddrs - } - - hosts := strings.Split(hostList, ",") - for _, h := range hosts { - h = strings.TrimSpace(h) - if ip := net.ParseIP(h); ip != nil { - ipAddrs = append(ipAddrs, ip) - } else { - dnsNames = append(dnsNames, h) - } - } - - return dnsNames, ipAddrs -} - -// Generate self-signed certificate +// generateSelfSigned creates a new self-signed server certificate and private key. func (tc *TLSCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error { // 1. Generate an RSA private key with the specified bit size priv, err := rsa.GenerateKey(rand.Reader, bits) @@ -338,7 +322,7 @@ func (tc *TLSCommand) generateSelfSigned(cn, org, country, hosts string, days, b return nil } -// Generate server cert with CA +// generateServerCert creates a new server certificate signed by a provided CA. func (tc *TLSCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { caCert, caKey, err := loadCA(caFile, caKeyFile) if err != nil { @@ -401,7 +385,7 @@ func (tc *TLSCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyF return nil } -// Generate client cert with CA +// generateClientCert creates a new client certificate signed by a provided CA for mTLS. func (tc *TLSCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { caCert, caKey, err := loadCA(caFile, caKeyFile) if err != nil { @@ -458,7 +442,7 @@ func (tc *TLSCommand) generateClientCert(cn, org, country, caFile, caKeyFile str return nil } -// Load cert with CA +// loadCA reads and parses a CA certificate and its corresponding private key from files. func loadCA(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) { // Load CA certificate certPEM, err := os.ReadFile(certFile) @@ -517,6 +501,7 @@ func loadCA(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error return caCert, caKey, nil } +// saveCert saves a DER-encoded certificate to a file in PEM format. func saveCert(filename string, certDER []byte) error { certFile, err := os.Create(filename) if err != nil { @@ -539,6 +524,7 @@ func saveCert(filename string, certDER []byte) error { return nil } +// saveKey saves an RSA private key to a file in PEM format with restricted permissions. func saveKey(filename string, key *rsa.PrivateKey) error { keyFile, err := os.Create(filename) if err != nil { @@ -560,4 +546,26 @@ func saveKey(filename string, key *rsa.PrivateKey) error { } return nil +} + +// parseHosts splits a comma-separated string of hosts into slices of DNS names and IP addresses. +func parseHosts(hostList string) ([]string, []net.IP) { + var dnsNames []string + var ipAddrs []net.IP + + if hostList == "" { + return dnsNames, ipAddrs + } + + hosts := strings.Split(hostList, ",") + for _, h := range hosts { + h = strings.TrimSpace(h) + if ip := net.ParseIP(h); ip != nil { + ipAddrs = append(ipAddrs, ip) + } else { + dnsNames = append(dnsNames, h) + } + } + + return dnsNames, ipAddrs } \ No newline at end of file diff --git a/src/cmd/logwisp/commands/version.go b/src/cmd/logwisp/commands/version.go index 073e3ff..2fa1038 100644 --- a/src/cmd/logwisp/commands/version.go +++ b/src/cmd/logwisp/commands/version.go @@ -7,23 +7,26 @@ import ( "logwisp/src/internal/version" ) -// VersionCommand handles version display +// VersionCommand handles the display of the application's version information. type VersionCommand struct{} -// NewVersionCommand creates a new version command +// NewVersionCommand creates a new version command handler. func NewVersionCommand() *VersionCommand { return &VersionCommand{} } +// Execute prints the detailed version string to stdout. func (c *VersionCommand) Execute(args []string) error { fmt.Println(version.String()) return nil } +// Description returns a brief one-line description of the command. func (c *VersionCommand) Description() string { return "Show version information" } +// Help returns the detailed help text for the command. func (c *VersionCommand) Help() string { return `Version Command - Show LogWisp version information diff --git a/src/cmd/logwisp/main.go b/src/cmd/logwisp/main.go index 58ae93c..e2a558b 100644 --- a/src/cmd/logwisp/main.go +++ b/src/cmd/logwisp/main.go @@ -18,8 +18,10 @@ import ( "github.com/lixenwraith/log" ) +// logger is the global logger instance for the application. var logger *log.Logger +// main is the entry point for the LogWisp application. func main() { // Handle subcommands before any config loading // This prevents flag conflicts with lixenwraith/config @@ -185,6 +187,7 @@ func main() { logger.Info("msg", "Shutdown complete") } +// shutdownLogger gracefully shuts down the global logger. func shutdownLogger() { if logger != nil { if err := logger.Shutdown(2 * time.Second); err != nil { diff --git a/src/cmd/logwisp/output.go b/src/cmd/logwisp/output.go index 1dd17c5..0e49724 100644 --- a/src/cmd/logwisp/output.go +++ b/src/cmd/logwisp/output.go @@ -8,7 +8,7 @@ import ( "sync" ) -// Manages all application output respecting quiet mode +// OutputHandler manages all application output, respecting the global quiet mode. type OutputHandler struct { quiet bool mu sync.RWMutex @@ -16,10 +16,10 @@ type OutputHandler struct { stderr io.Writer } -// Global output handler instance +// output is the global instance of the OutputHandler. var output *OutputHandler -// Initializes the global output handler +// InitOutputHandler initializes the global output handler. func InitOutputHandler(quiet bool) { output = &OutputHandler{ quiet: quiet, @@ -28,7 +28,32 @@ func InitOutputHandler(quiet bool) { } } -// Writes to stdout if not in quiet mode +// Print writes to stdout. +func Print(format string, args ...any) { + if output != nil { + output.Print(format, args...) + } +} + +// Error writes to stderr. +func Error(format string, args ...any) { + if output != nil { + output.Error(format, args...) + } +} + +// FatalError writes to stderr and exits the application. +func FatalError(code int, format string, args ...any) { + if output != nil { + output.FatalError(code, format, args...) + } else { + // Fallback if handler not initialized + fmt.Fprintf(os.Stderr, format, args...) + os.Exit(code) + } +} + +// Print writes a formatted string to stdout if not in quiet mode. func (o *OutputHandler) Print(format string, args ...any) { o.mu.RLock() defer o.mu.RUnlock() @@ -38,7 +63,7 @@ func (o *OutputHandler) Print(format string, args ...any) { } } -// Writes to stderr if not in quiet mode +// Error writes a formatted string to stderr if not in quiet mode. func (o *OutputHandler) Error(format string, args ...any) { o.mu.RLock() defer o.mu.RUnlock() @@ -48,45 +73,22 @@ func (o *OutputHandler) Error(format string, args ...any) { } } -// Writes to stderr and exits (respects quiet mode) +// FatalError writes a formatted string to stderr and exits with the given code. func (o *OutputHandler) FatalError(code int, format string, args ...any) { o.Error(format, args...) os.Exit(code) } -// Returns the current quiet mode status +// IsQuiet returns the current quiet mode status. func (o *OutputHandler) IsQuiet() bool { o.mu.RLock() defer o.mu.RUnlock() return o.quiet } -// Updates quiet mode (useful for testing) +// SetQuiet updates the quiet mode status. func (o *OutputHandler) SetQuiet(quiet bool) { o.mu.Lock() defer o.mu.Unlock() o.quiet = quiet -} - -// Helper functions for global output handler -func Print(format string, args ...any) { - if output != nil { - output.Print(format, args...) - } -} - -func Error(format string, args ...any) { - if output != nil { - output.Error(format, args...) - } -} - -func FatalError(code int, format string, args ...any) { - if output != nil { - output.FatalError(code, format, args...) - } else { - // Fallback if handler not initialized - fmt.Fprintf(os.Stderr, format, args...) - os.Exit(code) - } } \ No newline at end of file diff --git a/src/cmd/logwisp/reload.go b/src/cmd/logwisp/reload.go index f6d6a6d..065c4f0 100644 --- a/src/cmd/logwisp/reload.go +++ b/src/cmd/logwisp/reload.go @@ -17,7 +17,7 @@ import ( "github.com/lixenwraith/log" ) -// Handles configuration hot reload +// ReloadManager handles the configuration hot-reloading functionality. type ReloadManager struct { configPath string service *service.Service @@ -35,7 +35,7 @@ type ReloadManager struct { statusReporterMu sync.Mutex } -// Creates a new reload manager +// NewReloadManager creates a new reload manager. func NewReloadManager(configPath string, initialCfg *config.Config, logger *log.Logger) *ReloadManager { return &ReloadManager{ configPath: configPath, @@ -45,7 +45,7 @@ func NewReloadManager(configPath string, initialCfg *config.Config, logger *log. } } -// Begins watching for configuration changes +// Start bootstraps the initial service and begins watching for configuration changes. func (rm *ReloadManager) Start(ctx context.Context) error { // Bootstrap initial service svc, err := bootstrapService(ctx, rm.cfg) @@ -90,7 +90,75 @@ func (rm *ReloadManager) Start(ctx context.Context) error { return nil } -// Monitors configuration changes +// Shutdown gracefully stops the reload manager and the currently active service. +func (rm *ReloadManager) Shutdown() { + rm.logger.Info("msg", "Shutting down reload manager") + + // Stop status reporter + rm.stopStatusReporter() + + // Stop watching + close(rm.shutdownCh) + rm.wg.Wait() + + // Stop config watching + if rm.lcfg != nil { + rm.lcfg.StopAutoUpdate() + } + + // Shutdown current services + rm.mu.RLock() + currentService := rm.service + rm.mu.RUnlock() + + if currentService != nil { + rm.logger.Info("msg", "Shutting down service") + currentService.Shutdown() + } +} + +// GetService returns the currently active service instance in a thread-safe manner. +func (rm *ReloadManager) GetService() *service.Service { + rm.mu.RLock() + defer rm.mu.RUnlock() + return rm.service +} + +// triggerReload initiates the configuration reload process. +func (rm *ReloadManager) triggerReload(ctx context.Context) { + // Prevent concurrent reloads + rm.reloadingMu.Lock() + if rm.isReloading { + rm.reloadingMu.Unlock() + rm.logger.Debug("msg", "Reload already in progress, skipping") + return + } + rm.isReloading = true + rm.reloadingMu.Unlock() + + defer func() { + rm.reloadingMu.Lock() + rm.isReloading = false + rm.reloadingMu.Unlock() + }() + + rm.logger.Info("msg", "Starting configuration hot reload") + + // Create reload context with timeout + reloadCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := rm.performReload(reloadCtx); err != nil { + rm.logger.Error("msg", "Hot reload failed", + "error", err, + "action", "keeping current configuration and services") + return + } + + rm.logger.Info("msg", "Configuration hot reload completed successfully") +} + +// watchLoop is the main goroutine that monitors for configuration file changes. func (rm *ReloadManager) watchLoop(ctx context.Context) { defer rm.wg.Done() @@ -144,91 +212,7 @@ func (rm *ReloadManager) watchLoop(ctx context.Context) { } } -// Verify file permissions for security -func verifyFilePermissions(path string) error { - info, err := os.Stat(path) - if err != nil { - return fmt.Errorf("failed to stat config file: %w", err) - } - - // Extract file mode and system stats - mode := info.Mode() - stat, ok := info.Sys().(*syscall.Stat_t) - if !ok { - return fmt.Errorf("unable to get file ownership info") - } - - // Check ownership - must be current user or root - currentUID := uint32(os.Getuid()) - if stat.Uid != currentUID && stat.Uid != 0 { - return fmt.Errorf("config file owned by uid %d, expected %d or 0", stat.Uid, currentUID) - } - - // Check permissions - must not be writable by group or other - perm := mode.Perm() - if perm&0022 != 0 { - // Group or other has write permission - return fmt.Errorf("insecure permissions %04o - file must not be writable by group/other", perm) - } - - return nil -} - -// Determines if a config change requires service reload -func (rm *ReloadManager) shouldReload(path string) bool { - // Pipeline changes always require reload - if strings.HasPrefix(path, "pipelines.") || path == "pipelines" { - return true - } - - // Logging changes don't require service reload - if strings.HasPrefix(path, "logging.") { - return false - } - - // Status reporter changes - if path == "disable_status_reporter" { - return true - } - - return false -} - -// Performs the actual reload -func (rm *ReloadManager) triggerReload(ctx context.Context) { - // Prevent concurrent reloads - rm.reloadingMu.Lock() - if rm.isReloading { - rm.reloadingMu.Unlock() - rm.logger.Debug("msg", "Reload already in progress, skipping") - return - } - rm.isReloading = true - rm.reloadingMu.Unlock() - - defer func() { - rm.reloadingMu.Lock() - rm.isReloading = false - rm.reloadingMu.Unlock() - }() - - rm.logger.Info("msg", "Starting configuration hot reload") - - // Create reload context with timeout - reloadCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - if err := rm.performReload(reloadCtx); err != nil { - rm.logger.Error("msg", "Hot reload failed", - "error", err, - "action", "keeping current configuration and services") - return - } - - rm.logger.Info("msg", "Configuration hot reload completed successfully") -} - -// Executes the reload process +// performReload executes the steps to validate and apply a new configuration. func (rm *ReloadManager) performReload(ctx context.Context) error { // Get updated config from lconfig updatedCfg, err := rm.lcfg.AsStruct() @@ -272,7 +256,57 @@ func (rm *ReloadManager) performReload(ctx context.Context) error { return nil } -// Gracefully shuts down old services +// shouldReload determines if a given configuration change requires a full service reload. +func (rm *ReloadManager) shouldReload(path string) bool { + // Pipeline changes always require reload + if strings.HasPrefix(path, "pipelines.") || path == "pipelines" { + return true + } + + // Logging changes don't require service reload + if strings.HasPrefix(path, "logging.") { + return false + } + + // Status reporter changes + if path == "disable_status_reporter" { + return true + } + + return false +} + +// verifyFilePermissions checks the ownership and permissions of the config file for security. +func verifyFilePermissions(path string) error { + info, err := os.Stat(path) + if err != nil { + return fmt.Errorf("failed to stat config file: %w", err) + } + + // Extract file mode and system stats + mode := info.Mode() + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return fmt.Errorf("unable to get file ownership info") + } + + // Check ownership - must be current user or root + currentUID := uint32(os.Getuid()) + if stat.Uid != currentUID && stat.Uid != 0 { + return fmt.Errorf("config file owned by uid %d, expected %d or 0", stat.Uid, currentUID) + } + + // Check permissions - must not be writable by group or other + perm := mode.Perm() + if perm&0022 != 0 { + // Group or other has write permission + return fmt.Errorf("insecure permissions %04o - file must not be writable by group/other", perm) + } + + return nil +} + +// shutdownOldServices gracefully shuts down the previous service instance after a successful reload. func (rm *ReloadManager) shutdownOldServices(svc *service.Service) { // Give connections time to drain rm.logger.Debug("msg", "Draining connections from old services") @@ -286,7 +320,7 @@ func (rm *ReloadManager) shutdownOldServices(svc *service.Service) { rm.logger.Debug("msg", "Old services shutdown complete") } -// Starts a new status reporter +// startStatusReporter starts a new status reporter for service. func (rm *ReloadManager) startStatusReporter(ctx context.Context, svc *service.Service) { rm.statusReporterMu.Lock() defer rm.statusReporterMu.Unlock() @@ -299,7 +333,19 @@ func (rm *ReloadManager) startStatusReporter(ctx context.Context, svc *service.S rm.logger.Debug("msg", "Started status reporter") } -// Stops old and starts new status reporter +// stopStatusReporter stops the currently running status reporter. +func (rm *ReloadManager) stopStatusReporter() { + rm.statusReporterMu.Lock() + defer rm.statusReporterMu.Unlock() + + if rm.statusReporterCancel != nil { + rm.statusReporterCancel() + rm.statusReporterCancel = nil + rm.logger.Debug("msg", "Stopped status reporter") + } +} + +// restartStatusReporter stops the old status reporter and starts a new one. func (rm *ReloadManager) restartStatusReporter(ctx context.Context, newService *service.Service) { if rm.cfg.DisableStatusReporter { // Just stop the old one if disabled @@ -322,50 +368,4 @@ func (rm *ReloadManager) restartStatusReporter(ctx context.Context, newService * go statusReporter(newService, reporterCtx) rm.logger.Debug("msg", "Started new status reporter") -} - -// Stops the status reporter -func (rm *ReloadManager) stopStatusReporter() { - rm.statusReporterMu.Lock() - defer rm.statusReporterMu.Unlock() - - if rm.statusReporterCancel != nil { - rm.statusReporterCancel() - rm.statusReporterCancel = nil - rm.logger.Debug("msg", "Stopped status reporter") - } -} - -// Stops the reload manager -func (rm *ReloadManager) Shutdown() { - rm.logger.Info("msg", "Shutting down reload manager") - - // Stop status reporter - rm.stopStatusReporter() - - // Stop watching - close(rm.shutdownCh) - rm.wg.Wait() - - // Stop config watching - if rm.lcfg != nil { - rm.lcfg.StopAutoUpdate() - } - - // Shutdown current services - rm.mu.RLock() - currentService := rm.service - rm.mu.RUnlock() - - if currentService != nil { - rm.logger.Info("msg", "Shutting down service") - currentService.Shutdown() - } -} - -// Returns the current service (thread-safe) -func (rm *ReloadManager) GetService() *service.Service { - rm.mu.RLock() - defer rm.mu.RUnlock() - return rm.service } \ No newline at end of file diff --git a/src/cmd/logwisp/signal.go b/src/cmd/logwisp/signal.go index 7d0f5ca..b7d3946 100644 --- a/src/cmd/logwisp/signal.go +++ b/src/cmd/logwisp/signal.go @@ -10,14 +10,14 @@ import ( "github.com/lixenwraith/log" ) -// Manages OS signals +// SignalHandler manages OS signals for shutdown and configuration reloads. type SignalHandler struct { reloadManager *ReloadManager logger *log.Logger sigChan chan os.Signal } -// Creates a signal handler +// NewSignalHandler creates a new signal handler. func NewSignalHandler(rm *ReloadManager, logger *log.Logger) *SignalHandler { sh := &SignalHandler{ reloadManager: rm, @@ -36,7 +36,7 @@ func NewSignalHandler(rm *ReloadManager, logger *log.Logger) *SignalHandler { return sh } -// Processes signals +// Handle blocks and processes incoming OS signals. func (sh *SignalHandler) Handle(ctx context.Context) os.Signal { for { select { @@ -58,7 +58,7 @@ func (sh *SignalHandler) Handle(ctx context.Context) os.Signal { } } -// Cleans up signal handling +// Stop cleans up the signal handling channel. func (sh *SignalHandler) Stop() { signal.Stop(sh.sigChan) close(sh.sigChan) diff --git a/src/cmd/logwisp/status.go b/src/cmd/logwisp/status.go index faba2c8..54490ab 100644 --- a/src/cmd/logwisp/status.go +++ b/src/cmd/logwisp/status.go @@ -10,7 +10,7 @@ import ( "logwisp/src/internal/service" ) -// Periodically logs service status +// statusReporter is a goroutine that periodically logs the health and statistics of the service. func statusReporter(service *service.Service, ctx context.Context) { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() @@ -60,55 +60,7 @@ func statusReporter(service *service.Service, ctx context.Context) { } } -// Logs the status of an individual pipeline -func logPipelineStatus(name string, stats map[string]any) { - statusFields := []any{ - "msg", "Pipeline status", - "pipeline", name, - } - - // Add processing statistics - if totalProcessed, ok := stats["total_processed"].(uint64); ok { - statusFields = append(statusFields, "entries_processed", totalProcessed) - } - if totalFiltered, ok := stats["total_filtered"].(uint64); ok { - statusFields = append(statusFields, "entries_filtered", totalFiltered) - } - - // Add source count - if sourceCount, ok := stats["source_count"].(int); ok { - statusFields = append(statusFields, "sources", sourceCount) - } - - // Add sink statistics - if sinks, ok := stats["sinks"].([]map[string]any); ok { - tcpConns := int64(0) - httpConns := int64(0) - - for _, sink := range sinks { - sinkType := sink["type"].(string) - if activeConns, ok := sink["active_connections"].(int64); ok { - switch sinkType { - case "tcp": - tcpConns += activeConns - case "http": - httpConns += activeConns - } - } - } - - if tcpConns > 0 { - statusFields = append(statusFields, "tcp_connections", tcpConns) - } - if httpConns > 0 { - statusFields = append(statusFields, "http_connections", httpConns) - } - } - - logger.Debug(statusFields...) -} - -// Logs the configured endpoints for a pipeline +// displayPipelineEndpoints logs the configured source and sink endpoints for a pipeline at startup. func displayPipelineEndpoints(cfg config.PipelineConfig) { // Display sink endpoints for i, sinkCfg := range cfg.Sinks { @@ -256,4 +208,52 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) { "pipeline", cfg.Name, "filter_count", len(cfg.Filters)) } +} + +// logPipelineStatus logs the detailed status and statistics of an individual pipeline. +func logPipelineStatus(name string, stats map[string]any) { + statusFields := []any{ + "msg", "Pipeline status", + "pipeline", name, + } + + // Add processing statistics + if totalProcessed, ok := stats["total_processed"].(uint64); ok { + statusFields = append(statusFields, "entries_processed", totalProcessed) + } + if totalFiltered, ok := stats["total_filtered"].(uint64); ok { + statusFields = append(statusFields, "entries_filtered", totalFiltered) + } + + // Add source count + if sourceCount, ok := stats["source_count"].(int); ok { + statusFields = append(statusFields, "sources", sourceCount) + } + + // Add sink statistics + if sinks, ok := stats["sinks"].([]map[string]any); ok { + tcpConns := int64(0) + httpConns := int64(0) + + for _, sink := range sinks { + sinkType := sink["type"].(string) + if activeConns, ok := sink["active_connections"].(int64); ok { + switch sinkType { + case "tcp": + tcpConns += activeConns + case "http": + httpConns += activeConns + } + } + } + + if tcpConns > 0 { + statusFields = append(statusFields, "tcp_connections", tcpConns) + } + if httpConns > 0 { + statusFields = append(statusFields, "http_connections", httpConns) + } + } + + logger.Debug(statusFields...) } \ No newline at end of file diff --git a/src/internal/auth/authenticator.go b/src/internal/auth/authenticator.go deleted file mode 100644 index 3640d41..0000000 --- a/src/internal/auth/authenticator.go +++ /dev/null @@ -1,213 +0,0 @@ -// FILE: logwisp/src/internal/auth/authenticator.go -package auth - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "strings" - "sync" - "time" - - "logwisp/src/internal/config" - - "github.com/lixenwraith/log" -) - -// Prevent unbounded map growth -const maxAuthTrackedIPs = 10000 - -// Handles all authentication methods for a pipeline -type Authenticator struct { - config *config.ServerAuthConfig - logger *log.Logger - tokens map[string]bool // token -> valid - mu sync.RWMutex - - // Session tracking - sessions map[string]*Session - sessionMu sync.RWMutex -} - -// TODO: only one connection per user, token, mtls -// TODO: implement tracker logic -// Represents an authenticated connection -type Session struct { - ID string - Username string - Method string // basic, token, mtls - RemoteAddr string - CreatedAt time.Time - LastActivity time.Time -} - -// Creates a new authenticator from config -func NewAuthenticator(cfg *config.ServerAuthConfig, logger *log.Logger) (*Authenticator, error) { - // SCRAM is handled by ScramManager in sources - if cfg == nil || cfg.Type == "none" || cfg.Type == "scram" { - return nil, nil - } - - a := &Authenticator{ - config: cfg, - logger: logger, - tokens: make(map[string]bool), - sessions: make(map[string]*Session), - } - - // Initialize tokens - if cfg.Type == "token" && cfg.Token != nil { - for _, token := range cfg.Token.Tokens { - a.tokens[token] = true - } - } - - // Start session cleanup - go a.sessionCleanup() - - logger.Info("msg", "Authenticator initialized", - "component", "auth", - "type", cfg.Type) - - return a, nil -} - -// Handles HTTP authentication headers -func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) { - if a == nil || a.config.Type == "none" { - return &Session{ - ID: generateSessionID(), - Method: "none", - RemoteAddr: remoteAddr, - CreatedAt: time.Now(), - }, nil - } - - var session *Session - var err error - - switch a.config.Type { - case "token": - session, err = a.authenticateToken(authHeader, remoteAddr) - default: - err = fmt.Errorf("unsupported auth type: %s", a.config.Type) - } - - if err != nil { - time.Sleep(500 * time.Millisecond) - return nil, err - } - - return session, nil -} - -func (a *Authenticator) authenticateToken(authHeader, remoteAddr string) (*Session, error) { - if !strings.HasPrefix(authHeader, "Token") { - return nil, fmt.Errorf("invalid token auth header") - } - - token := authHeader[7:] - return a.validateToken(token, remoteAddr) -} - -func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) { - // Check static tokens first - a.mu.RLock() - isValid := a.tokens[token] - a.mu.RUnlock() - - if !isValid { - return nil, fmt.Errorf("invalid token") - } - - session := &Session{ - ID: generateSessionID(), - Method: "token", - RemoteAddr: remoteAddr, - CreatedAt: time.Now(), - LastActivity: time.Now(), - } - a.storeSession(session) - return session, nil -} - -func (a *Authenticator) storeSession(session *Session) { - a.sessionMu.Lock() - a.sessions[session.ID] = session - a.sessionMu.Unlock() - - a.logger.Info("msg", "Session created", - "component", "auth", - "session_id", session.ID, - "username", session.Username, - "method", session.Method, - "remote_addr", session.RemoteAddr) -} - -func (a *Authenticator) sessionCleanup() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for range ticker.C { - a.sessionMu.Lock() - now := time.Now() - for id, session := range a.sessions { - if now.Sub(session.LastActivity) > 30*time.Minute { - delete(a.sessions, id) - a.logger.Debug("msg", "Session expired", - "component", "auth", - "session_id", id) - } - } - a.sessionMu.Unlock() - } -} - -func generateSessionID() string { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - // Fallback to a less secure method if crypto/rand fails - return fmt.Sprintf("fallback-%d", time.Now().UnixNano()) - } - return base64.URLEncoding.EncodeToString(b) -} - -// Checks if a session is still valid -func (a *Authenticator) ValidateSession(sessionID string) bool { - if a == nil { - return true - } - - a.sessionMu.RLock() - session, exists := a.sessions[sessionID] - a.sessionMu.RUnlock() - - if !exists { - return false - } - - // Update activity - a.sessionMu.Lock() - session.LastActivity = time.Now() - a.sessionMu.Unlock() - - return true -} - -// Returns authentication statistics -func (a *Authenticator) GetStats() map[string]any { - if a == nil { - return map[string]any{"enabled": false} - } - - a.sessionMu.RLock() - sessionCount := len(a.sessions) - a.sessionMu.RUnlock() - - return map[string]any{ - "enabled": true, - "type": a.config.Type, - "active_sessions": sessionCount, - "static_tokens": len(a.tokens), - } -} \ No newline at end of file diff --git a/src/internal/auth/scram_client.go b/src/internal/auth/scram_client.go deleted file mode 100644 index 3722269..0000000 --- a/src/internal/auth/scram_client.go +++ /dev/null @@ -1,106 +0,0 @@ -// FILE: src/internal/auth/scram_client.go -package auth - -import ( - "crypto/rand" - "crypto/sha256" - "crypto/subtle" - "encoding/base64" - "fmt" - - "golang.org/x/crypto/argon2" -) - -// Client handles SCRAM client-side authentication -type ScramClient struct { - Username string - Password string - - // Handshake state - clientNonce string - serverFirst *ServerFirst - authMessage string - serverKey []byte -} - -// NewScramClient creates SCRAM client -func NewScramClient(username, password string) *ScramClient { - return &ScramClient{ - Username: username, - Password: password, - } -} - -// StartAuthentication generates ClientFirst message -func (c *ScramClient) StartAuthentication() (*ClientFirst, error) { - // Generate client nonce - nonce := make([]byte, 32) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("failed to generate nonce: %w", err) - } - c.clientNonce = base64.StdEncoding.EncodeToString(nonce) - - return &ClientFirst{ - Username: c.Username, - ClientNonce: c.clientNonce, - }, nil -} - -// ProcessServerFirst handles server challenge -func (c *ScramClient) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) { - c.serverFirst = msg - - // Decode salt - salt, err := base64.StdEncoding.DecodeString(msg.Salt) - if err != nil { - return nil, fmt.Errorf("invalid salt encoding: %w", err) - } - - // Derive keys using Argon2id - saltedPassword := argon2.IDKey([]byte(c.Password), salt, - msg.ArgonTime, msg.ArgonMemory, msg.ArgonThreads, 32) - - clientKey := computeHMAC(saltedPassword, []byte("Client Key")) - serverKey := computeHMAC(saltedPassword, []byte("Server Key")) - storedKey := sha256.Sum256(clientKey) - - // Build auth message - clientFirstBare := fmt.Sprintf("u=%s,n=%s", c.Username, c.clientNonce) - clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce) - c.authMessage = clientFirstBare + "," + msg.Marshal() + "," + clientFinalBare - - // Compute client proof - clientSignature := computeHMAC(storedKey[:], []byte(c.authMessage)) - clientProof := xorBytes(clientKey, clientSignature) - - // Store server key for verification - c.serverKey = serverKey - - return &ClientFinal{ - FullNonce: msg.FullNonce, - ClientProof: base64.StdEncoding.EncodeToString(clientProof), - }, nil -} - -// VerifyServerFinal validates server signature -func (c *ScramClient) VerifyServerFinal(msg *ServerFinal) error { - if c.authMessage == "" || c.serverKey == nil { - return fmt.Errorf("invalid handshake state") - } - - // Compute expected server signature - expectedSig := computeHMAC(c.serverKey, []byte(c.authMessage)) - - // Decode received signature - receivedSig, err := base64.StdEncoding.DecodeString(msg.ServerSignature) - if err != nil { - return fmt.Errorf("invalid signature encoding: %w", err) - } - - // ☢ SECURITY: Constant-time comparison - if subtle.ConstantTimeCompare(expectedSig, receivedSig) != 1 { - return fmt.Errorf("server authentication failed") - } - - return nil -} \ No newline at end of file diff --git a/src/internal/auth/scram_credential.go b/src/internal/auth/scram_credential.go deleted file mode 100644 index 1851a40..0000000 --- a/src/internal/auth/scram_credential.go +++ /dev/null @@ -1,108 +0,0 @@ -// FILE: src/internal/auth/scram_credential.go -package auth - -import ( - "crypto/hmac" - "crypto/sha256" - "crypto/subtle" - "encoding/base64" - "fmt" - "strings" - - "logwisp/src/internal/core" - - "golang.org/x/crypto/argon2" -) - -// Credential stores SCRAM authentication data -type Credential struct { - Username string - Salt []byte // 16+ bytes - ArgonTime uint32 // e.g., 3 - ArgonMemory uint32 // e.g., 64*1024 KiB - ArgonThreads uint8 // e.g., 4 - StoredKey []byte // SHA256(ClientKey) - ServerKey []byte // For server auth - PHCHash string -} - -// DeriveCredential creates SCRAM credential from password -func DeriveCredential(username, password string, salt []byte, time, memory uint32, threads uint8) (*Credential, error) { - if len(salt) < 16 { - return nil, fmt.Errorf("salt must be at least 16 bytes") - } - - // Derive salted password using Argon2id - saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, core.Argon2KeyLen) - - // Construct PHC format for basic auth compatibility - saltB64 := base64.RawStdEncoding.EncodeToString(salt) - hashB64 := base64.RawStdEncoding.EncodeToString(saltedPassword) - phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", - argon2.Version, memory, time, threads, saltB64, hashB64) - - // Derive keys - clientKey := computeHMAC(saltedPassword, []byte("Client Key")) - serverKey := computeHMAC(saltedPassword, []byte("Server Key")) - storedKey := sha256.Sum256(clientKey) - - return &Credential{ - Username: username, - Salt: salt, - ArgonTime: time, - ArgonMemory: memory, - ArgonThreads: threads, - StoredKey: storedKey[:], - ServerKey: serverKey, - PHCHash: phcHash, - }, nil -} - -// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential -func MigrateFromPHC(username, password, phcHash string) (*Credential, error) { - // Parse PHC: $argon2id$v=19$m=65536,t=3,p=4$salt$hash - parts := strings.Split(phcHash, "$") - if len(parts) != 6 || parts[1] != "argon2id" { - return nil, fmt.Errorf("invalid PHC format") - } - - var memory, time uint32 - var threads uint8 - fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads) - - salt, err := base64.RawStdEncoding.DecodeString(parts[4]) - if err != nil { - return nil, fmt.Errorf("invalid salt encoding: %w", err) - } - - expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) - if err != nil { - return nil, fmt.Errorf("invalid hash encoding: %w", err) - } - - // Verify password matches - computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash))) - if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 { - return nil, fmt.Errorf("password verification failed") - } - - // Now derive SCRAM credential - return DeriveCredential(username, password, salt, time, memory, threads) -} - -func computeHMAC(key, message []byte) []byte { - mac := hmac.New(sha256.New, key) - mac.Write(message) - return mac.Sum(nil) -} - -func xorBytes(a, b []byte) []byte { - if len(a) != len(b) { - panic("xor length mismatch") - } - result := make([]byte, len(a)) - for i := range a { - result[i] = a[i] ^ b[i] - } - return result -} \ No newline at end of file diff --git a/src/internal/auth/scram_manager.go b/src/internal/auth/scram_manager.go deleted file mode 100644 index 230497a..0000000 --- a/src/internal/auth/scram_manager.go +++ /dev/null @@ -1,83 +0,0 @@ -// FILE: src/internal/auth/scram_manager.go -package auth - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - - "logwisp/src/internal/config" -) - -// ScramManager provides high-level SCRAM operations with rate limiting -type ScramManager struct { - server *ScramServer -} - -// NewScramManager creates SCRAM manager -func NewScramManager(scramAuthCfg *config.ScramAuthConfig) *ScramManager { - manager := &ScramManager{ - server: NewScramServer(), - } - - // Load users from SCRAM config - for _, user := range scramAuthCfg.Users { - storedKey, err := base64.StdEncoding.DecodeString(user.StoredKey) - if err != nil { - // Skip user with invalid stored key - continue - } - - serverKey, err := base64.StdEncoding.DecodeString(user.ServerKey) - if err != nil { - // Skip user with invalid server key - continue - } - - salt, err := base64.StdEncoding.DecodeString(user.Salt) - if err != nil { - // Skip user with invalid salt - continue - } - - cred := &Credential{ - Username: user.Username, - StoredKey: storedKey, - ServerKey: serverKey, - Salt: salt, - ArgonTime: user.ArgonTime, - ArgonMemory: user.ArgonMemory, - ArgonThreads: user.ArgonThreads, - } - manager.server.AddCredential(cred) - } - - return manager -} - -// RegisterUser creates new user credential -func (sm *ScramManager) RegisterUser(username, password string) error { - salt := make([]byte, 16) - if _, err := rand.Read(salt); err != nil { - return fmt.Errorf("salt generation failed: %w", err) - } - - cred, err := DeriveCredential(username, password, salt, - sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads) - if err != nil { - return err - } - - sm.server.AddCredential(cred) - return nil -} - -// HandleClientFirst wraps server's HandleClientFirst -func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { - return sm.server.HandleClientFirst(msg) -} - -// HandleClientFinal wraps server's HandleClientFinal -func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { - return sm.server.HandleClientFinal(msg) -} \ No newline at end of file diff --git a/src/internal/auth/scram_message.go b/src/internal/auth/scram_message.go deleted file mode 100644 index 37f0842..0000000 --- a/src/internal/auth/scram_message.go +++ /dev/null @@ -1,38 +0,0 @@ -// FILE: src/internal/auth/scram_message.go -package auth - -import ( - "fmt" -) - -// ClientFirst initiates authentication -type ClientFirst struct { - Username string `json:"u"` - ClientNonce string `json:"n"` -} - -// ServerFirst contains server challenge -type ServerFirst struct { - FullNonce string `json:"r"` // client_nonce + server_nonce - Salt string `json:"s"` // base64 - ArgonTime uint32 `json:"t"` - ArgonMemory uint32 `json:"m"` - ArgonThreads uint8 `json:"p"` -} - -// ClientFinal contains client proof -type ClientFinal struct { - FullNonce string `json:"r"` - ClientProof string `json:"p"` // base64 -} - -// ServerFinal contains server signature for mutual auth -type ServerFinal struct { - ServerSignature string `json:"v"` // base64 - SessionID string `json:"sid,omitempty"` -} - -func (sf *ServerFirst) Marshal() string { - return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d", - sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads) -} \ No newline at end of file diff --git a/src/internal/auth/scram_protocol.go b/src/internal/auth/scram_protocol.go deleted file mode 100644 index 03f384f..0000000 --- a/src/internal/auth/scram_protocol.go +++ /dev/null @@ -1,117 +0,0 @@ -// FILE: src/internal/auth/scram_protocol.go -package auth - -import ( - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/lixenwraith/log" - "github.com/panjf2000/gnet/v2" -) - -// ScramProtocolHandler handles SCRAM message exchange for TCP -type ScramProtocolHandler struct { - manager *ScramManager - logger *log.Logger -} - -// NewScramProtocolHandler creates protocol handler -func NewScramProtocolHandler(manager *ScramManager, logger *log.Logger) *ScramProtocolHandler { - return &ScramProtocolHandler{ - manager: manager, - logger: logger, - } -} - -// HandleAuthMessage processes a complete auth line from buffer -func (sph *ScramProtocolHandler) HandleAuthMessage(line []byte, conn gnet.Conn) (authenticated bool, session *Session, err error) { - // Parse SCRAM messages - parts := strings.Fields(string(line)) - if len(parts) < 2 { - conn.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil) - return false, nil, fmt.Errorf("invalid message format") - } - - switch parts[0] { - case "SCRAM-FIRST": - // Parse ClientFirst JSON - var clientFirst ClientFirst - if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil { - conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil) - return false, nil, fmt.Errorf("invalid JSON") - } - - // Process with SCRAM server - serverFirst, err := sph.manager.HandleClientFirst(&clientFirst) - if err != nil { - // Still send challenge to prevent user enumeration - response, _ := json.Marshal(serverFirst) - conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil) - return false, nil, err - } - - // Send ServerFirst challenge - response, _ := json.Marshal(serverFirst) - conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil) - return false, nil, nil // Not authenticated yet - - case "SCRAM-PROOF": - // Parse ClientFinal JSON - var clientFinal ClientFinal - if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil { - conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil) - return false, nil, fmt.Errorf("invalid JSON") - } - - // Verify proof - serverFinal, err := sph.manager.HandleClientFinal(&clientFinal) - if err != nil { - conn.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil) - return false, nil, err - } - - // Authentication successful - session = &Session{ - ID: serverFinal.SessionID, - Method: "scram-sha-256", - RemoteAddr: conn.RemoteAddr().String(), - CreatedAt: time.Now(), - } - - // Send ServerFinal with signature - response, _ := json.Marshal(serverFinal) - conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil) - - return true, session, nil - - default: - conn.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil) - return false, nil, fmt.Errorf("unknown command: %s", parts[0]) - } -} - -// FormatSCRAMRequest formats a SCRAM protocol message for TCP -func FormatSCRAMRequest(command string, data interface{}) (string, error) { - jsonData, err := json.Marshal(data) - if err != nil { - return "", fmt.Errorf("failed to marshal %s: %w", command, err) - } - return fmt.Sprintf("%s %s\n", command, jsonData), nil -} - -// ParseSCRAMResponse parses a SCRAM protocol response from TCP -func ParseSCRAMResponse(response string) (command string, data string, err error) { - response = strings.TrimSpace(response) - parts := strings.SplitN(response, " ", 2) - if len(parts) < 1 { - return "", "", fmt.Errorf("empty response") - } - - command = parts[0] - if len(parts) > 1 { - data = parts[1] - } - return command, data, nil -} \ No newline at end of file diff --git a/src/internal/auth/scram_server.go b/src/internal/auth/scram_server.go deleted file mode 100644 index b9252c9..0000000 --- a/src/internal/auth/scram_server.go +++ /dev/null @@ -1,174 +0,0 @@ -// FILE: src/internal/auth/scram_server.go -package auth - -import ( - "crypto/rand" - "crypto/sha256" - "crypto/subtle" - "encoding/base64" - "fmt" - "sync" - "time" - - "logwisp/src/internal/core" -) - -// Server handles SCRAM authentication -type ScramServer struct { - credentials map[string]*Credential - handshakes map[string]*HandshakeState - mu sync.RWMutex - - // TODO: configurability useful? to be included in config or refactor to use core.const directly for simplicity - // Default Argon2 params for new registrations - DefaultTime uint32 - DefaultMemory uint32 - DefaultThreads uint8 -} - -// HandshakeState tracks ongoing authentication -type HandshakeState struct { - Username string - ClientNonce string - ServerNonce string - FullNonce string - Credential *Credential - CreatedAt time.Time -} - -// NewScramServer creates SCRAM server -func NewScramServer() *ScramServer { - return &ScramServer{ - credentials: make(map[string]*Credential), - handshakes: make(map[string]*HandshakeState), - DefaultTime: core.Argon2Time, - DefaultMemory: core.Argon2Memory, - DefaultThreads: core.Argon2Threads, - } -} - -// AddCredential registers user credential -func (s *ScramServer) AddCredential(cred *Credential) { - s.mu.Lock() - defer s.mu.Unlock() - s.credentials[cred.Username] = cred -} - -// HandleClientFirst processes initial auth request -func (s *ScramServer) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { - s.mu.Lock() - defer s.mu.Unlock() - - // Check if user exists - cred, exists := s.credentials[msg.Username] - if !exists { - // Prevent user enumeration - still generate response - salt := make([]byte, 16) - rand.Read(salt) - serverNonce := generateNonce() - - return &ServerFirst{ - FullNonce: msg.ClientNonce + serverNonce, - Salt: base64.StdEncoding.EncodeToString(salt), - ArgonTime: s.DefaultTime, - ArgonMemory: s.DefaultMemory, - ArgonThreads: s.DefaultThreads, - }, fmt.Errorf("invalid credentials") - } - - // Generate server nonce - serverNonce := generateNonce() - fullNonce := msg.ClientNonce + serverNonce - - // Store handshake state - state := &HandshakeState{ - Username: msg.Username, - ClientNonce: msg.ClientNonce, - ServerNonce: serverNonce, - FullNonce: fullNonce, - Credential: cred, - CreatedAt: time.Now(), - } - s.handshakes[fullNonce] = state - - // Cleanup old handshakes - s.cleanupHandshakes() - - return &ServerFirst{ - FullNonce: fullNonce, - Salt: base64.StdEncoding.EncodeToString(cred.Salt), - ArgonTime: cred.ArgonTime, - ArgonMemory: cred.ArgonMemory, - ArgonThreads: cred.ArgonThreads, - }, nil -} - -// HandleClientFinal verifies client proof -func (s *ScramServer) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { - s.mu.Lock() - defer s.mu.Unlock() - - state, exists := s.handshakes[msg.FullNonce] - if !exists { - return nil, fmt.Errorf("invalid nonce or expired handshake") - } - defer delete(s.handshakes, msg.FullNonce) - - // Check timeout - if time.Since(state.CreatedAt) > 60*time.Second { - return nil, fmt.Errorf("handshake timeout") - } - - // Decode client proof - clientProof, err := base64.StdEncoding.DecodeString(msg.ClientProof) - if err != nil { - return nil, fmt.Errorf("invalid proof encoding") - } - - // Build auth message - clientFirstBare := fmt.Sprintf("u=%s,n=%s", state.Username, state.ClientNonce) - serverFirst := &ServerFirst{ - FullNonce: state.FullNonce, - Salt: base64.StdEncoding.EncodeToString(state.Credential.Salt), - ArgonTime: state.Credential.ArgonTime, - ArgonMemory: state.Credential.ArgonMemory, - ArgonThreads: state.Credential.ArgonThreads, - } - clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce) - authMessage := clientFirstBare + "," + serverFirst.Marshal() + "," + clientFinalBare - - // Compute client signature - clientSignature := computeHMAC(state.Credential.StoredKey, []byte(authMessage)) - - // XOR to get ClientKey - clientKey := xorBytes(clientProof, clientSignature) - - // Verify by computing StoredKey - computedStoredKey := sha256.Sum256(clientKey) - if subtle.ConstantTimeCompare(computedStoredKey[:], state.Credential.StoredKey) != 1 { - return nil, fmt.Errorf("authentication failed") - } - - // Generate server signature for mutual auth - serverSignature := computeHMAC(state.Credential.ServerKey, []byte(authMessage)) - - return &ServerFinal{ - ServerSignature: base64.StdEncoding.EncodeToString(serverSignature), - SessionID: generateSessionID(), - }, nil -} - -func (s *ScramServer) cleanupHandshakes() { - cutoff := time.Now().Add(-60 * time.Second) - for nonce, state := range s.handshakes { - if state.CreatedAt.Before(cutoff) { - delete(s.handshakes, nonce) - } - } -} - -func generateNonce() string { - b := make([]byte, 32) - rand.Read(b) - return base64.StdEncoding.EncodeToString(b) -} \ No newline at end of file diff --git a/src/internal/config/config.go b/src/internal/config/config.go index 886b2c3..8fe6d5d 100644 --- a/src/internal/config/config.go +++ b/src/internal/config/config.go @@ -3,6 +3,7 @@ package config // --- LogWisp Configuration Options --- +// Config is the top-level configuration structure for the LogWisp application. type Config struct { // Top-level flags for application control Background bool `toml:"background"` @@ -26,7 +27,7 @@ type Config struct { // --- Logging Options --- -// Represents logging configuration for LogWisp +// LogConfig represents the logging configuration for the LogWisp application itself. type LogConfig struct { // Output mode: "file", "stdout", "stderr", "split", "all", "none" Output string `toml:"output"` @@ -41,6 +42,7 @@ type LogConfig struct { Console *LogConsoleConfig `toml:"console"` } +// LogFileConfig defines settings for file-based application logging. type LogFileConfig struct { // Directory for log files Directory string `toml:"directory"` @@ -58,6 +60,7 @@ type LogFileConfig struct { RetentionHours float64 `toml:"retention_hours"` } +// LogConsoleConfig defines settings for console-based application logging. type LogConsoleConfig struct { // Target for console output: "stdout", "stderr", "split" // "split": info/debug to stdout, warn/error to stderr @@ -69,19 +72,19 @@ type LogConsoleConfig struct { // --- Pipeline Options --- +// PipelineConfig defines a complete data flow from sources to sinks. type PipelineConfig struct { Name string `toml:"name"` Sources []SourceConfig `toml:"sources"` RateLimit *RateLimitConfig `toml:"rate_limit"` Filters []FilterConfig `toml:"filters"` Format *FormatConfig `toml:"format"` - - Sinks []SinkConfig `toml:"sinks"` - // Auth *ServerAuthConfig `toml:"auth"` // Global auth for pipeline + Sinks []SinkConfig `toml:"sinks"` } // Common configuration structs used across components +// NetLimitConfig defines network-level access control and rate limiting rules. type NetLimitConfig struct { Enabled bool `toml:"enabled"` MaxConnections int64 `toml:"max_connections"` @@ -95,27 +98,37 @@ type NetLimitConfig struct { IPBlacklist []string `toml:"ip_blacklist"` } -type TLSConfig struct { - Enabled bool `toml:"enabled"` - CertFile string `toml:"cert_file"` - KeyFile string `toml:"key_file"` - CAFile string `toml:"ca_file"` - ServerName string `toml:"server_name"` // for client verification - SkipVerify bool `toml:"skip_verify"` +// TLSServerConfig defines TLS settings for a server (HTTP Source, HTTP Sink). +type TLSServerConfig struct { + Enabled bool `toml:"enabled"` + CertFile string `toml:"cert_file"` // Server's certificate file. + KeyFile string `toml:"key_file"` // Server's private key file. + ClientAuth bool `toml:"client_auth"` // Enable/disable mTLS. + ClientCAFile string `toml:"client_ca_file"` // CA for verifying client certificates. + VerifyClientCert bool `toml:"verify_client_cert"` // Require and verify client certs. - // Client certificate authentication - ClientAuth bool `toml:"client_auth"` - ClientCAFile string `toml:"client_ca_file"` - VerifyClientCert bool `toml:"verify_client_cert"` - - // TLS version constraints - MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3" - MaxVersion string `toml:"max_version"` - - // Cipher suites (comma-separated list) + // Common TLS settings + MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3" + MaxVersion string `toml:"max_version"` CipherSuites string `toml:"cipher_suites"` } +// TLSClientConfig defines TLS settings for a client (HTTP Client Sink). +type TLSClientConfig struct { + Enabled bool `toml:"enabled"` + ServerCAFile string `toml:"server_ca_file"` // CA for verifying the remote server's certificate. + ClientCertFile string `toml:"client_cert_file"` // Client's certificate for mTLS. + ClientKeyFile string `toml:"client_key_file"` // Client's private key for mTLS. + ServerName string `toml:"server_name"` // For server certificate validation (SNI). + InsecureSkipVerify bool `toml:"insecure_skip_verify"` // Use with caution. + + // Common TLS settings + MinVersion string `toml:"min_version"` + MaxVersion string `toml:"max_version"` + CipherSuites string `toml:"cipher_suites"` +} + +// HeartbeatConfig defines settings for periodic keep-alive or status messages. type HeartbeatConfig struct { Enabled bool `toml:"enabled"` IntervalMS int64 `toml:"interval_ms"` @@ -124,15 +137,15 @@ type HeartbeatConfig struct { Format string `toml:"format"` } +// TODO: Future implementation +// ClientAuthConfig defines settings for client-side authentication. type ClientAuthConfig struct { - Type string `toml:"type"` // "none", "basic", "token", "scram" - Username string `toml:"username"` - Password string `toml:"password"` - Token string `toml:"token"` + Type string `toml:"type"` // "none" } // --- Source Options --- +// SourceConfig is a polymorphic struct representing a single data source. type SourceConfig struct { Type string `toml:"type"` @@ -143,6 +156,7 @@ type SourceConfig struct { TCP *TCPSourceOptions `toml:"tcp,omitempty"` } +// DirectorySourceOptions defines settings for a directory-based source. type DirectorySourceOptions struct { Path string `toml:"path"` Pattern string `toml:"pattern"` // glob pattern @@ -150,10 +164,12 @@ type DirectorySourceOptions struct { Recursive bool `toml:"recursive"` // TODO: implement logic } +// StdinSourceOptions defines settings for a stdin-based source. type StdinSourceOptions struct { BufferSize int64 `toml:"buffer_size"` } +// HTTPSourceOptions defines settings for an HTTP server source. type HTTPSourceOptions struct { Host string `toml:"host"` Port int64 `toml:"port"` @@ -163,10 +179,11 @@ type HTTPSourceOptions struct { ReadTimeout int64 `toml:"read_timeout_ms"` WriteTimeout int64 `toml:"write_timeout_ms"` NetLimit *NetLimitConfig `toml:"net_limit"` - TLS *TLSConfig `toml:"tls"` + TLS *TLSServerConfig `toml:"tls"` Auth *ServerAuthConfig `toml:"auth"` } +// TCPSourceOptions defines settings for a TCP server source. type TCPSourceOptions struct { Host string `toml:"host"` Port int64 `toml:"port"` @@ -180,6 +197,7 @@ type TCPSourceOptions struct { // --- Sink Options --- +// SinkConfig is a polymorphic struct representing a single data sink. type SinkConfig struct { Type string `toml:"type"` @@ -192,12 +210,14 @@ type SinkConfig struct { TCPClient *TCPClientSinkOptions `toml:"tcp_client,omitempty"` } +// ConsoleSinkOptions defines settings for a console-based sink. type ConsoleSinkOptions struct { Target string `toml:"target"` // "stdout", "stderr", "split" Colorize bool `toml:"colorize"` BufferSize int64 `toml:"buffer_size"` } +// FileSinkOptions defines settings for a file-based sink. type FileSinkOptions struct { Directory string `toml:"directory"` Name string `toml:"name"` @@ -209,6 +229,7 @@ type FileSinkOptions struct { FlushInterval int64 `toml:"flush_interval_ms"` } +// HTTPSinkOptions defines settings for an HTTP server sink. type HTTPSinkOptions struct { Host string `toml:"host"` Port int64 `toml:"port"` @@ -218,10 +239,11 @@ type HTTPSinkOptions struct { WriteTimeout int64 `toml:"write_timeout_ms"` Heartbeat *HeartbeatConfig `toml:"heartbeat"` NetLimit *NetLimitConfig `toml:"net_limit"` - TLS *TLSConfig `toml:"tls"` + TLS *TLSServerConfig `toml:"tls"` Auth *ServerAuthConfig `toml:"auth"` } +// TCPSinkOptions defines settings for a TCP server sink. type TCPSinkOptions struct { Host string `toml:"host"` Port int64 `toml:"port"` @@ -234,6 +256,7 @@ type TCPSinkOptions struct { Auth *ServerAuthConfig `toml:"auth"` } +// HTTPClientSinkOptions defines settings for an HTTP client sink. type HTTPClientSinkOptions struct { URL string `toml:"url"` BufferSize int64 `toml:"buffer_size"` @@ -244,10 +267,11 @@ type HTTPClientSinkOptions struct { RetryDelayMS int64 `toml:"retry_delay_ms"` RetryBackoff float64 `toml:"retry_backoff"` InsecureSkipVerify bool `toml:"insecure_skip_verify"` - TLS *TLSConfig `toml:"tls"` + TLS *TLSClientConfig `toml:"tls"` Auth *ClientAuthConfig `toml:"auth"` } +// TCPClientSinkOptions defines settings for a TCP client sink. type TCPClientSinkOptions struct { Host string `toml:"host"` Port int64 `toml:"port"` @@ -264,7 +288,7 @@ type TCPClientSinkOptions struct { // --- Rate Limit Options --- -// Defines the action to take when a rate limit is exceeded. +// RateLimitPolicy defines the action to take when a rate limit is exceeded. type RateLimitPolicy int const ( @@ -274,7 +298,7 @@ const ( PolicyDrop ) -// Defines the configuration for pipeline-level rate limiting. +// RateLimitConfig defines the configuration for pipeline-level rate limiting. type RateLimitConfig struct { // Rate is the number of log entries allowed per second. Default: 0 (disabled). Rate float64 `toml:"rate"` @@ -288,23 +312,27 @@ type RateLimitConfig struct { // --- Filter Options --- -// Represents the filter type +// FilterType represents the filter's behavior (include or exclude). type FilterType string const ( + // FilterTypeInclude specifies that only matching logs will pass. FilterTypeInclude FilterType = "include" // Whitelist - only matching logs pass + // FilterTypeExclude specifies that matching logs will be dropped. FilterTypeExclude FilterType = "exclude" // Blacklist - matching logs are dropped ) -// Represents how multiple patterns are combined +// FilterLogic represents how multiple filter patterns are combined. type FilterLogic string const ( - FilterLogicOr FilterLogic = "or" // Match any pattern + // FilterLogicOr specifies that a match on any pattern is sufficient. + FilterLogicOr FilterLogic = "or" // Match any pattern + // FilterLogicAnd specifies that all patterns must match. FilterLogicAnd FilterLogic = "and" // Match all patterns ) -// Represents filter configuration +// FilterConfig represents the configuration for a single filter. type FilterConfig struct { Type FilterType `toml:"type"` Logic FilterLogic `toml:"logic"` @@ -313,6 +341,7 @@ type FilterConfig struct { // --- Formatter Options --- +// FormatConfig is a polymorphic struct representing log entry formatting options. type FormatConfig struct { // Format configuration - polymorphic like sources/sinks Type string `toml:"type"` // "json", "txt", "raw" @@ -323,6 +352,7 @@ type FormatConfig struct { RawFormatOptions *RawFormatterOptions `toml:"raw,omitempty"` } +// JSONFormatterOptions defines settings for the JSON formatter. type JSONFormatterOptions struct { Pretty bool `toml:"pretty"` TimestampField string `toml:"timestamp_field"` @@ -331,49 +361,21 @@ type JSONFormatterOptions struct { SourceField string `toml:"source_field"` } +// TxtFormatterOptions defines settings for the text template formatter. type TxtFormatterOptions struct { Template string `toml:"template"` TimestampFormat string `toml:"timestamp_format"` } +// RawFormatterOptions defines settings for the raw pass-through formatter. type RawFormatterOptions struct { AddNewLine bool `toml:"add_new_line"` } // --- Server-side Auth (for sources) --- -type BasicAuthConfig struct { - Users []BasicAuthUser `toml:"users"` - Realm string `toml:"realm"` -} - -type BasicAuthUser struct { - Username string `toml:"username"` - PasswordHash string `toml:"password_hash"` // Argon2 -} - -type ScramAuthConfig struct { - Users []ScramUser `toml:"users"` -} - -type ScramUser struct { - Username string `toml:"username"` - StoredKey string `toml:"stored_key"` // base64 - ServerKey string `toml:"server_key"` // base64 - Salt string `toml:"salt"` // base64 - ArgonTime uint32 `toml:"argon_time"` - ArgonMemory uint32 `toml:"argon_memory"` - ArgonThreads uint8 `toml:"argon_threads"` -} - -type TokenAuthConfig struct { - Tokens []string `toml:"tokens"` -} - -// Server auth wrapper (for sources accepting connections) +// TODO: future implementation +// ServerAuthConfig defines settings for server-side authentication. type ServerAuthConfig struct { - Type string `toml:"type"` // "none", "basic", "token", "scram" - Basic *BasicAuthConfig `toml:"basic,omitempty"` - Token *TokenAuthConfig `toml:"token,omitempty"` - Scram *ScramAuthConfig `toml:"scram,omitempty"` + Type string `toml:"type"` // "none" } \ No newline at end of file diff --git a/src/internal/config/loader.go b/src/internal/config/loader.go index 4daed16..73d519e 100644 --- a/src/internal/config/loader.go +++ b/src/internal/config/loader.go @@ -11,13 +11,66 @@ import ( lconfig "github.com/lixenwraith/config" ) +// configManager holds the global instance of the configuration manager. var configManager *lconfig.Config -// Hot reload access +// Load is the single entry point for loading all application configuration. +func Load(args []string) (*Config, error) { + configPath, isExplicit := resolveConfigPath(args) + // Build configuration with all sources + + // Create target config instance that will be populated + finalConfig := &Config{} + + // Builder handles loading, populating the target struct, and validation + cfg, err := lconfig.NewBuilder(). + WithTarget(finalConfig). // Typed target struct + WithDefaults(defaults()). // Default values + WithSources( + lconfig.SourceCLI, + lconfig.SourceEnv, + lconfig.SourceFile, + lconfig.SourceDefault, + ). + WithEnvTransform(customEnvTransform). // Convert '.' to '_' in env separation + WithEnvPrefix("LOGWISP_"). // Environment variable prefix + WithArgs(args). // Command-line arguments + WithFile(configPath). // TOML config file + WithFileFormat("toml"). // Explicit format + WithTypedValidator(ValidateConfig). // Centralized validation + WithSecurityOptions(lconfig.SecurityOptions{ + PreventPathTraversal: true, + MaxFileSize: 10 * 1024 * 1024, // 10MB max config + }). + Build() + + if err != nil { + // Handle file not found errors - maintain existing behavior + if errors.Is(err, lconfig.ErrConfigNotFound) { + if isExplicit { + return nil, fmt.Errorf("config file not found: %s", configPath) + } + // If the default config file is not found, it's not an error, default/cli/env will be used + } else { + return nil, fmt.Errorf("failed to load or validate config: %w", err) + } + } + + // Store the config file path for hot reload + finalConfig.ConfigFile = configPath + + // Store the manager for hot reload + configManager = cfg + + return finalConfig, nil +} + +// GetConfigManager returns the global configuration manager instance for hot-reloading. func GetConfigManager() *lconfig.Config { return configManager } +// defaults provides the default configuration values for the application. func defaults() *Config { return &Config{ // Top-level flag defaults @@ -76,58 +129,7 @@ func defaults() *Config { } } -// Single entry point for loading all configuration -func Load(args []string) (*Config, error) { - configPath, isExplicit := resolveConfigPath(args) - // Build configuration with all sources - - // Create target config instance that will be populated - finalConfig := &Config{} - - // Builder handles loading, populating the target struct, and validation - cfg, err := lconfig.NewBuilder(). - WithTarget(finalConfig). // Typed target struct - WithDefaults(defaults()). // Default values - WithSources( - lconfig.SourceCLI, - lconfig.SourceEnv, - lconfig.SourceFile, - lconfig.SourceDefault, - ). - WithEnvTransform(customEnvTransform). // Convert '.' to '_' in env separation - WithEnvPrefix("LOGWISP_"). // Environment variable prefix - WithArgs(args). // Command-line arguments - WithFile(configPath). // TOML config file - WithFileFormat("toml"). // Explicit format - WithTypedValidator(ValidateConfig). // Centralized validation - WithSecurityOptions(lconfig.SecurityOptions{ - PreventPathTraversal: true, - MaxFileSize: 10 * 1024 * 1024, // 10MB max config - }). - Build() - - if err != nil { - // Handle file not found errors - maintain existing behavior - if errors.Is(err, lconfig.ErrConfigNotFound) { - if isExplicit { - return nil, fmt.Errorf("config file not found: %s", configPath) - } - // If the default config file is not found, it's not an error, default/cli/env will be used - } else { - return nil, fmt.Errorf("failed to load or validate config: %w", err) - } - } - - // Store the config file path for hot reload - finalConfig.ConfigFile = configPath - - // Store the manager for hot reload - configManager = cfg - - return finalConfig, nil -} - -// Returns the configuration file path +// resolveConfigPath determines the configuration file path based on CLI args, env vars, and default locations. func resolveConfigPath(args []string) (path string, isExplicit bool) { // 1. Check for --config flag in command-line arguments (highest precedence) for i, arg := range args { @@ -163,6 +165,7 @@ func resolveConfigPath(args []string) (path string, isExplicit bool) { return "logwisp.toml", false } +// customEnvTransform converts TOML-style config paths (e.g., logging.level) to environment variable format (LOGGING_LEVEL). func customEnvTransform(path string) string { env := strings.ReplaceAll(path, ".", "_") env = strings.ToUpper(env) diff --git a/src/internal/config/validation.go b/src/internal/config/validation.go index af29348..e83dcc1 100644 --- a/src/internal/config/validation.go +++ b/src/internal/config/validation.go @@ -11,8 +11,7 @@ import ( lconfig "github.com/lixenwraith/config" ) -// validateConfig is the centralized validator for the entire configuration -// This replaces the old (c *Config) validate() method +// ValidateConfig is the centralized validator for the entire configuration structure. func ValidateConfig(cfg *Config) error { if cfg == nil { return fmt.Errorf("config is nil") @@ -39,6 +38,7 @@ func ValidateConfig(cfg *Config) error { return nil } +// validateLogConfig validates the application's own logging settings. func validateLogConfig(cfg *LogConfig) error { validOutputs := map[string]bool{ "file": true, "stdout": true, "stderr": true, @@ -74,6 +74,7 @@ func validateLogConfig(cfg *LogConfig) error { return nil } +// validatePipeline validates a single pipeline's configuration. func validatePipeline(index int, p *PipelineConfig, pipelineNames map[string]bool, allPorts map[int64]string) error { // Validate pipeline name if err := lconfig.NonEmpty(p.Name); err != nil { @@ -131,7 +132,7 @@ func validatePipeline(index int, p *PipelineConfig, pipelineNames map[string]boo return nil } -// validateSourceConfig validates typed source configuration +// validateSourceConfig validates a polymorphic source configuration. func validateSourceConfig(pipelineName string, index int, s *SourceConfig) error { if err := lconfig.NonEmpty(s.Type); err != nil { return fmt.Errorf("pipeline '%s' source[%d]: missing type", pipelineName, index) @@ -186,177 +187,7 @@ func validateSourceConfig(pipelineName string, index int, s *SourceConfig) error } } -func validateDirectorySource(pipelineName string, index int, opts *DirectorySourceOptions) error { - if err := lconfig.NonEmpty(opts.Path); err != nil { - return fmt.Errorf("pipeline '%s' source[%d]: directory requires 'path'", pipelineName, index) - } else { - absPath, err := filepath.Abs(opts.Path) - if err != nil { - return fmt.Errorf("invalid path %s: %w", opts.Path, err) - } - opts.Path = absPath - } - - // Check for directory traversal - // TODO: traversal check only if optional security settings from cli/env set - if strings.Contains(opts.Path, "..") { - return fmt.Errorf("pipeline '%s' source[%d]: path contains directory traversal", pipelineName, index) - } - - // Validate pattern if provided - if opts.Pattern != "" { - if strings.Count(opts.Pattern, "*") == 0 && strings.Count(opts.Pattern, "?") == 0 { - // If no wildcards, ensure valid filename - if filepath.Base(opts.Pattern) != opts.Pattern { - return fmt.Errorf("pipeline '%s' source[%d]: pattern contains path separators", pipelineName, index) - } - } - } else { - opts.Pattern = "*" - } - - // Validate check interval - if opts.CheckIntervalMS < 10 { - return fmt.Errorf("pipeline '%s' source[%d]: check_interval_ms must be at least 10ms", pipelineName, index) - } - - return nil -} - -func validateStdinSource(pipelineName string, index int, opts *StdinSourceOptions) error { - if opts.BufferSize < 0 { - return fmt.Errorf("pipeline '%s' source[%d]: buffer_size must be positive", pipelineName, index) - } else if opts.BufferSize == 0 { - opts.BufferSize = 1000 - } - return nil -} - -func validateHTTPSource(pipelineName string, index int, opts *HTTPSourceOptions) error { - // Validate port - if err := lconfig.Port(opts.Port); err != nil { - return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) - } - - // Set defaults - if opts.Host == "" { - opts.Host = "0.0.0.0" - } - if opts.IngestPath == "" { - opts.IngestPath = "/ingest" - } - if opts.MaxRequestBodySize <= 0 { - opts.MaxRequestBodySize = 10 * 1024 * 1024 // 10MB default - } - if opts.ReadTimeout <= 0 { - opts.ReadTimeout = 5000 // 5 seconds - } - if opts.WriteTimeout <= 0 { - opts.WriteTimeout = 5000 // 5 seconds - } - - // Validate host if specified - if opts.Host != "" && opts.Host != "0.0.0.0" { - if err := lconfig.IPAddress(opts.Host); err != nil { - return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) - } - } - - // Validate paths - if !strings.HasPrefix(opts.IngestPath, "/") { - return fmt.Errorf("pipeline '%s' source[%d]: ingest_path must start with /", pipelineName, index) - } - - // Validate auth configuration - validHTTPSourceAuthTypes := map[string]bool{"basic": true, "token": true, "mtls": true} - if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" { - if !validHTTPSourceAuthTypes[opts.Auth.Type] { - return fmt.Errorf("pipeline '%s' source[%d]: %s is not a valid auth type", - pipelineName, index, opts.Auth.Type) - } - // All non-none auth types require TLS for HTTP - if opts.TLS == nil || !opts.TLS.Enabled { - return fmt.Errorf("pipeline '%s' source[%d]: %s auth requires TLS to be enabled", - pipelineName, index, opts.Auth.Type) - } - - // Validate specific auth types - if err := validateServerAuth(pipelineName, opts.Auth); err != nil { - return fmt.Errorf("source[%d]: %w", index, err) - } - } - - // Validate nested configs - if opts.NetLimit != nil { - if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil { - return err - } - } - - if opts.TLS != nil { - if err := validateTLS(pipelineName, fmt.Sprintf("source[%d]", index), opts.TLS); err != nil { - return err - } - } - - return nil -} - -func validateTCPSource(pipelineName string, index int, opts *TCPSourceOptions) error { - // Validate port - if err := lconfig.Port(opts.Port); err != nil { - return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) - } - - // Set defaults - if opts.Host == "" { - opts.Host = "0.0.0.0" - } - if opts.ReadTimeout <= 0 { - opts.ReadTimeout = 5000 // 5 seconds - } - if !opts.KeepAlive { - opts.KeepAlive = true // Default enabled - } - if opts.KeepAlivePeriod <= 0 { - opts.KeepAlivePeriod = 30000 // 30 seconds - } - - // Validate host if specified - if opts.Host != "" && opts.Host != "0.0.0.0" { - if err := lconfig.IPAddress(opts.Host); err != nil { - return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) - } - } - - // TCP source does NOT support TLS - // Validate auth configuration - only none and scram are allowed - if opts.Auth != nil { - switch opts.Auth.Type { - case "", "none": - // OK - case "scram": - // SCRAM doesn't require TLS - if err := validateServerAuth(pipelineName, opts.Auth); err != nil { - return fmt.Errorf("source[%d]: %w", index, err) - } - default: - return fmt.Errorf("pipeline '%s' source[%d]: TCP source only supports 'none' or 'scram' auth (got '%s')", - pipelineName, index, opts.Auth.Type) - } - } - - // Validate NetLimit if present - if opts.NetLimit != nil { - if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil { - return err - } - } - - return nil -} - -// validateSinkConfig validates typed sink configuration +// validateSinkConfig validates a polymorphic sink configuration. func validateSinkConfig(pipelineName string, index int, s *SinkConfig, allPorts map[int64]string) error { if err := lconfig.NonEmpty(s.Type); err != nil { return fmt.Errorf("pipeline '%s' sink[%d]: missing type", pipelineName, index) @@ -423,6 +254,268 @@ func validateSinkConfig(pipelineName string, index int, s *SinkConfig, allPorts } } +// validateFormatterConfig validates formatter configuration +func validateFormatterConfig(p *PipelineConfig) error { + if p.Format == nil { + p.Format = &FormatConfig{ + Type: "raw", + } + } else if p.Format.Type == "" { + p.Format.Type = "raw" // Default + } + + switch p.Format.Type { + + case "raw": + if p.Format.RawFormatOptions == nil { + p.Format.RawFormatOptions = &RawFormatterOptions{} + } + + case "txt": + if p.Format.TxtFormatOptions == nil { + p.Format.TxtFormatOptions = &TxtFormatterOptions{} + } + + // Default template format + templateStr := "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}" + if p.Format.TxtFormatOptions.Template != "" { + p.Format.TxtFormatOptions.Template = templateStr + } + + // Default timestamp format + timestampFormat := time.RFC3339 + if p.Format.TxtFormatOptions.TimestampFormat != "" { + p.Format.TxtFormatOptions.TimestampFormat = timestampFormat + } + + case "json": + if p.Format.JSONFormatOptions == nil { + p.Format.JSONFormatOptions = &JSONFormatterOptions{} + } + } + + return nil +} + +// validateRateLimit validates the pipeline-level rate limit settings. +func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { + if cfg == nil { + return nil + } + + if cfg.Rate < 0 { + return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName) + } + + if cfg.Burst < 0 { + return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName) + } + + if cfg.MaxEntrySizeBytes < 0 { + return fmt.Errorf("pipeline '%s': max entry size bytes cannot be negative", pipelineName) + } + + // Validate policy + switch strings.ToLower(cfg.Policy) { + case "", "pass", "drop": + // Valid policies + default: + return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')", + pipelineName, cfg.Policy) + } + + return nil +} + +// validateFilter validates a single filter's configuration. +func validateFilter(pipelineName string, filterIndex int, cfg *FilterConfig) error { + // Validate filter type + switch cfg.Type { + case FilterTypeInclude, FilterTypeExclude, "": + // Valid types + default: + return fmt.Errorf("pipeline '%s' filter[%d]: invalid type '%s' (must be 'include' or 'exclude')", + pipelineName, filterIndex, cfg.Type) + } + + // Validate filter logic + switch cfg.Logic { + case FilterLogicOr, FilterLogicAnd, "": + // Valid logic + default: + return fmt.Errorf("pipeline '%s' filter[%d]: invalid logic '%s' (must be 'or' or 'and')", + pipelineName, filterIndex, cfg.Logic) + } + + // Empty patterns is valid - passes everything + if len(cfg.Patterns) == 0 { + return nil + } + + // Validate regex patterns + for i, pattern := range cfg.Patterns { + if _, err := regexp.Compile(pattern); err != nil { + return fmt.Errorf("pipeline '%s' filter[%d] pattern[%d] '%s': invalid regex: %w", + pipelineName, filterIndex, i, pattern, err) + } + } + + return nil +} + +// validateDirectorySource validates the settings for a directory source. +func validateDirectorySource(pipelineName string, index int, opts *DirectorySourceOptions) error { + if err := lconfig.NonEmpty(opts.Path); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: directory requires 'path'", pipelineName, index) + } else { + absPath, err := filepath.Abs(opts.Path) + if err != nil { + return fmt.Errorf("invalid path %s: %w", opts.Path, err) + } + opts.Path = absPath + } + + // Check for directory traversal + // TODO: traversal check only if optional security settings from cli/env set + if strings.Contains(opts.Path, "..") { + return fmt.Errorf("pipeline '%s' source[%d]: path contains directory traversal", pipelineName, index) + } + + // Validate pattern if provided + if opts.Pattern != "" { + if strings.Count(opts.Pattern, "*") == 0 && strings.Count(opts.Pattern, "?") == 0 { + // If no wildcards, ensure valid filename + if filepath.Base(opts.Pattern) != opts.Pattern { + return fmt.Errorf("pipeline '%s' source[%d]: pattern contains path separators", pipelineName, index) + } + } + } else { + opts.Pattern = "*" + } + + // Validate check interval + if opts.CheckIntervalMS < 10 { + return fmt.Errorf("pipeline '%s' source[%d]: check_interval_ms must be at least 10ms", pipelineName, index) + } + + return nil +} + +// validateStdinSource validates the settings for a stdin source. +func validateStdinSource(pipelineName string, index int, opts *StdinSourceOptions) error { + if opts.BufferSize < 0 { + return fmt.Errorf("pipeline '%s' source[%d]: buffer_size must be positive", pipelineName, index) + } else if opts.BufferSize == 0 { + opts.BufferSize = 1000 + } + return nil +} + +// validateHTTPSource validates the settings for an HTTP source. +func validateHTTPSource(pipelineName string, index int, opts *HTTPSourceOptions) error { + // Validate port + if err := lconfig.Port(opts.Port); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) + } + + // Set defaults + if opts.Host == "" { + opts.Host = "0.0.0.0" + } + if opts.IngestPath == "" { + opts.IngestPath = "/ingest" + } + if opts.MaxRequestBodySize <= 0 { + opts.MaxRequestBodySize = 10 * 1024 * 1024 // 10MB default + } + if opts.ReadTimeout <= 0 { + opts.ReadTimeout = 5000 // 5 seconds + } + if opts.WriteTimeout <= 0 { + opts.WriteTimeout = 5000 // 5 seconds + } + + // Validate host if specified + if opts.Host != "" && opts.Host != "0.0.0.0" { + if err := lconfig.IPAddress(opts.Host); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) + } + } + + // Validate paths + if !strings.HasPrefix(opts.IngestPath, "/") { + return fmt.Errorf("pipeline '%s' source[%d]: ingest_path must start with /", pipelineName, index) + } + + // Validate auth configuration + validHTTPSourceAuthTypes := map[string]bool{"basic": true, "token": true, "mtls": true} + if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" { + if !validHTTPSourceAuthTypes[opts.Auth.Type] { + return fmt.Errorf("pipeline '%s' source[%d]: %s is not a valid auth type", + pipelineName, index, opts.Auth.Type) + } + // All non-none auth types require TLS for HTTP + if opts.TLS == nil || !opts.TLS.Enabled { + return fmt.Errorf("pipeline '%s' source[%d]: %s auth requires TLS to be enabled", + pipelineName, index, opts.Auth.Type) + } + } + + // Validate nested configs + if opts.NetLimit != nil { + if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil { + return err + } + } + + if opts.TLS != nil { + if err := validateTLSServer(pipelineName, fmt.Sprintf("source[%d]", index), opts.TLS); err != nil { + return err + } + } + + return nil +} + +// validateTCPSource validates the settings for a TCP source. +func validateTCPSource(pipelineName string, index int, opts *TCPSourceOptions) error { + // Validate port + if err := lconfig.Port(opts.Port); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) + } + + // Set defaults + if opts.Host == "" { + opts.Host = "0.0.0.0" + } + if opts.ReadTimeout <= 0 { + opts.ReadTimeout = 5000 // 5 seconds + } + if !opts.KeepAlive { + opts.KeepAlive = true // Default enabled + } + if opts.KeepAlivePeriod <= 0 { + opts.KeepAlivePeriod = 30000 // 30 seconds + } + + // Validate host if specified + if opts.Host != "" && opts.Host != "0.0.0.0" { + if err := lconfig.IPAddress(opts.Host); err != nil { + return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err) + } + } + + // Validate NetLimit if present + if opts.NetLimit != nil { + if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil { + return err + } + } + + return nil +} + +// validateConsoleSink validates the settings for a console sink. func validateConsoleSink(pipelineName string, index int, opts *ConsoleSinkOptions) error { if opts.BufferSize < 1 { return fmt.Errorf("pipeline '%s' sink[%d]: buffer_size must be positive", pipelineName, index) @@ -430,6 +523,7 @@ func validateConsoleSink(pipelineName string, index int, opts *ConsoleSinkOption return nil } +// validateFileSink validates the settings for a file sink. func validateFileSink(pipelineName string, index int, opts *FileSinkOptions) error { if err := lconfig.NonEmpty(opts.Directory); err != nil { return fmt.Errorf("pipeline '%s' sink[%d]: file requires 'directory'", pipelineName, index) @@ -463,6 +557,7 @@ func validateFileSink(pipelineName string, index int, opts *FileSinkOptions) err return nil } +// validateHTTPSink validates the settings for an HTTP sink. func validateHTTPSink(pipelineName string, index int, opts *HTTPSinkOptions, allPorts map[int64]string) error { // Validate port if err := lconfig.Port(opts.Port); err != nil { @@ -511,7 +606,7 @@ func validateHTTPSink(pipelineName string, index int, opts *HTTPSinkOptions, all } if opts.TLS != nil { - if err := validateTLS(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil { + if err := validateTLSServer(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil { return err } } @@ -519,6 +614,7 @@ func validateHTTPSink(pipelineName string, index int, opts *HTTPSinkOptions, all return nil } +// validateTCPSink validates the settings for a TCP sink. func validateTCPSink(pipelineName string, index int, opts *TCPSinkOptions, allPorts map[int64]string) error { // Validate port if err := lconfig.Port(opts.Port); err != nil { @@ -560,6 +656,7 @@ func validateTCPSink(pipelineName string, index int, opts *TCPSinkOptions, allPo return nil } +// validateHTTPClientSink validates the settings for an HTTP client sink. func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSinkOptions) error { // Validate URL if err := lconfig.NonEmpty(opts.URL); err != nil { @@ -575,8 +672,6 @@ func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSink return fmt.Errorf("pipeline '%s' sink[%d]: URL must use http or https scheme", pipelineName, index) } - isHTTPS := parsedURL.Scheme == "https" - // Set defaults for unspecified fields if opts.BufferSize <= 0 { opts.BufferSize = 1000 @@ -600,57 +695,9 @@ func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSink opts.RetryBackoff = 2.0 } - // Validate auth configuration - if opts.Auth != nil { - switch opts.Auth.Type { - case "basic": - if opts.Auth.Username == "" || opts.Auth.Password == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: username and password required for basic auth", - pipelineName, index) - } - if !isHTTPS && !opts.InsecureSkipVerify { - return fmt.Errorf("pipeline '%s' sink[%d]: basic auth requires HTTPS (security: credentials would be sent in plaintext)", - pipelineName, index) - } - - case "token": - if opts.Auth.Token == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: token required for %s auth", - pipelineName, index, opts.Auth.Type) - } - if !isHTTPS && !opts.InsecureSkipVerify { - return fmt.Errorf("pipeline '%s' sink[%d]: %s auth requires HTTPS (security: token would be sent in plaintext)", - pipelineName, index, opts.Auth.Type) - } - - case "mtls": - if !isHTTPS { - return fmt.Errorf("pipeline '%s' sink[%d]: mTLS requires HTTPS", - pipelineName, index) - } - // mTLS certs should be in TLS config, not auth config - if opts.TLS == nil || opts.TLS.CertFile == "" || opts.TLS.KeyFile == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: cert_file and key_file required in TLS config for mTLS auth", - pipelineName, index) - } - - case "none", "": - // Clear any credentials if auth is "none" or empty - if opts.Auth != nil { - opts.Auth.Username = "" - opts.Auth.Password = "" - opts.Auth.Token = "" - } - - default: - return fmt.Errorf("pipeline '%s' sink[%d]: invalid auth type '%s' (valid: none, basic, token, mtls)", - pipelineName, index, opts.Auth.Type) - } - } - // Validate TLS config if present if opts.TLS != nil { - if err := validateTLS(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil { + if err := validateTLSClient(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil { return err } } @@ -658,6 +705,7 @@ func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSink return nil } +// validateTCPClientSink validates the settings for a TCP client sink. func validateTCPClientSink(pipelineName string, index int, opts *TCPClientSinkOptions) error { // Validate host and port if err := lconfig.NonEmpty(opts.Host); err != nil { @@ -694,78 +742,10 @@ func validateTCPClientSink(pipelineName string, index int, opts *TCPClientSinkOp opts.ReconnectBackoff = 1.5 } - // Validate auth configuration - if opts.Auth != nil { - switch opts.Auth.Type { - - case "scram": - if opts.Auth.Username == "" || opts.Auth.Password == "" { - return fmt.Errorf("pipeline '%s' sink[%d]: username and password required for SCRAM auth", - pipelineName, index) - } - // SCRAM doesn't require TLS as it uses challenge-response - - case "none", "": - // Clear credentials - if opts.Auth != nil { - opts.Auth.Username = "" - opts.Auth.Password = "" - opts.Auth.Token = "" - } - - default: - return fmt.Errorf("pipeline '%s' sink[%d]: invalid auth type '%s' (valid: none, basic, token, scram, mtls)", - pipelineName, index, opts.Auth.Type) - } - } - return nil } -// validateFormatterConfig validates formatter configuration -func validateFormatterConfig(p *PipelineConfig) error { - if p.Format == nil { - p.Format = &FormatConfig{ - Type: "raw", - } - } else if p.Format.Type == "" { - p.Format.Type = "raw" // Default - } - - switch p.Format.Type { - - case "raw": - if p.Format.RawFormatOptions == nil { - p.Format.RawFormatOptions = &RawFormatterOptions{} - } - - case "txt": - if p.Format.TxtFormatOptions == nil { - p.Format.TxtFormatOptions = &TxtFormatterOptions{} - } - - // Default template format - templateStr := "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}" - if p.Format.TxtFormatOptions.Template != "" { - p.Format.TxtFormatOptions.Template = templateStr - } - - // Default timestamp format - timestampFormat := time.RFC3339 - if p.Format.TxtFormatOptions.TimestampFormat != "" { - p.Format.TxtFormatOptions.TimestampFormat = timestampFormat - } - - case "json": - if p.Format.JSONFormatOptions == nil { - p.Format.JSONFormatOptions = &JSONFormatterOptions{} - } - } - - return nil -} - -// Helper validation functions for nested configs +// validateNetLimit validates nested NetLimitConfig settings. func validateNetLimit(pipelineName, location string, nl *NetLimitConfig) error { if !nl.Enabled { return nil // Skip validation if disabled @@ -782,21 +762,45 @@ func validateNetLimit(pipelineName, location string, nl *NetLimitConfig) error { return nil } -func validateTLS(pipelineName, location string, tls *TLSConfig) error { +// validateTLSServer validates the new TLSServerConfig struct. +func validateTLSServer(pipelineName, location string, tls *TLSServerConfig) error { if !tls.Enabled { return nil // Skip validation if disabled } - // If TLS enabled, cert and key files required (unless skip verify) - if !tls.SkipVerify { - if tls.CertFile == "" || tls.KeyFile == "" { - return fmt.Errorf("pipeline '%s' %s: TLS enabled requires cert_file and key_file", pipelineName, location) - } + // If TLS is enabled for a server, cert and key files are mandatory. + if tls.CertFile == "" || tls.KeyFile == "" { + return fmt.Errorf("pipeline '%s' %s: TLS enabled requires both cert_file and key_file", pipelineName, location) + } + + // If mTLS (ClientAuth) is enabled, a client CA file is mandatory. + if tls.ClientAuth && tls.ClientCAFile == "" { + return fmt.Errorf("pipeline '%s' %s: client_auth is enabled, which requires a client_ca_file", pipelineName, location) } return nil } +// validateTLSClient validates the new TLSClientConfig struct. +func validateTLSClient(pipelineName, location string, tls *TLSClientConfig) error { + if !tls.Enabled { + return nil // Skip validation if disabled + } + + // If verification is not skipped, a server CA file must be provided. + if !tls.InsecureSkipVerify && tls.ServerCAFile == "" { + return fmt.Errorf("pipeline '%s' %s: TLS verification is enabled (insecure_skip_verify=false) but server_ca_file is not provided", pipelineName, location) + } + + // For client mTLS, both the cert and key must be provided together. + if (tls.ClientCertFile != "" && tls.ClientKeyFile == "") || (tls.ClientCertFile == "" && tls.ClientKeyFile != "") { + return fmt.Errorf("pipeline '%s' %s: for client mTLS, both client_cert_file and client_key_file must be provided", pipelineName, location) + } + + return nil +} + +// validateHeartbeat validates nested HeartbeatConfig settings. func validateHeartbeat(pipelineName, location string, hb *HeartbeatConfig) error { if !hb.Enabled { return nil // Skip validation if disabled @@ -806,140 +810,5 @@ func validateHeartbeat(pipelineName, location string, hb *HeartbeatConfig) error return fmt.Errorf("pipeline '%s' %s: heartbeat interval must be at least 1000ms", pipelineName, location) } - return nil -} - -func validateServerAuth(pipelineName string, auth *ServerAuthConfig) error { - if auth.Type == "" || auth.Type == "none" { - return nil - } - - // Count populated auth configs - populated := 0 - var populatedType string - - if auth.Basic != nil { - populated++ - populatedType = "basic" - } - if auth.Token != nil { - populated++ - populatedType = "token" - } - if auth.Scram != nil { - populated++ - populatedType = "scram" - } - - if populated == 0 { - return fmt.Errorf("pipeline '%s': auth type '%s' specified but config missing", pipelineName, auth.Type) - } - if populated > 1 { - return fmt.Errorf("pipeline '%s': multiple auth configurations provided", pipelineName) - } - if populatedType != auth.Type { - return fmt.Errorf("pipeline '%s': auth type mismatch - type is '%s' but config is for '%s'", - pipelineName, auth.Type, populatedType) - } - - // Validate specific auth type - switch auth.Type { - case "basic": - if len(auth.Basic.Users) == 0 { - return fmt.Errorf("pipeline '%s': basic auth requires at least one user", pipelineName) - } - for i, user := range auth.Basic.Users { - if err := lconfig.NonEmpty(user.Username); err != nil { - return fmt.Errorf("pipeline '%s': basic auth user[%d] missing username", pipelineName, i) - } - if err := lconfig.NonEmpty(user.PasswordHash); err != nil { - return fmt.Errorf("pipeline '%s': basic auth user[%d] missing password_hash", pipelineName, i) - } - } - case "token": - if len(auth.Token.Tokens) == 0 { - return fmt.Errorf("pipeline '%s': token auth requires at least one token", pipelineName) - } - case "scram": - if len(auth.Scram.Users) == 0 { - return fmt.Errorf("pipeline '%s': scram auth requires at least one user", pipelineName) - } - for i, user := range auth.Scram.Users { - if err := lconfig.NonEmpty(user.Username); err != nil { - return fmt.Errorf("pipeline '%s': scram auth user[%d] missing username", pipelineName, i) - } - // Validate required SCRAM fields - if user.StoredKey == "" || user.ServerKey == "" || user.Salt == "" { - return fmt.Errorf("pipeline '%s': scram auth user[%d] missing required fields", pipelineName, i) - } - } - default: - return fmt.Errorf("pipeline '%s': unknown auth type '%s'", pipelineName, auth.Type) - } - - return nil -} - -func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error { - if cfg == nil { - return nil - } - - if cfg.Rate < 0 { - return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName) - } - - if cfg.Burst < 0 { - return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName) - } - - if cfg.MaxEntrySizeBytes < 0 { - return fmt.Errorf("pipeline '%s': max entry size bytes cannot be negative", pipelineName) - } - - // Validate policy - switch strings.ToLower(cfg.Policy) { - case "", "pass", "drop": - // Valid policies - default: - return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')", - pipelineName, cfg.Policy) - } - - return nil -} - -func validateFilter(pipelineName string, filterIndex int, cfg *FilterConfig) error { - // Validate filter type - switch cfg.Type { - case FilterTypeInclude, FilterTypeExclude, "": - // Valid types - default: - return fmt.Errorf("pipeline '%s' filter[%d]: invalid type '%s' (must be 'include' or 'exclude')", - pipelineName, filterIndex, cfg.Type) - } - - // Validate filter logic - switch cfg.Logic { - case FilterLogicOr, FilterLogicAnd, "": - // Valid logic - default: - return fmt.Errorf("pipeline '%s' filter[%d]: invalid logic '%s' (must be 'or' or 'and')", - pipelineName, filterIndex, cfg.Logic) - } - - // Empty patterns is valid - passes everything - if len(cfg.Patterns) == 0 { - return nil - } - - // Validate regex patterns - for i, pattern := range cfg.Patterns { - if _, err := regexp.Compile(pattern); err != nil { - return fmt.Errorf("pipeline '%s' filter[%d] pattern[%d] '%s': invalid regex: %w", - pipelineName, filterIndex, i, pattern, err) - } - } - return nil } \ No newline at end of file diff --git a/src/internal/core/const.go b/src/internal/core/const.go deleted file mode 100644 index e8b4a64..0000000 --- a/src/internal/core/const.go +++ /dev/null @@ -1,13 +0,0 @@ -// FILE: logwisp/src/internal/core/const.go -package core - -// Argon2id parameters -const ( - Argon2Time = 3 - Argon2Memory = 64 * 1024 // 64 MB - Argon2Threads = 4 - Argon2SaltLen = 16 - Argon2KeyLen = 32 -) - -const DefaultTokenLength = 32 \ No newline at end of file diff --git a/src/internal/core/types.go b/src/internal/core/data.go similarity index 82% rename from src/internal/core/types.go rename to src/internal/core/data.go index ac79878..d12c289 100644 --- a/src/internal/core/types.go +++ b/src/internal/core/data.go @@ -1,4 +1,4 @@ -// FILE: logwisp/src/internal/core/types.go +// FILE: logwisp/src/internal/core/data.go package core import ( @@ -6,6 +6,8 @@ import ( "time" ) +const MaxSessionTime = time.Minute * 30 + // Represents a single log record flowing through the pipeline type LogEntry struct { Time time.Time `json:"time"` diff --git a/src/internal/filter/chain.go b/src/internal/filter/chain.go index 41e171a..d646849 100644 --- a/src/internal/filter/chain.go +++ b/src/internal/filter/chain.go @@ -11,7 +11,7 @@ import ( "github.com/lixenwraith/log" ) -// Manages multiple filters in sequence +// Chain manages a sequence of filters, applying them in order. type Chain struct { filters []*Filter logger *log.Logger @@ -21,7 +21,7 @@ type Chain struct { totalPassed atomic.Uint64 } -// Creates a new filter chain from configurations +// NewChain creates a new filter chain from a slice of filter configurations. func NewChain(configs []config.FilterConfig, logger *log.Logger) (*Chain, error) { chain := &Chain{ filters: make([]*Filter, 0, len(configs)), @@ -42,7 +42,7 @@ func NewChain(configs []config.FilterConfig, logger *log.Logger) (*Chain, error) return chain, nil } -// Runs all filters in sequence, returns true if the entry passes all filters +// Apply runs a log entry through all filters in the chain. func (c *Chain) Apply(entry core.LogEntry) bool { c.totalProcessed.Add(1) @@ -67,7 +67,7 @@ func (c *Chain) Apply(entry core.LogEntry) bool { return true } -// Returns chain statistics +// GetStats returns aggregated statistics for the entire chain. func (c *Chain) GetStats() map[string]any { filterStats := make([]map[string]any, len(c.filters)) for i, filter := range c.filters { diff --git a/src/internal/filter/filter.go b/src/internal/filter/filter.go index 6f1045c..894651b 100644 --- a/src/internal/filter/filter.go +++ b/src/internal/filter/filter.go @@ -13,7 +13,7 @@ import ( "github.com/lixenwraith/log" ) -// Applies regex-based filtering to log entries +// Filter applies regex-based filtering to log entries. type Filter struct { config config.FilterConfig patterns []*regexp.Regexp @@ -26,7 +26,7 @@ type Filter struct { totalDropped atomic.Uint64 } -// Creates a new filter from configuration +// NewFilter creates a new filter from a configuration. func NewFilter(cfg config.FilterConfig, logger *log.Logger) (*Filter, error) { // Set defaults if cfg.Type == "" { @@ -60,7 +60,7 @@ func NewFilter(cfg config.FilterConfig, logger *log.Logger) (*Filter, error) { return f, nil } -// Checks if a log entry should be passed through +// Apply determines if a log entry should be passed through based on the filter's rules. func (f *Filter) Apply(entry core.LogEntry) bool { f.totalProcessed.Add(1) @@ -130,7 +130,44 @@ func (f *Filter) Apply(entry core.LogEntry) bool { return shouldPass } -// Checks if text matches the patterns according to the logic +// GetStats returns the filter's current statistics. +func (f *Filter) GetStats() map[string]any { + return map[string]any{ + "type": f.config.Type, + "logic": f.config.Logic, + "pattern_count": len(f.patterns), + "total_processed": f.totalProcessed.Load(), + "total_matched": f.totalMatched.Load(), + "total_dropped": f.totalDropped.Load(), + } +} + +// UpdatePatterns allows for dynamic, thread-safe updates to the filter's regex patterns. +func (f *Filter) UpdatePatterns(patterns []string) error { + compiled := make([]*regexp.Regexp, 0, len(patterns)) + + // Compile all patterns first + for i, pattern := range patterns { + re, err := regexp.Compile(pattern) + if err != nil { + return fmt.Errorf("invalid regex pattern[%d] '%s': %w", i, pattern, err) + } + compiled = append(compiled, re) + } + + // Update atomically + f.mu.Lock() + f.patterns = compiled + f.config.Patterns = patterns + f.mu.Unlock() + + f.logger.Info("msg", "Filter patterns updated", + "component", "filter", + "pattern_count", len(patterns)) + return nil +} + +// matches checks if the given text matches the filter's patterns according to its logic. func (f *Filter) matches(text string) bool { switch f.config.Logic { case config.FilterLogicOr: @@ -158,41 +195,4 @@ func (f *Filter) matches(text string) bool { "logic", f.config.Logic) return false } -} - -// Returns filter statistics -func (f *Filter) GetStats() map[string]any { - return map[string]any{ - "type": f.config.Type, - "logic": f.config.Logic, - "pattern_count": len(f.patterns), - "total_processed": f.totalProcessed.Load(), - "total_matched": f.totalMatched.Load(), - "total_dropped": f.totalDropped.Load(), - } -} - -// Allows dynamic pattern updates -func (f *Filter) UpdatePatterns(patterns []string) error { - compiled := make([]*regexp.Regexp, 0, len(patterns)) - - // Compile all patterns first - for i, pattern := range patterns { - re, err := regexp.Compile(pattern) - if err != nil { - return fmt.Errorf("invalid regex pattern[%d] '%s': %w", i, pattern, err) - } - compiled = append(compiled, re) - } - - // Update atomically - f.mu.Lock() - f.patterns = compiled - f.config.Patterns = patterns - f.mu.Unlock() - - f.logger.Info("msg", "Filter patterns updated", - "component", "filter", - "pattern_count", len(patterns)) - return nil } \ No newline at end of file diff --git a/src/internal/format/format.go b/src/internal/format/format.go index d0f8b18..8617155 100644 --- a/src/internal/format/format.go +++ b/src/internal/format/format.go @@ -10,16 +10,16 @@ import ( "github.com/lixenwraith/log" ) -// Defines the interface for transforming a LogEntry into a byte slice. +// Formatter defines the interface for transforming a LogEntry into a byte slice. type Formatter interface { // Format takes a LogEntry and returns the formatted log as a byte slice. Format(entry core.LogEntry) ([]byte, error) - // Name returns the formatter type name + // Name returns the formatter's type name (e.g., "json", "raw"). Name() string } -// Creates a new Formatter based on the provided configuration. +// NewFormatter is a factory function that creates a Formatter based on the provided configuration. func NewFormatter(cfg *config.FormatConfig, logger *log.Logger) (Formatter, error) { switch cfg.Type { case "json": diff --git a/src/internal/format/json.go b/src/internal/format/json.go index 282310c..4151444 100644 --- a/src/internal/format/json.go +++ b/src/internal/format/json.go @@ -12,13 +12,13 @@ import ( "github.com/lixenwraith/log" ) -// Produces structured JSON logs +// JSONFormatter produces structured JSON logs from LogEntry objects. type JSONFormatter struct { config *config.JSONFormatterOptions logger *log.Logger } -// Creates a new JSON formatter +// NewJSONFormatter creates a new JSON formatter from configuration options. func NewJSONFormatter(opts *config.JSONFormatterOptions, logger *log.Logger) (*JSONFormatter, error) { f := &JSONFormatter{ config: opts, @@ -28,7 +28,7 @@ func NewJSONFormatter(opts *config.JSONFormatterOptions, logger *log.Logger) (*J return f, nil } -// Formats the log entry as JSON +// Format transforms a single LogEntry into a JSON byte slice. func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) { // Start with a clean map output := make(map[string]any) @@ -92,13 +92,12 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) { return append(result, '\n'), nil } -// Returns the formatter name +// Name returns the formatter's type name. func (f *JSONFormatter) Name() string { return "json" } -// Formats multiple entries as a JSON array -// This is a special method for sinks that need to batch entries +// FormatBatch transforms a slice of LogEntry objects into a single JSON array byte slice. func (f *JSONFormatter) FormatBatch(entries []core.LogEntry) ([]byte, error) { // For batching, we need to create an array of formatted objects batch := make([]json.RawMessage, 0, len(entries)) diff --git a/src/internal/format/raw.go b/src/internal/format/raw.go index 6c3f745..f9ae43f 100644 --- a/src/internal/format/raw.go +++ b/src/internal/format/raw.go @@ -8,13 +8,13 @@ import ( "github.com/lixenwraith/log" ) -// Outputs the log message as-is with a newline +// RawFormatter outputs the raw log message, optionally with a newline. type RawFormatter struct { config *config.RawFormatterOptions logger *log.Logger } -// Creates a new raw formatter +// NewRawFormatter creates a new raw pass-through formatter. func NewRawFormatter(cfg *config.RawFormatterOptions, logger *log.Logger) (*RawFormatter, error) { return &RawFormatter{ config: cfg, @@ -22,7 +22,7 @@ func NewRawFormatter(cfg *config.RawFormatterOptions, logger *log.Logger) (*RawF }, nil } -// Returns the message with a newline appended +// Format returns the raw message from the LogEntry as a byte slice. func (f *RawFormatter) Format(entry core.LogEntry) ([]byte, error) { // TODO: Standardize not to add "\n" when processing raw, check lixenwraith/log for consistency if f.config.AddNewLine { @@ -32,7 +32,7 @@ func (f *RawFormatter) Format(entry core.LogEntry) ([]byte, error) { } } -// Returns the formatter name +// Name returns the formatter's type name. func (f *RawFormatter) Name() string { return "raw" } \ No newline at end of file diff --git a/src/internal/format/txt.go b/src/internal/format/txt.go index aef9d96..d8d95af 100644 --- a/src/internal/format/txt.go +++ b/src/internal/format/txt.go @@ -14,14 +14,14 @@ import ( "github.com/lixenwraith/log" ) -// Produces human-readable text logs using templates +// TxtFormatter produces human-readable, template-based text logs. type TxtFormatter struct { config *config.TxtFormatterOptions template *template.Template logger *log.Logger } -// Creates a new text formatter +// NewTxtFormatter creates a new text formatter from a template configuration. func NewTxtFormatter(opts *config.TxtFormatterOptions, logger *log.Logger) (*TxtFormatter, error) { f := &TxtFormatter{ config: opts, @@ -47,7 +47,7 @@ func NewTxtFormatter(opts *config.TxtFormatterOptions, logger *log.Logger) (*Txt return f, nil } -// Formats the log entry using the template +// Format transforms a LogEntry into a text byte slice using the configured template. func (f *TxtFormatter) Format(entry core.LogEntry) ([]byte, error) { // Prepare data for template data := map[string]any{ @@ -91,7 +91,7 @@ func (f *TxtFormatter) Format(entry core.LogEntry) ([]byte, error) { return result, nil } -// Returns the formatter name +// Name returns the formatter's type name. func (f *TxtFormatter) Name() string { return "txt" } \ No newline at end of file diff --git a/src/internal/limit/net.go b/src/internal/limit/net.go index 022ab4a..b157a1a 100644 --- a/src/internal/limit/net.go +++ b/src/internal/limit/net.go @@ -14,7 +14,7 @@ import ( "github.com/lixenwraith/log" ) -// DenialReason indicates why a request was denied +// DenialReason indicates why a network request was denied. type DenialReason string // ** THIS PROGRAM IS IPV4 ONLY !!** @@ -32,7 +32,7 @@ const ( ReasonInvalidIP DenialReason = "Invalid IP address" ) -// NetLimiter manages net limiting for a transport +// NetLimiter manages network-level limiting including ACLs, rate limits, and connection counts. type NetLimiter struct { config *config.NetLimitConfig logger *log.Logger @@ -75,20 +75,21 @@ type NetLimiter struct { cleanupDone chan struct{} } +// ipLimiter holds the rate limiting and activity state for a single IP address. type ipLimiter struct { bucket *TokenBucket lastSeen time.Time connections atomic.Int64 } -// Connection tracking with activity timestamp +// connTracker tracks active connections and their last activity. type connTracker struct { connections atomic.Int64 lastSeen time.Time mu sync.Mutex } -// Creates a new net limiter +// NewNetLimiter creates a new network limiter from configuration. func NewNetLimiter(cfg *config.NetLimitConfig, logger *log.Logger) *NetLimiter { if cfg == nil { return nil @@ -145,120 +146,7 @@ func NewNetLimiter(cfg *config.NetLimitConfig, logger *log.Logger) *NetLimiter { return l } -// parseIPLists parses and validates IP whitelist/blacklist -func (l *NetLimiter) parseIPLists() { - // Parse whitelist - for _, entry := range l.config.IPWhitelist { - if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil { - l.ipWhitelist = append(l.ipWhitelist, ipNet) - } - } - - // Parse blacklist - for _, entry := range l.config.IPBlacklist { - if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil { - l.ipBlacklist = append(l.ipBlacklist, ipNet) - } - } -} - -// parseIPEntry parses a single IP or CIDR entry -func (l *NetLimiter) parseIPEntry(entry, listType string) *net.IPNet { - // Handle single IP - if !strings.Contains(entry, "/") { - ip := net.ParseIP(entry) - if ip == nil { - l.logger.Warn("msg", "Invalid IP entry", - "component", "netlimit", - "list", listType, - "entry", entry) - return nil - } - - // Reject IPv6 - if ip.To4() == nil { - l.logger.Warn("msg", "IPv6 address rejected", - "component", "netlimit", - "list", listType, - "entry", entry, - "reason", IPv4Only) - return nil - } - - return &net.IPNet{IP: ip.To4(), Mask: net.CIDRMask(32, 32)} - } - - // Parse CIDR - ipAddr, ipNet, err := net.ParseCIDR(entry) - if err != nil { - l.logger.Warn("msg", "Invalid CIDR entry", - "component", "netlimit", - "list", listType, - "entry", entry, - "error", err) - return nil - } - - // Reject IPv6 CIDR - if ipAddr.To4() == nil { - l.logger.Warn("msg", "IPv6 CIDR rejected", - "component", "netlimit", - "list", listType, - "entry", entry, - "reason", IPv4Only) - return nil - } - - // Ensure mask is IPv4 - _, bits := ipNet.Mask.Size() - if bits != 32 { - l.logger.Warn("msg", "Non-IPv4 CIDR mask rejected", - "component", "netlimit", - "list", listType, - "entry", entry, - "mask_bits", bits, - "reason", IPv4Only) - return nil - } - - return &net.IPNet{IP: ipAddr.To4(), Mask: ipNet.Mask} -} - -// checkIPAccess checks if an IP is allowed by ACLs -func (l *NetLimiter) checkIPAccess(ip net.IP) DenialReason { - // 1. Check blacklist first (deny takes precedence) - for _, ipNet := range l.ipBlacklist { - if ipNet.Contains(ip) { - l.blockedByBlacklist.Add(1) - l.logger.Debug("msg", "IP denied by blacklist", - "component", "netlimit", - "ip", ip.String(), - "rule", ipNet.String()) - return ReasonBlacklisted - } - } - - // 2. If whitelist is configured, IP must be in it - if len(l.ipWhitelist) > 0 { - for _, ipNet := range l.ipWhitelist { - if ipNet.Contains(ip) { - l.logger.Debug("msg", "IP allowed by whitelist", - "component", "netlimit", - "ip", ip.String(), - "rule", ipNet.String()) - return ReasonAllowed - } - } - l.blockedByWhitelist.Add(1) - l.logger.Debug("msg", "IP not in whitelist", - "component", "netlimit", - "ip", ip.String()) - return ReasonNotWhitelisted - } - - return ReasonAllowed -} - +// Shutdown gracefully stops the net limiter's background cleanup processes. func (l *NetLimiter) Shutdown() { if l == nil { return @@ -278,7 +166,7 @@ func (l *NetLimiter) Shutdown() { } } -// Checks if an HTTP request should be allowed: IP access control + connection limits (IP only) + calls +// CheckHTTP checks if an incoming HTTP request is allowed based on all configured limits. func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int64, message string) { if l == nil { return true, 0, "" @@ -361,20 +249,7 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6 return true, 0, "" } -// Update connection activity -func (l *NetLimiter) updateConnectionActivity(ip string) { - l.connMu.RLock() - tracker, exists := l.ipConnections[ip] - l.connMu.RUnlock() - - if exists { - tracker.mu.Lock() - tracker.lastSeen = time.Now() - tracker.mu.Unlock() - } -} - -// Checks if a TCP connection should be allowed: IP access control + calls checkIPLimit() +// CheckTCP checks if an incoming TCP connection is allowed based on ACLs and rate limits. func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool { if l == nil { return true @@ -422,11 +297,7 @@ func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool { return true } -func isIPv4(ip net.IP) bool { - return ip.To4() != nil -} - -// Tracks a new connection for an IP +// AddConnection tracks a new connection from a specific remote address (for HTTP). func (l *NetLimiter) AddConnection(remoteAddr string) { if l == nil { return @@ -477,7 +348,7 @@ func (l *NetLimiter) AddConnection(remoteAddr string) { "connections", newCount) } -// Removes a connection for an IP +// RemoveConnection removes a tracked connection (for HTTP). func (l *NetLimiter) RemoveConnection(remoteAddr string) { if l == nil { return @@ -527,7 +398,113 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) { } } -// Returns net limiter statistics +// TrackConnection checks connection limits and tracks a new connection (for TCP). +func (l *NetLimiter) TrackConnection(ip string, user string, token string) bool { + if l == nil { + return true + } + + l.connMu.Lock() + defer l.connMu.Unlock() + + // Check total connections limit (0 = disabled) + if l.config.MaxConnectionsTotal > 0 { + currentTotal := l.totalConnections.Load() + if currentTotal >= l.config.MaxConnectionsTotal { + l.blockedByConnLimit.Add(1) + l.logger.Debug("msg", "TCP connection blocked by total limit", + "component", "netlimit", + "current_total", currentTotal, + "max_connections_total", l.config.MaxConnectionsTotal) + return false + } + } + + // Check per-IP connection limit (0 = disabled) + if l.config.MaxConnectionsPerIP > 0 && ip != "" { + tracker, exists := l.ipConnections[ip] + if !exists { + tracker = &connTracker{lastSeen: time.Now()} + l.ipConnections[ip] = tracker + } + if tracker.connections.Load() >= l.config.MaxConnectionsPerIP { + l.blockedByConnLimit.Add(1) + l.logger.Debug("msg", "TCP connection blocked by IP limit", + "component", "netlimit", + "ip", ip, + "current", tracker.connections.Load(), + "max", l.config.MaxConnectionsPerIP) + return false + } + } + + // All checks passed, increment counters + l.totalConnections.Add(1) + + if ip != "" && l.config.MaxConnectionsPerIP > 0 { + if tracker, exists := l.ipConnections[ip]; exists { + tracker.connections.Add(1) + tracker.mu.Lock() + tracker.lastSeen = time.Now() + tracker.mu.Unlock() + } + } + + return true +} + +// ReleaseConnection decrements connection counters when a connection is closed (for TCP). +func (l *NetLimiter) ReleaseConnection(ip string, user string, token string) { + if l == nil { + return + } + + l.connMu.Lock() + defer l.connMu.Unlock() + + // Decrement total + if l.totalConnections.Load() > 0 { + l.totalConnections.Add(-1) + } + + // Decrement IP counter + if ip != "" { + if tracker, exists := l.ipConnections[ip]; exists { + if tracker.connections.Load() > 0 { + tracker.connections.Add(-1) + } + tracker.mu.Lock() + tracker.lastSeen = time.Now() + tracker.mu.Unlock() + } + } + + // Decrement user counter + if user != "" { + if tracker, exists := l.userConnections[user]; exists { + if tracker.connections.Load() > 0 { + tracker.connections.Add(-1) + } + tracker.mu.Lock() + tracker.lastSeen = time.Now() + tracker.mu.Unlock() + } + } + + // Decrement token counter + if token != "" { + if tracker, exists := l.tokenConnections[token]; exists { + if tracker.connections.Load() > 0 { + tracker.connections.Add(-1) + } + tracker.mu.Lock() + tracker.lastSeen = time.Now() + tracker.mu.Unlock() + } + } +} + +// GetStats returns a map of the net limiter's current statistics. func (l *NetLimiter) GetStats() map[string]any { if l == nil { return map[string]any{"enabled": false} @@ -613,53 +590,26 @@ func (l *NetLimiter) GetStats() map[string]any { } } -// Performs IP net limit check (req/sec) -func (l *NetLimiter) checkIPLimit(ip string) bool { - // Validate IP format - parsedIP := net.ParseIP(ip) - if parsedIP == nil || !isIPv4(parsedIP) { - l.logger.Warn("msg", "Invalid or non-IPv4 address in rate limiter", - "component", "netlimit", - "ip", ip) - return false - } +// cleanupLoop runs a periodic cleanup of stale limiter and tracker entries. +func (l *NetLimiter) cleanupLoop() { + defer close(l.cleanupDone) - // Maybe run cleanup - l.maybeCleanup() + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() - // IP limit - l.ipMu.Lock() - lim, exists := l.ipLimiters[ip] - if !exists { - // Create new limiter for this IP - lim = &ipLimiter{ - bucket: NewTokenBucket( - float64(l.config.BurstSize), - l.config.RequestsPerSecond, - ), - lastSeen: time.Now(), + for { + select { + case <-l.ctx.Done(): + // Exit when context is cancelled + l.logger.Debug("msg", "Cleanup loop stopping", "component", "netlimit") + return + case <-ticker.C: + l.cleanup() } - l.ipLimiters[ip] = lim - l.uniqueIPs.Add(1) - - l.logger.Debug("msg", "Created new IP limiter", - "ip", ip, - "total_ips", l.uniqueIPs.Load()) - } else { - lim.lastSeen = time.Now() } - l.ipMu.Unlock() - - // Rate limit check - allowed := lim.bucket.Allow() - if !allowed { - l.blockedByRateLimit.Add(1) - } - - return allowed } -// Runs cleanup if enough time has passed +// maybeCleanup triggers an asynchronous cleanup if enough time has passed since the last one. func (l *NetLimiter) maybeCleanup() { l.cleanupMu.Lock() @@ -685,7 +635,7 @@ func (l *NetLimiter) maybeCleanup() { }() } -// Removes stale IP limiters +// cleanup removes stale IP limiters and connection trackers from memory. func (l *NetLimiter) cleanup() { staleTimeout := 5 * time.Minute now := time.Now() @@ -767,127 +717,180 @@ func (l *NetLimiter) cleanup() { } } -// Runs periodic cleanup -func (l *NetLimiter) cleanupLoop() { - defer close(l.cleanupDone) +// checkIPAccess verifies if an IP address is permitted by the configured ACLs. +func (l *NetLimiter) checkIPAccess(ip net.IP) DenialReason { + // 1. Check blacklist first (deny takes precedence) + for _, ipNet := range l.ipBlacklist { + if ipNet.Contains(ip) { + l.blockedByBlacklist.Add(1) + l.logger.Debug("msg", "IP denied by blacklist", + "component", "netlimit", + "ip", ip.String(), + "rule", ipNet.String()) + return ReasonBlacklisted + } + } - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() + // 2. If whitelist is configured, IP must be in it + if len(l.ipWhitelist) > 0 { + for _, ipNet := range l.ipWhitelist { + if ipNet.Contains(ip) { + l.logger.Debug("msg", "IP allowed by whitelist", + "component", "netlimit", + "ip", ip.String(), + "rule", ipNet.String()) + return ReasonAllowed + } + } + l.blockedByWhitelist.Add(1) + l.logger.Debug("msg", "IP not in whitelist", + "component", "netlimit", + "ip", ip.String()) + return ReasonNotWhitelisted + } - for { - select { - case <-l.ctx.Done(): - // Exit when context is cancelled - l.logger.Debug("msg", "Cleanup loop stopping", "component", "netlimit") - return - case <-ticker.C: - l.cleanup() + return ReasonAllowed +} + +// checkIPLimit enforces the requests-per-second limit for a given IP address. +func (l *NetLimiter) checkIPLimit(ip string) bool { + // Validate IP format + parsedIP := net.ParseIP(ip) + if parsedIP == nil || !isIPv4(parsedIP) { + l.logger.Warn("msg", "Invalid or non-IPv4 address in rate limiter", + "component", "netlimit", + "ip", ip) + return false + } + + // Maybe run cleanup + l.maybeCleanup() + + // IP limit + l.ipMu.Lock() + lim, exists := l.ipLimiters[ip] + if !exists { + // Create new limiter for this IP + lim = &ipLimiter{ + bucket: NewTokenBucket( + float64(l.config.BurstSize), + l.config.RequestsPerSecond, + ), + lastSeen: time.Now(), + } + l.ipLimiters[ip] = lim + l.uniqueIPs.Add(1) + + l.logger.Debug("msg", "Created new IP limiter", + "ip", ip, + "total_ips", l.uniqueIPs.Load()) + } else { + lim.lastSeen = time.Now() + } + l.ipMu.Unlock() + + // Rate limit check + allowed := lim.bucket.Allow() + if !allowed { + l.blockedByRateLimit.Add(1) + } + + return allowed +} + +// parseIPLists converts the string-based IP rules from the config into parsed net.IPNet objects. +func (l *NetLimiter) parseIPLists() { + // Parse whitelist + for _, entry := range l.config.IPWhitelist { + if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil { + l.ipWhitelist = append(l.ipWhitelist, ipNet) + } + } + + // Parse blacklist + for _, entry := range l.config.IPBlacklist { + if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil { + l.ipBlacklist = append(l.ipBlacklist, ipNet) } } } -// Tracks a new connection with optional user/token info: Connection limits (IP/user/token/total) for TCP only -func (l *NetLimiter) TrackConnection(ip string, user string, token string) bool { - if l == nil { - return true - } - - l.connMu.Lock() - defer l.connMu.Unlock() - - // Check total connections limit (0 = disabled) - if l.config.MaxConnectionsTotal > 0 { - currentTotal := l.totalConnections.Load() - if currentTotal >= l.config.MaxConnectionsTotal { - l.blockedByConnLimit.Add(1) - l.logger.Debug("msg", "TCP connection blocked by total limit", +// parseIPEntry parses a single IP address or CIDR notation string into a net.IPNet object. +func (l *NetLimiter) parseIPEntry(entry, listType string) *net.IPNet { + // Handle single IP + if !strings.Contains(entry, "/") { + ip := net.ParseIP(entry) + if ip == nil { + l.logger.Warn("msg", "Invalid IP entry", "component", "netlimit", - "current_total", currentTotal, - "max_connections_total", l.config.MaxConnectionsTotal) - return false + "list", listType, + "entry", entry) + return nil } - } - // Check per-IP connection limit (0 = disabled) - if l.config.MaxConnectionsPerIP > 0 && ip != "" { - tracker, exists := l.ipConnections[ip] - if !exists { - tracker = &connTracker{lastSeen: time.Now()} - l.ipConnections[ip] = tracker - } - if tracker.connections.Load() >= l.config.MaxConnectionsPerIP { - l.blockedByConnLimit.Add(1) - l.logger.Debug("msg", "TCP connection blocked by IP limit", + // Reject IPv6 + if ip.To4() == nil { + l.logger.Warn("msg", "IPv6 address rejected", "component", "netlimit", - "ip", ip, - "current", tracker.connections.Load(), - "max", l.config.MaxConnectionsPerIP) - return false + "list", listType, + "entry", entry, + "reason", IPv4Only) + return nil } + + return &net.IPNet{IP: ip.To4(), Mask: net.CIDRMask(32, 32)} } - // All checks passed, increment counters - l.totalConnections.Add(1) - - if ip != "" && l.config.MaxConnectionsPerIP > 0 { - if tracker, exists := l.ipConnections[ip]; exists { - tracker.connections.Add(1) - tracker.mu.Lock() - tracker.lastSeen = time.Now() - tracker.mu.Unlock() - } + // Parse CIDR + ipAddr, ipNet, err := net.ParseCIDR(entry) + if err != nil { + l.logger.Warn("msg", "Invalid CIDR entry", + "component", "netlimit", + "list", listType, + "entry", entry, + "error", err) + return nil } - return true + // Reject IPv6 CIDR + if ipAddr.To4() == nil { + l.logger.Warn("msg", "IPv6 CIDR rejected", + "component", "netlimit", + "list", listType, + "entry", entry, + "reason", IPv4Only) + return nil + } + + // Ensure mask is IPv4 + _, bits := ipNet.Mask.Size() + if bits != 32 { + l.logger.Warn("msg", "Non-IPv4 CIDR mask rejected", + "component", "netlimit", + "list", listType, + "entry", entry, + "mask_bits", bits, + "reason", IPv4Only) + return nil + } + + return &net.IPNet{IP: ipAddr.To4(), Mask: ipNet.Mask} } -// Releases a tracked connection -func (l *NetLimiter) ReleaseConnection(ip string, user string, token string) { - if l == nil { - return - } +// updateConnectionActivity updates the last seen timestamp for a connection tracker. +func (l *NetLimiter) updateConnectionActivity(ip string) { + l.connMu.RLock() + tracker, exists := l.ipConnections[ip] + l.connMu.RUnlock() - l.connMu.Lock() - defer l.connMu.Unlock() - - // Decrement total - if l.totalConnections.Load() > 0 { - l.totalConnections.Add(-1) + if exists { + tracker.mu.Lock() + tracker.lastSeen = time.Now() + tracker.mu.Unlock() } +} - // Decrement IP counter - if ip != "" { - if tracker, exists := l.ipConnections[ip]; exists { - if tracker.connections.Load() > 0 { - tracker.connections.Add(-1) - } - tracker.mu.Lock() - tracker.lastSeen = time.Now() - tracker.mu.Unlock() - } - } - - // Decrement user counter - if user != "" { - if tracker, exists := l.userConnections[user]; exists { - if tracker.connections.Load() > 0 { - tracker.connections.Add(-1) - } - tracker.mu.Lock() - tracker.lastSeen = time.Now() - tracker.mu.Unlock() - } - } - - // Decrement token counter - if token != "" { - if tracker, exists := l.tokenConnections[token]; exists { - if tracker.connections.Load() > 0 { - tracker.connections.Add(-1) - } - tracker.mu.Lock() - tracker.lastSeen = time.Now() - tracker.mu.Unlock() - } - } +// isIPv4 is a helper function to check if a net.IP is an IPv4 address. +func isIPv4(ip net.IP) bool { + return ip.To4() != nil } \ No newline at end of file diff --git a/src/internal/limit/rate.go b/src/internal/limit/rate.go index 9cca25c..da54244 100644 --- a/src/internal/limit/rate.go +++ b/src/internal/limit/rate.go @@ -11,7 +11,7 @@ import ( "github.com/lixenwraith/log" ) -// Enforces rate limits on log entries flowing through a pipeline. +// RateLimiter enforces rate limits on log entries flowing through a pipeline. type RateLimiter struct { bucket *TokenBucket policy config.RateLimitPolicy @@ -23,7 +23,7 @@ type RateLimiter struct { droppedCount atomic.Uint64 } -// Creates a new rate limiter. If cfg.Rate is 0, it returns nil. +// NewRateLimiter creates a new pipeline-level rate limiter from configuration. func NewRateLimiter(cfg config.RateLimitConfig, logger *log.Logger) (*RateLimiter, error) { if cfg.Rate <= 0 { return nil, nil // No rate limit @@ -56,8 +56,7 @@ func NewRateLimiter(cfg config.RateLimitConfig, logger *log.Logger) (*RateLimite return l, nil } -// Checks if a log entry is allowed to pass based on the rate limit. -// It returns true if the entry should pass, false if it should be dropped. +// Allow checks if a log entry is permitted to pass based on the rate limit. func (l *RateLimiter) Allow(entry core.LogEntry) bool { if l == nil || l.policy == config.PolicyPass { return true @@ -83,7 +82,7 @@ func (l *RateLimiter) Allow(entry core.LogEntry) bool { return true } -// GetStats returns the statistics for the limiter. +// GetStats returns statistics for the rate limiter. func (l *RateLimiter) GetStats() map[string]any { if l == nil { return map[string]any{ @@ -106,7 +105,7 @@ func (l *RateLimiter) GetStats() map[string]any { return stats } -// policyString returns the string representation of the policy. +// policyString returns the string representation of a rate limit policy. func policyString(p config.RateLimitPolicy) string { switch p { case config.PolicyDrop: diff --git a/src/internal/limit/token_bucket.go b/src/internal/limit/token_bucket.go index 66ea854..ad60fa7 100644 --- a/src/internal/limit/token_bucket.go +++ b/src/internal/limit/token_bucket.go @@ -6,8 +6,7 @@ import ( "time" ) -// TokenBucket implements a token bucket rate limiter -// Safe for concurrent use. +// TokenBucket implements a thread-safe token bucket rate limiter. type TokenBucket struct { capacity float64 tokens float64 @@ -16,7 +15,7 @@ type TokenBucket struct { mu sync.Mutex } -// Creates a new token bucket with given capacity and refill rate +// NewTokenBucket creates a new token bucket with a given capacity and refill rate. func NewTokenBucket(capacity float64, refillRate float64) *TokenBucket { return &TokenBucket{ capacity: capacity, @@ -26,12 +25,12 @@ func NewTokenBucket(capacity float64, refillRate float64) *TokenBucket { } } -// Attempts to consume one token, returns true if allowed +// Allow attempts to consume one token, returning true if successful. func (tb *TokenBucket) Allow() bool { return tb.AllowN(1) } -// Attempts to consume n tokens, returns true if allowed +// AllowN attempts to consume n tokens, returning true if successful. func (tb *TokenBucket) AllowN(n float64) bool { tb.mu.Lock() defer tb.mu.Unlock() @@ -45,7 +44,7 @@ func (tb *TokenBucket) AllowN(n float64) bool { return false } -// Returns the current number of available tokens +// Tokens returns the current number of available tokens in the bucket. func (tb *TokenBucket) Tokens() float64 { tb.mu.Lock() defer tb.mu.Unlock() @@ -54,8 +53,7 @@ func (tb *TokenBucket) Tokens() float64 { return tb.tokens } -// Adds tokens based on time elapsed since last refill -// MUST be called with mutex held +// refill adds new tokens to the bucket based on the elapsed time. func (tb *TokenBucket) refill() { now := time.Now() elapsed := now.Sub(tb.lastRefill).Seconds() diff --git a/src/internal/service/pipeline.go b/src/internal/service/pipeline.go index af6de71..a345df0 100644 --- a/src/internal/service/pipeline.go +++ b/src/internal/service/pipeline.go @@ -18,7 +18,7 @@ import ( "github.com/lixenwraith/log" ) -// Manages the flow of data from sources through filters to sinks +// Pipeline manages the flow of data from sources, through filters, to sinks. type Pipeline struct { Config *config.PipelineConfig Sources []source.Source @@ -33,7 +33,7 @@ type Pipeline struct { wg sync.WaitGroup } -// Contains statistics for a pipeline +// PipelineStats contains runtime statistics for a pipeline. type PipelineStats struct { StartTime time.Time TotalEntriesProcessed atomic.Uint64 @@ -44,7 +44,7 @@ type PipelineStats struct { FilterStats map[string]any } -// Creates and starts a new pipeline +// NewPipeline creates, configures, and starts a new pipeline within the service. func (s *Service) NewPipeline(cfg *config.PipelineConfig) error { s.mu.Lock() defer s.mu.Unlock() @@ -149,7 +149,7 @@ func (s *Service) NewPipeline(cfg *config.PipelineConfig) error { return nil } -// Gracefully stops the pipeline +// Shutdown gracefully stops the pipeline and all its components. func (p *Pipeline) Shutdown() { p.logger.Info("msg", "Shutting down pipeline", "component", "pipeline", @@ -187,7 +187,7 @@ func (p *Pipeline) Shutdown() { "pipeline", p.Config.Name) } -// Returns pipeline statistics +// GetStats returns a map of the pipeline's current statistics. func (p *Pipeline) GetStats() map[string]any { // Recovery to handle concurrent access during shutdown // When service is shutting down, sources/sinks might be nil or partially stopped @@ -263,7 +263,8 @@ func (p *Pipeline) GetStats() map[string]any { } } -// Runs periodic stats updates +// TODO: incomplete implementation +// startStatsUpdater runs a periodic stats updater. func (p *Pipeline) startStatsUpdater(ctx context.Context) { go func() { ticker := time.NewTicker(1 * time.Second) diff --git a/src/internal/service/service.go b/src/internal/service/service.go index a2ba640..18b63d6 100644 --- a/src/internal/service/service.go +++ b/src/internal/service/service.go @@ -15,7 +15,7 @@ import ( "github.com/lixenwraith/log" ) -// Service manages multiple pipelines +// Service manages a collection of log processing pipelines. type Service struct { pipelines map[string]*Pipeline mu sync.RWMutex @@ -25,7 +25,7 @@ type Service struct { logger *log.Logger } -// Creates a new service +// NewService creates a new, empty service. func NewService(ctx context.Context, logger *log.Logger) *Service { serviceCtx, cancel := context.WithCancel(ctx) return &Service{ @@ -36,7 +36,97 @@ func NewService(ctx context.Context, logger *log.Logger) *Service { } } -// Connects sources to sinks through filters +// GetPipeline returns a pipeline by its name. +func (s *Service) GetPipeline(name string) (*Pipeline, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + pipeline, exists := s.pipelines[name] + if !exists { + return nil, fmt.Errorf("pipeline '%s' not found", name) + } + return pipeline, nil +} + +// ListPipelines returns the names of all currently managed pipelines. +func (s *Service) ListPipelines() []string { + s.mu.RLock() + defer s.mu.RUnlock() + + names := make([]string, 0, len(s.pipelines)) + for name := range s.pipelines { + names = append(names, name) + } + return names +} + +// RemovePipeline stops and removes a pipeline from the service. +func (s *Service) RemovePipeline(name string) error { + s.mu.Lock() + defer s.mu.Unlock() + + pipeline, exists := s.pipelines[name] + if !exists { + err := fmt.Errorf("pipeline '%s' not found", name) + s.logger.Warn("msg", "Cannot remove non-existent pipeline", + "component", "service", + "pipeline", name, + "error", err) + return err + } + + s.logger.Info("msg", "Removing pipeline", "pipeline", name) + pipeline.Shutdown() + delete(s.pipelines, name) + return nil +} + +// Shutdown gracefully stops all pipelines managed by the service. +func (s *Service) Shutdown() { + s.logger.Info("msg", "Service shutdown initiated") + + s.mu.Lock() + pipelines := make([]*Pipeline, 0, len(s.pipelines)) + for _, pipeline := range s.pipelines { + pipelines = append(pipelines, pipeline) + } + s.mu.Unlock() + + // Stop all pipelines concurrently + var wg sync.WaitGroup + for _, pipeline := range pipelines { + wg.Add(1) + go func(p *Pipeline) { + defer wg.Done() + p.Shutdown() + }(pipeline) + } + wg.Wait() + + s.cancel() + s.wg.Wait() + + s.logger.Info("msg", "Service shutdown complete") +} + +// GetGlobalStats returns statistics for all pipelines. +func (s *Service) GetGlobalStats() map[string]any { + s.mu.RLock() + defer s.mu.RUnlock() + + stats := map[string]any{ + "pipelines": make(map[string]any), + "total_pipelines": len(s.pipelines), + } + + for name, pipeline := range s.pipelines { + stats["pipelines"].(map[string]any)[name] = pipeline.GetStats() + } + + return stats +} + +// wirePipeline connects a pipeline's sources to its sinks through its filter chain. func (s *Service) wirePipeline(p *Pipeline) { // For each source, subscribe and process entries for _, src := range p.Sources { @@ -113,7 +203,7 @@ func (s *Service) wirePipeline(p *Pipeline) { } } -// Creates a source instance based on configuration +// createSource is a factory function for creating a source instance from configuration. func (s *Service) createSource(cfg *config.SourceConfig) (source.Source, error) { switch cfg.Type { case "directory": @@ -129,7 +219,7 @@ func (s *Service) createSource(cfg *config.SourceConfig) (source.Source, error) } } -// Creates a sink instance based on configuration +// createSink is a factory function for creating a sink instance from configuration. func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) (sink.Sink, error) { switch cfg.Type { @@ -156,94 +246,4 @@ func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) default: return nil, fmt.Errorf("unknown sink type: %s", cfg.Type) } -} - -// Returns a pipeline by name -func (s *Service) GetPipeline(name string) (*Pipeline, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - pipeline, exists := s.pipelines[name] - if !exists { - return nil, fmt.Errorf("pipeline '%s' not found", name) - } - return pipeline, nil -} - -// Returns all pipeline names -func (s *Service) ListPipelines() []string { - s.mu.RLock() - defer s.mu.RUnlock() - - names := make([]string, 0, len(s.pipelines)) - for name := range s.pipelines { - names = append(names, name) - } - return names -} - -// Stops and removes a pipeline -func (s *Service) RemovePipeline(name string) error { - s.mu.Lock() - defer s.mu.Unlock() - - pipeline, exists := s.pipelines[name] - if !exists { - err := fmt.Errorf("pipeline '%s' not found", name) - s.logger.Warn("msg", "Cannot remove non-existent pipeline", - "component", "service", - "pipeline", name, - "error", err) - return err - } - - s.logger.Info("msg", "Removing pipeline", "pipeline", name) - pipeline.Shutdown() - delete(s.pipelines, name) - return nil -} - -// Stops all pipelines -func (s *Service) Shutdown() { - s.logger.Info("msg", "Service shutdown initiated") - - s.mu.Lock() - pipelines := make([]*Pipeline, 0, len(s.pipelines)) - for _, pipeline := range s.pipelines { - pipelines = append(pipelines, pipeline) - } - s.mu.Unlock() - - // Stop all pipelines concurrently - var wg sync.WaitGroup - for _, pipeline := range pipelines { - wg.Add(1) - go func(p *Pipeline) { - defer wg.Done() - p.Shutdown() - }(pipeline) - } - wg.Wait() - - s.cancel() - s.wg.Wait() - - s.logger.Info("msg", "Service shutdown complete") -} - -// Returns statistics for all pipelines -func (s *Service) GetGlobalStats() map[string]any { - s.mu.RLock() - defer s.mu.RUnlock() - - stats := map[string]any{ - "pipelines": make(map[string]any), - "total_pipelines": len(s.pipelines), - } - - for name, pipeline := range s.pipelines { - stats["pipelines"].(map[string]any)[name] = pipeline.GetStats() - } - - return stats } \ No newline at end of file diff --git a/src/internal/session/session.go b/src/internal/session/session.go new file mode 100644 index 0000000..9fef46e --- /dev/null +++ b/src/internal/session/session.go @@ -0,0 +1,290 @@ +// FILE: src/internal/session/session.go +package session + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "sync" + "time" +) + +// Session represents a connection session. +type Session struct { + ID string // Unique session identifier + RemoteAddr string // Client address + CreatedAt time.Time // Session creation time + LastActivity time.Time // Last activity timestamp + Metadata map[string]any // Optional metadata (e.g., TLS info) + + // Connection context + Source string // Source type: "tcp_source", "http_source", "tcp_sink", etc. +} + +// Manager handles the lifecycle of sessions. +type Manager struct { + sessions map[string]*Session + mu sync.RWMutex + + // Cleanup configuration + maxIdleTime time.Duration + cleanupTicker *time.Ticker + done chan struct{} + + // Expiry callbacks by source type + expiryCallbacks map[string]func(sessionID, remoteAddr string) + callbacksMu sync.RWMutex +} + +// NewManager creates a new session manager with a specified idle timeout. +func NewManager(maxIdleTime time.Duration) *Manager { + if maxIdleTime == 0 { + maxIdleTime = 30 * time.Minute // Default idle timeout + } + + m := &Manager{ + sessions: make(map[string]*Session), + maxIdleTime: maxIdleTime, + done: make(chan struct{}), + } + + // Start cleanup routine + m.startCleanup() + + return m +} + +// CreateSession creates and stores a new session for a connection. +func (m *Manager) CreateSession(remoteAddr string, source string, metadata map[string]any) *Session { + session := &Session{ + ID: generateSessionID(), + RemoteAddr: remoteAddr, + CreatedAt: time.Now(), + LastActivity: time.Now(), + Source: source, + Metadata: metadata, + } + + if metadata == nil { + session.Metadata = make(map[string]any) + } + + m.StoreSession(session) + return session +} + +// StoreSession adds a session to the manager. +func (m *Manager) StoreSession(session *Session) { + m.mu.Lock() + defer m.mu.Unlock() + m.sessions[session.ID] = session +} + +// GetSession retrieves a session by its unique ID. +func (m *Manager) GetSession(sessionID string) (*Session, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + session, exists := m.sessions[sessionID] + return session, exists +} + +// RemoveSession removes a session from the manager. +func (m *Manager) RemoveSession(sessionID string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.sessions, sessionID) +} + +// UpdateActivity updates the last activity timestamp for a session. +func (m *Manager) UpdateActivity(sessionID string) { + m.mu.Lock() + defer m.mu.Unlock() + + if session, exists := m.sessions[sessionID]; exists { + session.LastActivity = time.Now() + } +} + +// IsSessionActive checks if a session exists and has not been idle for too long. +func (m *Manager) IsSessionActive(sessionID string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + if session, exists := m.sessions[sessionID]; exists { + // Session exists and hasn't exceeded idle timeout + return time.Since(session.LastActivity) < m.maxIdleTime + } + return false +} + +// GetActiveSessions returns a snapshot of all currently active sessions. +func (m *Manager) GetActiveSessions() []*Session { + m.mu.RLock() + defer m.mu.RUnlock() + + sessions := make([]*Session, 0, len(m.sessions)) + for _, session := range m.sessions { + sessions = append(sessions, session) + } + return sessions +} + +// GetSessionCount returns the number of active sessions. +func (m *Manager) GetSessionCount() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.sessions) +} + +// GetSessionsBySource returns all sessions matching a specific source type. +func (m *Manager) GetSessionsBySource(source string) []*Session { + m.mu.RLock() + defer m.mu.RUnlock() + + var sessions []*Session + for _, session := range m.sessions { + if session.Source == source { + sessions = append(sessions, session) + } + } + return sessions +} + +// GetActiveSessionsBySource returns all active sessions for a given source. +func (m *Manager) GetActiveSessionsBySource(source string) []*Session { + m.mu.RLock() + defer m.mu.RUnlock() + + var sessions []*Session + now := time.Now() + + for _, session := range m.sessions { + if session.Source == source && now.Sub(session.LastActivity) < m.maxIdleTime { + sessions = append(sessions, session) + } + } + return sessions +} + +// GetStats returns statistics about the session manager. +func (m *Manager) GetStats() map[string]any { + m.mu.RLock() + defer m.mu.RUnlock() + + sourceCounts := make(map[string]int) + var totalSessions int + var oldestSession time.Time + var newestSession time.Time + + for _, session := range m.sessions { + totalSessions++ + sourceCounts[session.Source]++ + + if oldestSession.IsZero() || session.CreatedAt.Before(oldestSession) { + oldestSession = session.CreatedAt + } + if newestSession.IsZero() || session.CreatedAt.After(newestSession) { + newestSession = session.CreatedAt + } + } + + stats := map[string]any{ + "total_sessions": totalSessions, + "sessions_by_type": sourceCounts, + "max_idle_time": m.maxIdleTime.String(), + } + + if !oldestSession.IsZero() { + stats["oldest_session_age"] = time.Since(oldestSession).String() + } + if !newestSession.IsZero() { + stats["newest_session_age"] = time.Since(newestSession).String() + } + + return stats +} + +// Stop gracefully stops the session manager and its cleanup goroutine. +func (m *Manager) Stop() { + close(m.done) + if m.cleanupTicker != nil { + m.cleanupTicker.Stop() + } +} + +// RegisterExpiryCallback registers a callback function to be executed when a session expires. +func (m *Manager) RegisterExpiryCallback(source string, callback func(sessionID, remoteAddr string)) { + m.callbacksMu.Lock() + defer m.callbacksMu.Unlock() + + if m.expiryCallbacks == nil { + m.expiryCallbacks = make(map[string]func(sessionID, remoteAddr string)) + } + m.expiryCallbacks[source] = callback +} + +// UnregisterExpiryCallback removes an expiry callback for a given source type. +func (m *Manager) UnregisterExpiryCallback(source string) { + m.callbacksMu.Lock() + defer m.callbacksMu.Unlock() + + delete(m.expiryCallbacks, source) +} + +// startCleanup initializes the periodic cleanup of idle sessions. +func (m *Manager) startCleanup() { + m.cleanupTicker = time.NewTicker(5 * time.Minute) + + go func() { + for { + select { + case <-m.cleanupTicker.C: + m.cleanupIdleSessions() + case <-m.done: + return + } + } + }() +} + +// cleanupIdleSessions removes sessions that have exceeded the maximum idle time. +func (m *Manager) cleanupIdleSessions() { + m.mu.Lock() + defer m.mu.Unlock() + + now := time.Now() + expiredSessions := make([]*Session, 0) + + for id, session := range m.sessions { + idleTime := now.Sub(session.LastActivity) + + if idleTime > m.maxIdleTime { + expiredSessions = append(expiredSessions, session) + delete(m.sessions, id) + } + } + m.mu.Unlock() + + // Call callbacks outside of lock + if len(expiredSessions) > 0 { + m.callbacksMu.RLock() + defer m.callbacksMu.RUnlock() + + for _, session := range expiredSessions { + if callback, exists := m.expiryCallbacks[session.Source]; exists { + // Call callback to notify owner + go callback(session.ID, session.RemoteAddr) + } + } + } +} + +// generateSessionID creates a unique, random session identifier. +func generateSessionID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + // Fallback to timestamp-based ID + return fmt.Sprintf("session_%d", time.Now().UnixNano()) + } + return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b) +} \ No newline at end of file diff --git a/src/internal/sink/console.go b/src/internal/sink/console.go index 6ac6139..01669a4 100644 --- a/src/internal/sink/console.go +++ b/src/internal/sink/console.go @@ -16,7 +16,7 @@ import ( "github.com/lixenwraith/log" ) -// ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance +// ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance. type ConsoleSink struct { config *config.ConsoleSinkOptions input chan core.LogEntry @@ -31,7 +31,7 @@ type ConsoleSink struct { lastProcessed atomic.Value // time.Time } -// Creates a new console sink +// NewConsoleSink creates a new console sink. func NewConsoleSink(opts *config.ConsoleSinkOptions, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) { if opts == nil { return nil, fmt.Errorf("console sink options cannot be nil") @@ -73,10 +73,12 @@ func NewConsoleSink(opts *config.ConsoleSinkOptions, appLogger *log.Logger, form return s, nil } +// Input returns the channel for sending log entries. func (s *ConsoleSink) Input() chan<- core.LogEntry { return s.input } +// Start begins the processing loop for the sink. func (s *ConsoleSink) Start(ctx context.Context) error { // Start the internal writer's processing goroutine. if err := s.writer.Start(); err != nil { @@ -89,6 +91,7 @@ func (s *ConsoleSink) Start(ctx context.Context) error { return nil } +// Stop gracefully shuts down the sink. func (s *ConsoleSink) Stop() { target := s.writer.GetConfig().ConsoleTarget s.logger.Info("msg", "Stopping console sink", "target", target) @@ -103,6 +106,7 @@ func (s *ConsoleSink) Stop() { s.logger.Info("msg", "Console sink stopped", "target", target) } +// GetStats returns the sink's statistics. func (s *ConsoleSink) GetStats() SinkStats { lastProc, _ := s.lastProcessed.Load().(time.Time) @@ -117,7 +121,7 @@ func (s *ConsoleSink) GetStats() SinkStats { } } -// processLoop reads entries, formats them, and passes them to the internal writer. +// processLoop reads entries, formats them, and writes to the console. func (s *ConsoleSink) processLoop(ctx context.Context) { for { select { diff --git a/src/internal/sink/file.go b/src/internal/sink/file.go index e8d801c..d8a7f1a 100644 --- a/src/internal/sink/file.go +++ b/src/internal/sink/file.go @@ -15,7 +15,7 @@ import ( "github.com/lixenwraith/log" ) -// Writes log entries to files with rotation +// FileSink writes log entries to files with rotation. type FileSink struct { config *config.FileSinkOptions input chan core.LogEntry @@ -30,7 +30,7 @@ type FileSink struct { lastProcessed atomic.Value // time.Time } -// Creates a new file sink +// NewFileSink creates a new file sink. func NewFileSink(opts *config.FileSinkOptions, logger *log.Logger, formatter format.Formatter) (*FileSink, error) { if opts == nil { return nil, fmt.Errorf("file sink options cannot be nil") @@ -63,10 +63,12 @@ func NewFileSink(opts *config.FileSinkOptions, logger *log.Logger, formatter for return fs, nil } +// Input returns the channel for sending log entries. func (fs *FileSink) Input() chan<- core.LogEntry { return fs.input } +// Start begins the processing loop for the sink. func (fs *FileSink) Start(ctx context.Context) error { // Start the internal file writer if err := fs.writer.Start(); err != nil { @@ -78,6 +80,7 @@ func (fs *FileSink) Start(ctx context.Context) error { return nil } +// Stop gracefully shuts down the sink. func (fs *FileSink) Stop() { fs.logger.Info("msg", "Stopping file sink") close(fs.done) @@ -92,6 +95,7 @@ func (fs *FileSink) Stop() { fs.logger.Info("msg", "File sink stopped") } +// GetStats returns the sink's statistics. func (fs *FileSink) GetStats() SinkStats { lastProc, _ := fs.lastProcessed.Load().(time.Time) @@ -104,6 +108,7 @@ func (fs *FileSink) GetStats() SinkStats { } } +// processLoop reads entries, formats them, and writes to a file. func (fs *FileSink) processLoop(ctx context.Context) { for { select { diff --git a/src/internal/sink/http.go b/src/internal/sink/http.go index 8a7b2a5..ca0a9ab 100644 --- a/src/internal/sink/http.go +++ b/src/internal/sink/http.go @@ -5,18 +5,19 @@ import ( "bufio" "bytes" "context" + "crypto/tls" "encoding/json" "fmt" "sync" "sync/atomic" "time" - "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" "logwisp/src/internal/limit" - "logwisp/src/internal/tls" + "logwisp/src/internal/session" + ltls "logwisp/src/internal/tls" "logwisp/src/internal/version" "github.com/lixenwraith/log" @@ -24,7 +25,7 @@ import ( "github.com/valyala/fasthttp" ) -// Streams log entries via Server-Sent Events +// HTTPSink streams log entries via Server-Sent Events (SSE). type HTTPSink struct { // Configuration reference (NOT a copy) config *config.HTTPSinkOptions @@ -46,10 +47,11 @@ type HTTPSink struct { unregister chan uint64 nextClientID atomic.Uint64 - // Security components - authenticator *auth.Authenticator - tlsManager *tls.Manager - authConfig *config.ServerAuthConfig + // Session and security + sessionManager *session.Manager + clientSessions map[uint64]string // clientID -> sessionID + sessionsMu sync.RWMutex + tlsManager *ltls.ServerManager // Net limiting netLimiter *limit.NetLimiter @@ -57,31 +59,32 @@ type HTTPSink struct { // Statistics totalProcessed atomic.Uint64 lastProcessed atomic.Value // time.Time - authFailures atomic.Uint64 - authSuccesses atomic.Uint64 } -// Creates a new HTTP streaming sink +// NewHTTPSink creates a new HTTP streaming sink. func NewHTTPSink(opts *config.HTTPSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPSink, error) { if opts == nil { return nil, fmt.Errorf("HTTP sink options cannot be nil") } h := &HTTPSink{ - config: opts, // Direct reference to config struct - input: make(chan core.LogEntry, opts.BufferSize), - startTime: time.Now(), - done: make(chan struct{}), - logger: logger, - formatter: formatter, - clients: make(map[uint64]chan core.LogEntry), + config: opts, + input: make(chan core.LogEntry, opts.BufferSize), + startTime: time.Now(), + done: make(chan struct{}), + logger: logger, + formatter: formatter, + clients: make(map[uint64]chan core.LogEntry), + unregister: make(chan uint64), + sessionManager: session.NewManager(30 * time.Minute), + clientSessions: make(map[uint64]string), } h.lastProcessed.Store(time.Time{}) // Initialize TLS manager if configured if opts.TLS != nil && opts.TLS.Enabled { - tlsManager, err := tls.NewManager(opts.TLS, logger) + tlsManager, err := ltls.NewServerManager(opts.TLS, logger) if err != nil { return nil, fmt.Errorf("failed to create TLS manager: %w", err) } @@ -97,32 +100,21 @@ func NewHTTPSink(opts *config.HTTPSinkOptions, logger *log.Logger, formatter for h.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger) } - // Initialize authenticator if auth is not "none" - if opts.Auth != nil && opts.Auth.Type != "none" { - // Only "basic" and "token" are valid for HTTP sink - if opts.Auth.Type != "basic" && opts.Auth.Type != "token" { - return nil, fmt.Errorf("invalid auth type '%s' for HTTP sink (valid: none, basic, token)", opts.Auth.Type) - } - - authenticator, err := auth.NewAuthenticator(opts.Auth, logger) - if err != nil { - return nil, fmt.Errorf("failed to create authenticator: %w", err) - } - h.authenticator = authenticator - h.authConfig = opts.Auth - logger.Info("msg", "Authentication enabled", - "component", "http_sink", - "type", opts.Auth.Type) - } - return h, nil } +// Input returns the channel for sending log entries. func (h *HTTPSink) Input() chan<- core.LogEntry { return h.input } +// Start initializes the HTTP server and begins the broker loop. func (h *HTTPSink) Start(ctx context.Context) error { + // Register expiry callback + h.sessionManager.RegisterExpiryCallback("http_sink", func(sessionID, remoteAddr string) { + h.handleSessionExpiry(sessionID, remoteAddr) + }) + // Start central broker goroutine h.wg.Add(1) go h.brokerLoop(ctx) @@ -144,6 +136,16 @@ func (h *HTTPSink) Start(ctx context.Context) error { // Configure TLS if enabled if h.tlsManager != nil { h.server.TLSConfig = h.tlsManager.GetHTTPConfig() + + // Enforce mTLS configuration + if h.config.TLS.ClientAuth { + if h.config.TLS.VerifyClientCert { + h.server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + } else { + h.server.TLSConfig.ClientAuth = tls.RequireAnyClientCert + } + } + h.logger.Info("msg", "TLS enabled for HTTP sink", "component", "http_sink", "port", h.config.Port) @@ -183,7 +185,7 @@ func (h *HTTPSink) Start(ctx context.Context) error { if h.server != nil { shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - h.server.ShutdownWithContext(shutdownCtx) + _ = h.server.ShutdownWithContext(shutdownCtx) } }() @@ -197,7 +199,105 @@ func (h *HTTPSink) Start(ctx context.Context) error { } } -// Broadcasts only to active clients +// Stop gracefully shuts down the HTTP server and all client connections. +func (h *HTTPSink) Stop() { + h.logger.Info("msg", "Stopping HTTP sink") + + // Unregister callback + h.sessionManager.UnregisterExpiryCallback("http_sink") + + // Signal all client handlers to stop + close(h.done) + + // Shutdown HTTP server + if h.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = h.server.ShutdownWithContext(ctx) + } + + // Wait for all active client handlers to finish + h.wg.Wait() + + // Close unregister channel after all clients have finished + close(h.unregister) + + // Close all client channels + h.clientsMu.Lock() + for _, ch := range h.clients { + close(ch) + } + h.clients = make(map[uint64]chan core.LogEntry) + h.clientsMu.Unlock() + + // Stop session manager + if h.sessionManager != nil { + h.sessionManager.Stop() + } + + h.logger.Info("msg", "HTTP sink stopped") +} + +// GetStats returns the sink's statistics. +func (h *HTTPSink) GetStats() SinkStats { + lastProc, _ := h.lastProcessed.Load().(time.Time) + + var netLimitStats map[string]any + if h.netLimiter != nil { + netLimitStats = h.netLimiter.GetStats() + } + + var sessionStats map[string]any + if h.sessionManager != nil { + sessionStats = h.sessionManager.GetStats() + } + + var tlsStats map[string]any + if h.tlsManager != nil { + tlsStats = h.tlsManager.GetStats() + } + + return SinkStats{ + Type: "http", + TotalProcessed: h.totalProcessed.Load(), + ActiveConnections: h.activeClients.Load(), + StartTime: h.startTime, + LastProcessed: lastProc, + Details: map[string]any{ + "port": h.config.Port, + "buffer_size": h.config.BufferSize, + "endpoints": map[string]string{ + "stream": h.config.StreamPath, + "status": h.config.StatusPath, + }, + "net_limit": netLimitStats, + "sessions": sessionStats, + "tls": tlsStats, + }, + } +} + +// GetActiveConnections returns the current number of active clients. +func (h *HTTPSink) GetActiveConnections() int64 { + return h.activeClients.Load() +} + +// GetStreamPath returns the configured transport endpoint path. +func (h *HTTPSink) GetStreamPath() string { + return h.config.StreamPath +} + +// GetStatusPath returns the configured status endpoint path. +func (h *HTTPSink) GetStatusPath() string { + return h.config.StatusPath +} + +// GetHost returns the configured host. +func (h *HTTPSink) GetHost() string { + return h.config.Host +} + +// brokerLoop manages client connections and broadcasts log entries. func (h *HTTPSink) brokerLoop(ctx context.Context) { defer h.wg.Done() @@ -233,6 +333,11 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) { } h.clientsMu.Unlock() + // Clean up session tracking + h.sessionsMu.Lock() + delete(h.clientSessions, clientID) + h.sessionsMu.Unlock() + case entry, ok := <-h.input: if !ok { h.logger.Debug("msg", "Input channel closed, broker stopping", @@ -248,23 +353,50 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) { clientCount := len(h.clients) if clientCount > 0 { slowClients := 0 + var staleClients []uint64 + for id, ch := range h.clients { - select { - case ch <- entry: - // Successfully sent - default: - // Client buffer full - slowClients++ - if slowClients == 1 { // Log only once per broadcast - h.logger.Debug("msg", "Dropped entry for slow client(s)", - "component", "http_sink", - "client_id", id, - "slow_clients", slowClients, - "total_clients", clientCount) + h.sessionsMu.RLock() + sessionID, hasSession := h.clientSessions[id] + h.sessionsMu.RUnlock() + + if hasSession { + if !h.sessionManager.IsSessionActive(sessionID) { + staleClients = append(staleClients, id) + continue } + select { + case ch <- entry: + h.sessionManager.UpdateActivity(sessionID) + default: + slowClients++ + if slowClients == 1 { + h.logger.Debug("msg", "Dropped entry for slow client(s)", + "component", "http_sink", + "client_id", id, + "slow_clients", slowClients, + "total_clients", clientCount) + } + } + } else { + delete(h.clients, id) } } + + // Clean up stale clients after broadcast + if len(staleClients) > 0 { + go func() { + for _, clientID := range staleClients { + select { + case h.unregister <- clientID: + case <-h.done: + return + } + } + }() + } } + // If no clients connected, entry is discarded (no buffering) h.clientsMu.RUnlock() @@ -275,91 +407,29 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) { h.clientsMu.RLock() for id, ch := range h.clients { - select { - case ch <- heartbeatEntry: - default: - // Client buffer full, skip heartbeat - h.logger.Debug("msg", "Skipped heartbeat for slow client", - "component", "http_sink", - "client_id", id) + h.sessionsMu.RLock() + sessionID, hasSession := h.clientSessions[id] + h.sessionsMu.RUnlock() + + if hasSession { + select { + case ch <- heartbeatEntry: + // Update session activity on heartbeat + h.sessionManager.UpdateActivity(sessionID) + default: + // Client buffer full, skip heartbeat + h.logger.Debug("msg", "Skipped heartbeat for slow client", + "component", "http_sink", + "client_id", id) + } } } - h.clientsMu.RUnlock() } } } } -func (h *HTTPSink) Stop() { - h.logger.Info("msg", "Stopping HTTP sink") - - // Signal all client handlers to stop - close(h.done) - - // Shutdown HTTP server - if h.server != nil { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - h.server.ShutdownWithContext(ctx) - } - - // Wait for all active client handlers to finish - h.wg.Wait() - - // Close unregister channel after all clients have finished - close(h.unregister) - - // Close all client channels - h.clientsMu.Lock() - for _, ch := range h.clients { - close(ch) - } - h.clients = make(map[uint64]chan core.LogEntry) - h.clientsMu.Unlock() - - h.logger.Info("msg", "HTTP sink stopped") -} - -func (h *HTTPSink) GetStats() SinkStats { - lastProc, _ := h.lastProcessed.Load().(time.Time) - - var netLimitStats map[string]any - if h.netLimiter != nil { - netLimitStats = h.netLimiter.GetStats() - } - - var authStats map[string]any - if h.authenticator != nil { - authStats = h.authenticator.GetStats() - authStats["failures"] = h.authFailures.Load() - authStats["successes"] = h.authSuccesses.Load() - } - - var tlsStats map[string]any - if h.tlsManager != nil { - tlsStats = h.tlsManager.GetStats() - } - - return SinkStats{ - Type: "http", - TotalProcessed: h.totalProcessed.Load(), - ActiveConnections: h.activeClients.Load(), - StartTime: h.startTime, - LastProcessed: lastProc, - Details: map[string]any{ - "port": h.config.Port, - "buffer_size": h.config.BufferSize, - "endpoints": map[string]string{ - "stream": h.config.StreamPath, - "status": h.config.StatusPath, - }, - "net_limit": netLimitStats, - "auth": authStats, - "tls": tlsStats, - }, - } -} - +// requestHandler is the main entry point for all incoming HTTP requests. func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { remoteAddr := ctx.RemoteAddr().String() @@ -380,21 +450,6 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { } } - // Enforce TLS for authentication - if h.authenticator != nil && h.authConfig.Type != "none" { - isTLS := ctx.IsTLS() || h.tlsManager != nil - - if !isTLS { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(map[string]string{ - "error": "TLS required for authentication", - "hint": "Use HTTPS for authenticated connections", - }) - return - } - } - path := string(ctx.Path()) // Status endpoint doesn't require auth @@ -403,52 +458,14 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { return } - // Authenticate request - var session *auth.Session - if h.authenticator != nil { - authHeader := string(ctx.Request.Header.Peek("Authorization")) - var err error - session, err = h.authenticator.AuthenticateHTTP(authHeader, remoteAddr) - if err != nil { - h.authFailures.Add(1) - h.logger.Warn("msg", "Authentication failed", - "component", "http_sink", - "remote_addr", remoteAddr, - "error", err) - - // Return 401 with WWW-Authenticate header - ctx.SetStatusCode(fasthttp.StatusUnauthorized) - if h.authConfig.Type == "basic" && h.authConfig.Basic != nil { - realm := h.authConfig.Basic.Realm - if realm == "" { - realm = "Restricted" - } - ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm)) - } else if h.authConfig.Type == "token" { - ctx.Response.Header.Set("WWW-Authenticate", "Token") - } - - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(map[string]string{ - "error": "Unauthorized", - }) - return - } - h.authSuccesses.Add(1) - } else { - // Create anonymous session for unauthenticated connections - session = &auth.Session{ - ID: fmt.Sprintf("anon-%d", time.Now().UnixNano()), - Username: "anonymous", - Method: "none", - RemoteAddr: remoteAddr, - CreatedAt: time.Now(), - } - } + // Create anonymous session for all connections + sess := h.sessionManager.CreateSession(remoteAddr, "http_sink", map[string]any{ + "tls": ctx.IsTLS() || h.tlsManager != nil, + }) switch path { case h.config.StreamPath: - h.handleStream(ctx, session) + h.handleStream(ctx, sess) default: ctx.SetStatusCode(fasthttp.StatusNotFound) ctx.SetContentType("application/json") @@ -456,18 +473,11 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) { "error": "Not Found", }) } - // Handle stream endpoint - // if path == h.config.StreamPath { - // h.handleStream(ctx, session) - // return - // } - // - // // Unknown path - // ctx.SetStatusCode(fasthttp.StatusNotFound) - // ctx.SetBody([]byte("Not Found")) + } -func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) { +// handleStream manages a client's Server-Sent Events (SSE) stream. +func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, sess *session.Session) { // Track connection for net limiting remoteAddr := ctx.RemoteAddr().String() if h.netLimiter != nil { @@ -490,14 +500,18 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) h.clients[clientID] = clientChan h.clientsMu.Unlock() + // Register session mapping + h.sessionsMu.Lock() + h.clientSessions[clientID] = sess.ID + h.sessionsMu.Unlock() + // Define the stream writer function streamFunc := func(w *bufio.Writer) { connectCount := h.activeClients.Add(1) h.logger.Debug("msg", "HTTP client connected", "component", "http_sink", "remote_addr", remoteAddr, - "username", session.Username, - "auth_method", session.Method, + "session_id", sess.ID, "client_id", clientID, "active_clients", connectCount) @@ -510,7 +524,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) h.logger.Debug("msg", "HTTP client disconnected", "component", "http_sink", "remote_addr", remoteAddr, - "username", session.Username, + "session_id", sess.ID, "client_id", clientID, "active_clients", disconnectCount) @@ -521,14 +535,16 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) // Shutting down, don't block } + // Remove session + h.sessionManager.RemoveSession(sess.ID) + h.wg.Done() }() // Send initial connected event with metadata connectionInfo := map[string]any{ "client_id": fmt.Sprintf("%d", clientID), - "username": session.Username, - "auth_method": session.Method, + "session_id": sess.ID, "stream_path": h.config.StreamPath, "status_path": h.config.StatusPath, "buffer_size": h.config.BufferSize, @@ -573,20 +589,15 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) return } - case <-tickerChan: - // Validate session is still active - if h.authenticator != nil && session != nil && !h.authenticator.ValidateSession(session.ID) { - fmt.Fprintf(w, "event: disconnect\ndata: {\"reason\":\"session_expired\"}\n\n") - w.Flush() - return - } + // Update session activity + h.sessionManager.UpdateActivity(sess.ID) - // Heartbeat is sent from broker, additional client-specific heartbeat is sent here - // This provides per-client heartbeat validation with session check + case <-tickerChan: + // Client-specific heartbeat sessionHB := map[string]any{ - "type": "session_heartbeat", - "client_id": fmt.Sprintf("%d", clientID), - "session_valid": true, + "type": "heartbeat", + "client_id": fmt.Sprintf("%d", clientID), + "session_id": sess.ID, } hbData, _ := json.Marshal(sessionHB) fmt.Fprintf(w, "event: heartbeat\ndata: %s\n\n", hbData) @@ -607,49 +618,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) ctx.SetBodyStreamWriter(streamFunc) } -func (h *HTTPSink) formatEntryForSSE(w *bufio.Writer, entry core.LogEntry) error { - formatted, err := h.formatter.Format(entry) - if err != nil { - return err - } - - // Remove trailing newline if present (SSE adds its own) - formatted = bytes.TrimSuffix(formatted, []byte{'\n'}) - - // Multi-line content handler - lines := bytes.Split(formatted, []byte{'\n'}) - for _, line := range lines { - // SSE needs "data: " prefix for each line based on W3C spec - fmt.Fprintf(w, "data: %s\n", line) - } - fmt.Fprintf(w, "\n") // Empty line to terminate event - - return nil -} - -func (h *HTTPSink) createHeartbeatEntry() core.LogEntry { - message := "heartbeat" - - // Build fields for heartbeat metadata - fields := make(map[string]any) - fields["type"] = "heartbeat" - - if h.config.Heartbeat.Enabled { - fields["active_clients"] = h.activeClients.Load() - fields["uptime_seconds"] = int(time.Since(h.startTime).Seconds()) - } - - fieldsJSON, _ := json.Marshal(fields) - - return core.LogEntry{ - Time: time.Now(), - Source: "logwisp-http", - Level: "INFO", - Message: message, - Fields: fieldsJSON, - } -} - +// handleStatus provides a JSON status report of the sink. func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { ctx.SetContentType("application/json") @@ -662,17 +631,6 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { } } - var authStats any - if h.authenticator != nil { - authStats = h.authenticator.GetStats() - authStats.(map[string]any)["failures"] = h.authFailures.Load() - authStats.(map[string]any)["successes"] = h.authSuccesses.Load() - } else { - authStats = map[string]any{ - "enabled": false, - } - } - var tlsStats any if h.tlsManager != nil { tlsStats = h.tlsManager.GetStats() @@ -682,6 +640,11 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { } } + var sessionStats any + if h.sessionManager != nil { + sessionStats = h.sessionManager.GetStats() + } + status := map[string]any{ "service": "LogWisp", "version": version.Short(), @@ -703,13 +666,11 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { "format": h.config.Heartbeat.Format, }, "tls": tlsStats, - "auth": authStats, + "sessions": sessionStats, "net_limit": netLimitStats, }, "statistics": map[string]any{ "total_processed": h.totalProcessed.Load(), - "auth_failures": h.authFailures.Load(), - "auth_successes": h.authSuccesses.Load(), }, } @@ -717,22 +678,71 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) { ctx.SetBody(data) } -// Returns the current number of active clients -func (h *HTTPSink) GetActiveConnections() int64 { - return h.activeClients.Load() +// handleSessionExpiry is the callback for cleaning up expired sessions. +func (h *HTTPSink) handleSessionExpiry(sessionID, remoteAddr string) { + h.sessionsMu.RLock() + defer h.sessionsMu.RUnlock() + + // Find client by session ID + for clientID, sessID := range h.clientSessions { + if sessID == sessionID { + h.logger.Info("msg", "Closing expired session client", + "component", "http_sink", + "session_id", sessionID, + "client_id", clientID, + "remote_addr", remoteAddr) + + // Signal broker to unregister + select { + case h.unregister <- clientID: + case <-h.done: + } + return + } + } } -// Returns the configured transport endpoint path -func (h *HTTPSink) GetStreamPath() string { - return h.config.StreamPath +// createHeartbeatEntry generates a new heartbeat log entry. +func (h *HTTPSink) createHeartbeatEntry() core.LogEntry { + message := "heartbeat" + + // Build fields for heartbeat metadata + fields := make(map[string]any) + fields["type"] = "heartbeat" + + if h.config.Heartbeat.Enabled { + fields["active_clients"] = h.activeClients.Load() + fields["uptime_seconds"] = int(time.Since(h.startTime).Seconds()) + } + + fieldsJSON, _ := json.Marshal(fields) + + return core.LogEntry{ + Time: time.Now(), + Source: "logwisp-http", + Level: "INFO", + Message: message, + Fields: fieldsJSON, + } } -// Returns the configured status endpoint path -func (h *HTTPSink) GetStatusPath() string { - return h.config.StatusPath -} +// formatEntryForSSE formats a log entry into the SSE 'data:' format. +func (h *HTTPSink) formatEntryForSSE(w *bufio.Writer, entry core.LogEntry) error { + formatted, err := h.formatter.Format(entry) + if err != nil { + return err + } -// Returns the configured host -func (h *HTTPSink) GetHost() string { - return h.config.Host + // Remove trailing newline if present (SSE adds its own) + formatted = bytes.TrimSuffix(formatted, []byte{'\n'}) + + // Multi-line content handler + lines := bytes.Split(formatted, []byte{'\n'}) + for _, line := range lines { + // SSE needs "data: " prefix for each line based on W3C spec + fmt.Fprintf(w, "data: %s\n", line) + } + fmt.Fprintf(w, "\n") // Empty line to terminate event + + return nil } \ No newline at end of file diff --git a/src/internal/sink/http_client.go b/src/internal/sink/http_client.go index 327b46c..a8f08d4 100644 --- a/src/internal/sink/http_client.go +++ b/src/internal/sink/http_client.go @@ -5,39 +5,39 @@ import ( "bytes" "context" "crypto/tls" - "crypto/x509" - "encoding/base64" "fmt" - "os" "strings" "sync" "sync/atomic" "time" - "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" + "logwisp/src/internal/session" + ltls "logwisp/src/internal/tls" "logwisp/src/internal/version" "github.com/lixenwraith/log" "github.com/valyala/fasthttp" ) -// TODO: implement heartbeat for HTTP Client Sink, similar to HTTP Sink -// Forwards log entries to a remote HTTP endpoint +// TODO: add heartbeat +// HTTPClientSink forwards log entries to a remote HTTP endpoint. type HTTPClientSink struct { - input chan core.LogEntry - config *config.HTTPClientSinkOptions - client *fasthttp.Client - batch []core.LogEntry - batchMu sync.Mutex - done chan struct{} - wg sync.WaitGroup - startTime time.Time - logger *log.Logger - formatter format.Formatter - authenticator *auth.Authenticator + input chan core.LogEntry + config *config.HTTPClientSinkOptions + client *fasthttp.Client + batch []core.LogEntry + batchMu sync.Mutex + done chan struct{} + wg sync.WaitGroup + startTime time.Time + logger *log.Logger + formatter format.Formatter + sessionID string + sessionManager *session.Manager + tlsManager *ltls.ClientManager // Statistics totalProcessed atomic.Uint64 @@ -48,20 +48,21 @@ type HTTPClientSink struct { activeConnections atomic.Int64 } -// Creates a new HTTP client sink +// NewHTTPClientSink creates a new HTTP client sink. func NewHTTPClientSink(opts *config.HTTPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPClientSink, error) { if opts == nil { return nil, fmt.Errorf("HTTP client sink options cannot be nil") } h := &HTTPClientSink{ - config: opts, - input: make(chan core.LogEntry, opts.BufferSize), - batch: make([]core.LogEntry, 0, opts.BatchSize), - done: make(chan struct{}), - startTime: time.Now(), - logger: logger, - formatter: formatter, + config: opts, + input: make(chan core.LogEntry, opts.BufferSize), + batch: make([]core.LogEntry, 0, opts.BatchSize), + done: make(chan struct{}), + startTime: time.Now(), + logger: logger, + formatter: formatter, + sessionManager: session.NewManager(30 * time.Minute), } h.lastProcessed.Store(time.Time{}) h.lastBatchSent.Store(time.Time{}) @@ -75,54 +76,48 @@ func NewHTTPClientSink(opts *config.HTTPClientSinkOptions, logger *log.Logger, f DisableHeaderNamesNormalizing: true, } - // Configure TLS if using HTTPS + // Configure TLS for HTTPS if strings.HasPrefix(opts.URL, "https://") { - tlsConfig := &tls.Config{ - InsecureSkipVerify: opts.InsecureSkipVerify, - } - - // Use TLS config if provided - if opts.TLS != nil { - // Load custom CA for server verification - if opts.TLS.CAFile != "" { - caCert, err := os.ReadFile(opts.TLS.CAFile) - if err != nil { - return nil, fmt.Errorf("failed to read CA file '%s': %w", opts.TLS.CAFile, err) - } - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, fmt.Errorf("failed to parse CA certificate from '%s'", opts.TLS.CAFile) - } - tlsConfig.RootCAs = caCertPool - logger.Debug("msg", "Custom CA loaded for server verification", - "component", "http_client_sink", - "ca_file", opts.TLS.CAFile) + if opts.TLS != nil && opts.TLS.Enabled { + // Use the new ClientManager with the clear client-specific config + tlsManager, err := ltls.NewClientManager(opts.TLS, logger) + if err != nil { + return nil, fmt.Errorf("failed to create TLS client manager: %w", err) } + h.tlsManager = tlsManager + // Get the generated config + h.client.TLSConfig = tlsManager.GetConfig() - // Load client certificate for mTLS if provided - if opts.TLS.CertFile != "" && opts.TLS.KeyFile != "" { - cert, err := tls.LoadX509KeyPair(opts.TLS.CertFile, opts.TLS.KeyFile) - if err != nil { - return nil, fmt.Errorf("failed to load client certificate: %w", err) - } - tlsConfig.Certificates = []tls.Certificate{cert} - logger.Info("msg", "Client certificate loaded for mTLS", - "component", "http_client_sink", - "cert_file", opts.TLS.CertFile) + logger.Info("msg", "Client TLS configured", + "component", "http_client_sink", + "has_client_cert", opts.TLS.ClientCertFile != "", // Clearer check + "has_server_ca", opts.TLS.ServerCAFile != "", // Clearer check + "min_version", opts.TLS.MinVersion) + } else if opts.InsecureSkipVerify { // Use the new clear field + // TODO: document this behavior + h.client.TLSConfig = &tls.Config{ + InsecureSkipVerify: true, } } - - h.client.TLSConfig = tlsConfig } return h, nil } +// Input returns the channel for sending log entries. func (h *HTTPClientSink) Input() chan<- core.LogEntry { return h.input } +// Start begins the processing and batching loops. func (h *HTTPClientSink) Start(ctx context.Context) error { + // Create session for HTTP client sink lifetime + sess := h.sessionManager.CreateSession(h.config.URL, "http_client_sink", map[string]any{ + "batch_size": h.config.BatchSize, + "timeout": h.config.Timeout, + }) + h.sessionID = sess.ID + h.wg.Add(2) go h.processLoop(ctx) go h.batchTimer(ctx) @@ -131,10 +126,12 @@ func (h *HTTPClientSink) Start(ctx context.Context) error { "component", "http_client_sink", "url", h.config.URL, "batch_size", h.config.BatchSize, - "batch_delay_ms", h.config.BatchDelayMS) + "batch_delay_ms", h.config.BatchDelayMS, + "session_id", h.sessionID) return nil } +// Stop gracefully shuts down the sink, sending any remaining batched entries. func (h *HTTPClientSink) Stop() { h.logger.Info("msg", "Stopping HTTP client sink") close(h.done) @@ -151,12 +148,21 @@ func (h *HTTPClientSink) Stop() { h.batchMu.Unlock() } + // Remove session and stop manager + if h.sessionID != "" { + h.sessionManager.RemoveSession(h.sessionID) + } + if h.sessionManager != nil { + h.sessionManager.Stop() + } + h.logger.Info("msg", "HTTP client sink stopped", "total_processed", h.totalProcessed.Load(), "total_batches", h.totalBatches.Load(), "failed_batches", h.failedBatches.Load()) } +// GetStats returns the sink's statistics. func (h *HTTPClientSink) GetStats() SinkStats { lastProc, _ := h.lastProcessed.Load().(time.Time) lastBatch, _ := h.lastBatchSent.Load().(time.Time) @@ -165,6 +171,23 @@ func (h *HTTPClientSink) GetStats() SinkStats { pendingEntries := len(h.batch) h.batchMu.Unlock() + // Get session information + var sessionInfo map[string]any + if h.sessionID != "" { + if sess, exists := h.sessionManager.GetSession(h.sessionID); exists { + sessionInfo = map[string]any{ + "session_id": sess.ID, + "created_at": sess.CreatedAt, + "last_activity": sess.LastActivity, + } + } + } + + var tlsStats map[string]any + if h.tlsManager != nil { + tlsStats = h.tlsManager.GetStats() + } + return SinkStats{ Type: "http_client", TotalProcessed: h.totalProcessed.Load(), @@ -178,10 +201,13 @@ func (h *HTTPClientSink) GetStats() SinkStats { "total_batches": h.totalBatches.Load(), "failed_batches": h.failedBatches.Load(), "last_batch_sent": lastBatch, + "session": sessionInfo, + "tls": tlsStats, }, } } +// processLoop collects incoming log entries into a batch. func (h *HTTPClientSink) processLoop(ctx context.Context) { defer h.wg.Done() @@ -219,6 +245,7 @@ func (h *HTTPClientSink) processLoop(ctx context.Context) { } } +// batchTimer periodically triggers sending of the current batch. func (h *HTTPClientSink) batchTimer(ctx context.Context) { defer h.wg.Done() @@ -248,6 +275,7 @@ func (h *HTTPClientSink) batchTimer(ctx context.Context) { } } +// sendBatch sends a batch of log entries to the remote endpoint with retry logic. func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { h.activeConnections.Add(1) defer h.activeConnections.Add(-1) @@ -293,7 +321,6 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { var lastErr error retryDelay := time.Duration(h.config.RetryDelayMS) * time.Millisecond - // TODO: verify retry loop placement is correct or should it be after acquiring resources (req :=....) for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ { if attempt > 0 { // Wait before retry @@ -323,24 +350,6 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short())) - // Add authentication based on auth type - switch h.config.Auth.Type { - case "basic": - creds := h.config.Auth.Username + ":" + h.config.Auth.Password - encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds)) - req.Header.Set("Authorization", "Basic "+encodedCreds) - - case "token": - req.Header.Set("Authorization", "Token "+h.config.Auth.Token) - - case "mtls": - // mTLS auth is handled at TLS layer via client certificates - // No Authorization header needed - - case "none": - // No authentication - } - // Send request err := h.client.DoTimeout(req, resp, time.Duration(h.config.Timeout)*time.Second) @@ -370,6 +379,12 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) { // Check response status if statusCode >= 200 && statusCode < 300 { // Success + + // Update session activity on successful batch send + if h.sessionID != "" { + h.sessionManager.UpdateActivity(h.sessionID) + } + h.logger.Debug("msg", "Batch sent successfully", "component", "http_client_sink", "batch_size", len(batch), diff --git a/src/internal/sink/sink.go b/src/internal/sink/sink.go index 140b8b4..2592953 100644 --- a/src/internal/sink/sink.go +++ b/src/internal/sink/sink.go @@ -8,22 +8,22 @@ import ( "logwisp/src/internal/core" ) -// Represents an output data stream +// Sink represents an output data stream. type Sink interface { - // Returns the channel for sending log entries to this sink + // Input returns the channel for sending log entries to this sink. Input() chan<- core.LogEntry - // Begins processing log entries + // Start begins processing log entries. Start(ctx context.Context) error - // Gracefully shuts down the sink + // Stop gracefully shuts down the sink. Stop() - // Returns sink statistics + // GetStats returns sink statistics. GetStats() SinkStats } -// Contains statistics about a sink +// SinkStats contains statistics about a sink. type SinkStats struct { Type string TotalProcessed uint64 diff --git a/src/internal/sink/tcp.go b/src/internal/sink/tcp.go index 89bf04b..42988b3 100644 --- a/src/internal/sink/tcp.go +++ b/src/internal/sink/tcp.go @@ -11,31 +11,32 @@ import ( "sync/atomic" "time" - "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" "logwisp/src/internal/limit" + "logwisp/src/internal/session" "github.com/lixenwraith/log" "github.com/lixenwraith/log/compat" "github.com/panjf2000/gnet/v2" ) -// Streams log entries via TCP +// TCPSink streams log entries to connected TCP clients. type TCPSink struct { - input chan core.LogEntry - config *config.TCPSinkOptions - server *tcpServer - done chan struct{} - activeConns atomic.Int64 - startTime time.Time - engine *gnet.Engine - engineMu sync.Mutex - wg sync.WaitGroup - netLimiter *limit.NetLimiter - logger *log.Logger - formatter format.Formatter + input chan core.LogEntry + config *config.TCPSinkOptions + server *tcpServer + done chan struct{} + activeConns atomic.Int64 + startTime time.Time + engine *gnet.Engine + engineMu sync.Mutex + wg sync.WaitGroup + netLimiter *limit.NetLimiter + logger *log.Logger + formatter format.Formatter + sessionManager *session.Manager // Statistics totalProcessed atomic.Uint64 @@ -47,7 +48,7 @@ type TCPSink struct { errorMu sync.Mutex } -// Holds TCP sink configuration +// TCPConfig holds configuration for the TCPSink. type TCPConfig struct { Host string Port int64 @@ -56,19 +57,21 @@ type TCPConfig struct { NetLimit *config.NetLimitConfig } -// Creates a new TCP streaming sink +// NewTCPSink creates a new TCP streaming sink. func NewTCPSink(opts *config.TCPSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPSink, error) { if opts == nil { return nil, fmt.Errorf("TCP sink options cannot be nil") } t := &TCPSink{ - config: opts, // Direct reference to config - input: make(chan core.LogEntry, opts.BufferSize), - done: make(chan struct{}), - startTime: time.Now(), - logger: logger, - formatter: formatter, + config: opts, + input: make(chan core.LogEntry, opts.BufferSize), + done: make(chan struct{}), + startTime: time.Now(), + logger: logger, + formatter: formatter, + consecutiveWriteErrors: make(map[gnet.Conn]int), + sessionManager: session.NewManager(30 * time.Minute), } t.lastProcessed.Store(time.Time{}) @@ -82,16 +85,23 @@ func NewTCPSink(opts *config.TCPSinkOptions, logger *log.Logger, formatter forma return t, nil } +// Input returns the channel for sending log entries. func (t *TCPSink) Input() chan<- core.LogEntry { return t.input } +// Start initializes the TCP server and begins the broadcast loop. func (t *TCPSink) Start(ctx context.Context) error { t.server = &tcpServer{ sink: t, clients: make(map[gnet.Conn]*tcpClient), } + // Register expiry callback + t.sessionManager.RegisterExpiryCallback("tcp_sink", func(sessionID, remoteAddr string) { + t.handleSessionExpiry(sessionID, remoteAddr) + }) + // Start log broadcast loop t.wg.Add(1) go func() { @@ -155,8 +165,13 @@ func (t *TCPSink) Start(ctx context.Context) error { } } +// Stop gracefully shuts down the TCP server. func (t *TCPSink) Stop() { t.logger.Info("msg", "Stopping TCP sink") + + // Unregister callback + t.sessionManager.UnregisterExpiryCallback("tcp_sink") + // Signal broadcast loop to stop close(t.done) @@ -174,9 +189,15 @@ func (t *TCPSink) Stop() { // Wait for broadcast loop to finish t.wg.Wait() + // Stop session manager + if t.sessionManager != nil { + t.sessionManager.Stop() + } + t.logger.Info("msg", "TCP sink stopped") } +// GetStats returns the sink's statistics. func (t *TCPSink) GetStats() SinkStats { lastProc, _ := t.lastProcessed.Load().(time.Time) @@ -185,6 +206,11 @@ func (t *TCPSink) GetStats() SinkStats { netLimitStats = t.netLimiter.GetStats() } + var sessionStats map[string]any + if t.sessionManager != nil { + sessionStats = t.sessionManager.GetStats() + } + return SinkStats{ Type: "tcp", TotalProcessed: t.totalProcessed.Load(), @@ -195,11 +221,32 @@ func (t *TCPSink) GetStats() SinkStats { "port": t.config.Port, "buffer_size": t.config.BufferSize, "net_limit": netLimitStats, - "auth": map[string]any{"enabled": false}, + "sessions": sessionStats, }, } } +// GetActiveConnections returns the current number of active connections. +func (t *TCPSink) GetActiveConnections() int64 { + return t.activeConns.Load() +} + +// tcpServer implements the gnet.EventHandler interface for the TCP sink. +type tcpServer struct { + gnet.BuiltinEventEngine + sink *TCPSink + clients map[gnet.Conn]*tcpClient + mu sync.RWMutex +} + +// tcpClient represents a connected TCP client. +type tcpClient struct { + conn gnet.Conn + buffer bytes.Buffer + sessionID string +} + +// broadcastLoop manages the central broadcasting of log entries to all clients. func (t *TCPSink) broadcastLoop(ctx context.Context) { var ticker *time.Ticker var tickerChan <-chan time.Time @@ -248,101 +295,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) { } } -func (t *TCPSink) broadcastData(data []byte) { - t.server.mu.RLock() - defer t.server.mu.RUnlock() - - for conn, _ := range t.server.clients { - conn.AsyncWrite(data, func(c gnet.Conn, err error) error { - if err != nil { - t.writeErrors.Add(1) - t.handleWriteError(c, err) - } else { - // Reset consecutive error count on success - t.errorMu.Lock() - delete(t.consecutiveWriteErrors, c) - t.errorMu.Unlock() - } - return nil - }) - } -} - -// Handle write errors with threshold-based connection termination -func (t *TCPSink) handleWriteError(c gnet.Conn, err error) { - t.errorMu.Lock() - defer t.errorMu.Unlock() - - // Track consecutive errors per connection - if t.consecutiveWriteErrors == nil { - t.consecutiveWriteErrors = make(map[gnet.Conn]int) - } - - t.consecutiveWriteErrors[c]++ - errorCount := t.consecutiveWriteErrors[c] - - t.logger.Debug("msg", "AsyncWrite error", - "component", "tcp_sink", - "remote_addr", c.RemoteAddr(), - "error", err, - "consecutive_errors", errorCount) - - // Close connection after 3 consecutive write errors - if errorCount >= 3 { - t.logger.Warn("msg", "Closing connection due to repeated write errors", - "component", "tcp_sink", - "remote_addr", c.RemoteAddr(), - "error_count", errorCount) - delete(t.consecutiveWriteErrors, c) - c.Close() - } -} - -// Create heartbeat as a proper LogEntry -func (t *TCPSink) createHeartbeatEntry() core.LogEntry { - message := "heartbeat" - - // Build fields for heartbeat metadata - fields := make(map[string]any) - fields["type"] = "heartbeat" - - if t.config.Heartbeat.IncludeStats { - fields["active_connections"] = t.activeConns.Load() - fields["uptime_seconds"] = int64(time.Since(t.startTime).Seconds()) - } - - fieldsJSON, _ := json.Marshal(fields) - - return core.LogEntry{ - Time: time.Now(), - Source: "logwisp-tcp", - Level: "INFO", - Message: message, - Fields: fieldsJSON, - } -} - -// Returns the current number of connections -func (t *TCPSink) GetActiveConnections() int64 { - return t.activeConns.Load() -} - -// Represents a connected TCP client with auth state -type tcpClient struct { - conn gnet.Conn - buffer bytes.Buffer - authTimeout time.Time - session *auth.Session -} - -// Handles gnet events with authentication -type tcpServer struct { - gnet.BuiltinEventEngine - sink *TCPSink - clients map[gnet.Conn]*tcpClient - mu sync.RWMutex -} - +// OnBoot is called when the server starts. func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action { // Store engine reference for shutdown s.sink.engineMu.Lock() @@ -355,6 +308,7 @@ func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action { return gnet.None } +// OnOpen is called when a new connection is established. func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { remoteAddr := c.RemoteAddr() s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr) @@ -387,10 +341,14 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { s.sink.netLimiter.AddConnection(remoteStr) } + // Create session for tracking + sess := s.sink.sessionManager.CreateSession(c.RemoteAddr().String(), "tcp_sink", nil) + // TCP Sink accepts all connections without authentication client := &tcpClient{ - conn: c, - buffer: bytes.Buffer{}, + conn: c, + buffer: bytes.Buffer{}, + sessionID: sess.ID, } s.mu.Lock() @@ -400,14 +358,30 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { newCount := s.sink.activeConns.Add(1) s.sink.logger.Debug("msg", "TCP connection opened", "remote_addr", remoteAddr, + "session_id", sess.ID, "active_connections", newCount) return nil, gnet.None } +// OnClose is called when a connection is closed. func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action { remoteAddr := c.RemoteAddr().String() + // Get client to retrieve session ID + s.mu.RLock() + client, exists := s.clients[c] + s.mu.RUnlock() + + if exists && client.sessionID != "" { + // Remove session + s.sink.sessionManager.RemoveSession(client.sessionID) + s.sink.logger.Debug("msg", "Session removed", + "component", "tcp_sink", + "session_id", client.sessionID, + "remote_addr", remoteAddr) + } + // Remove client state s.mu.Lock() delete(s.clients, c) @@ -431,8 +405,141 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action { return gnet.None } +// OnTraffic is called when data is received from a connection. func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { + s.mu.RLock() + client, exists := s.clients[c] + s.mu.RUnlock() + + // Update session activity when client sends data + if exists && client.sessionID != "" { + s.sink.sessionManager.UpdateActivity(client.sessionID) + } + // TCP Sink doesn't expect any data from clients, discard all c.Discard(-1) return gnet.None +} + +// handleSessionExpiry is the callback for cleaning up expired sessions. +func (t *TCPSink) handleSessionExpiry(sessionID, remoteAddr string) { + t.server.mu.RLock() + defer t.server.mu.RUnlock() + + // Find connection by session ID + for conn, client := range t.server.clients { + if client.sessionID == sessionID { + t.logger.Info("msg", "Closing expired session connection", + "component", "tcp_sink", + "session_id", sessionID, + "remote_addr", remoteAddr) + + // Close connection + conn.Close() + return + } + } +} + +// broadcastData sends a formatted byte slice to all connected clients. +func (t *TCPSink) broadcastData(data []byte) { + t.server.mu.RLock() + defer t.server.mu.RUnlock() + + // Track clients to remove after iteration + var staleClients []gnet.Conn + + for conn, client := range t.server.clients { + // Update session activity before sending data + if client.sessionID != "" { + if !t.sessionManager.IsSessionActive(client.sessionID) { + // Session expired, mark for cleanup + staleClients = append(staleClients, conn) + continue + } + t.sessionManager.UpdateActivity(client.sessionID) + } + + conn.AsyncWrite(data, func(c gnet.Conn, err error) error { + if err != nil { + t.writeErrors.Add(1) + t.handleWriteError(c, err) + } else { + // Reset consecutive error count on success + t.errorMu.Lock() + delete(t.consecutiveWriteErrors, c) + t.errorMu.Unlock() + } + return nil + }) + } + + // Clean up stale connections outside the read lock + if len(staleClients) > 0 { + go t.cleanupStaleConnections(staleClients) + } +} + +// handleWriteError manages errors during async writes, closing faulty connections. +func (t *TCPSink) handleWriteError(c gnet.Conn, err error) { + t.errorMu.Lock() + defer t.errorMu.Unlock() + + // Track consecutive errors per connection + if t.consecutiveWriteErrors == nil { + t.consecutiveWriteErrors = make(map[gnet.Conn]int) + } + + t.consecutiveWriteErrors[c]++ + errorCount := t.consecutiveWriteErrors[c] + + t.logger.Debug("msg", "AsyncWrite error", + "component", "tcp_sink", + "remote_addr", c.RemoteAddr(), + "error", err, + "consecutive_errors", errorCount) + + // Close connection after 3 consecutive write errors + if errorCount >= 3 { + t.logger.Warn("msg", "Closing connection due to repeated write errors", + "component", "tcp_sink", + "remote_addr", c.RemoteAddr(), + "error_count", errorCount) + delete(t.consecutiveWriteErrors, c) + c.Close() + } +} + +// createHeartbeatEntry generates a new heartbeat log entry. +func (t *TCPSink) createHeartbeatEntry() core.LogEntry { + message := "heartbeat" + + // Build fields for heartbeat metadata + fields := make(map[string]any) + fields["type"] = "heartbeat" + + if t.config.Heartbeat.IncludeStats { + fields["active_connections"] = t.activeConns.Load() + fields["uptime_seconds"] = int64(time.Since(t.startTime).Seconds()) + } + + fieldsJSON, _ := json.Marshal(fields) + + return core.LogEntry{ + Time: time.Now(), + Source: "logwisp-tcp", + Level: "INFO", + Message: message, + Fields: fieldsJSON, + } +} + +// cleanupStaleConnections closes connections associated with expired sessions. +func (t *TCPSink) cleanupStaleConnections(staleConns []gnet.Conn) { + for _, conn := range staleConns { + t.logger.Info("msg", "Closing stale connection", + "component", "tcp_sink", + "remote_addr", conn.RemoteAddr()) + conn.Close() + } } \ No newline at end of file diff --git a/src/internal/sink/tcp_client.go b/src/internal/sink/tcp_client.go index 467b94a..2aebe02 100644 --- a/src/internal/sink/tcp_client.go +++ b/src/internal/sink/tcp_client.go @@ -2,28 +2,25 @@ package sink import ( - "bufio" "context" - "encoding/json" "errors" "fmt" "net" "strconv" - "strings" "sync" "sync/atomic" "time" - "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/format" + "logwisp/src/internal/session" "github.com/lixenwraith/log" ) -// TODO: implement heartbeat for TCP Client Sink, similar to TCP Sink -// Forwards log entries to a remote TCP endpoint +// TODO: add heartbeat +// TCPClientSink forwards log entries to a remote TCP endpoint. type TCPClientSink struct { input chan core.LogEntry config *config.TCPClientSinkOptions @@ -36,7 +33,9 @@ type TCPClientSink struct { logger *log.Logger formatter format.Formatter - // Reconnection state + // Connection + sessionID string + sessionManager *session.Manager reconnecting atomic.Bool lastConnectErr error connectTime time.Time @@ -49,7 +48,7 @@ type TCPClientSink struct { connectionUptime atomic.Value // time.Duration } -// Creates a new TCP client sink +// NewTCPClientSink creates a new TCP client sink. func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) { // Validation and defaults are handled in config package if opts == nil { @@ -57,13 +56,14 @@ func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, for } t := &TCPClientSink{ - config: opts, - address: opts.Host + ":" + strconv.Itoa(int(opts.Port)), - input: make(chan core.LogEntry, opts.BufferSize), - done: make(chan struct{}), - startTime: time.Now(), - logger: logger, - formatter: formatter, + config: opts, + address: opts.Host + ":" + strconv.Itoa(int(opts.Port)), + input: make(chan core.LogEntry, opts.BufferSize), + done: make(chan struct{}), + startTime: time.Now(), + logger: logger, + formatter: formatter, + sessionManager: session.NewManager(30 * time.Minute), } t.lastProcessed.Store(time.Time{}) t.connectionUptime.Store(time.Duration(0)) @@ -71,10 +71,12 @@ func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, for return t, nil } +// Input returns the channel for sending log entries. func (t *TCPClientSink) Input() chan<- core.LogEntry { return t.input } +// Start begins the connection and processing loops. func (t *TCPClientSink) Start(ctx context.Context) error { // Start connection manager t.wg.Add(1) @@ -91,6 +93,7 @@ func (t *TCPClientSink) Start(ctx context.Context) error { return nil } +// Stop gracefully shuts down the sink and its connection. func (t *TCPClientSink) Stop() { t.logger.Info("msg", "Stopping TCP client sink") close(t.done) @@ -103,12 +106,21 @@ func (t *TCPClientSink) Stop() { } t.connMu.Unlock() + // Remove session and stop manager + if t.sessionID != "" { + t.sessionManager.RemoveSession(t.sessionID) + } + if t.sessionManager != nil { + t.sessionManager.Stop() + } + t.logger.Info("msg", "TCP client sink stopped", "total_processed", t.totalProcessed.Load(), "total_failed", t.totalFailed.Load(), "total_reconnects", t.totalReconnects.Load()) } +// GetStats returns the sink's statistics. func (t *TCPClientSink) GetStats() SinkStats { lastProc, _ := t.lastProcessed.Load().(time.Time) uptime, _ := t.connectionUptime.Load().(time.Duration) @@ -122,6 +134,19 @@ func (t *TCPClientSink) GetStats() SinkStats { activeConns = 1 } + // Get session stats + var sessionInfo map[string]any + if t.sessionID != "" { + if sess, exists := t.sessionManager.GetSession(t.sessionID); exists { + sessionInfo = map[string]any{ + "session_id": sess.ID, + "created_at": sess.CreatedAt, + "last_activity": sess.LastActivity, + "remote_addr": sess.RemoteAddr, + } + } + } + return SinkStats{ Type: "tcp_client", TotalProcessed: t.totalProcessed.Load(), @@ -136,10 +161,12 @@ func (t *TCPClientSink) GetStats() SinkStats { "total_reconnects": t.totalReconnects.Load(), "connection_uptime": uptime.Seconds(), "last_error": fmt.Sprintf("%v", t.lastConnectErr), + "session": sessionInfo, }, } } +// connectionManager handles the lifecycle of the TCP connection, including reconnections. func (t *TCPClientSink) connectionManager(ctx context.Context) { defer t.wg.Done() @@ -154,6 +181,11 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) { default: } + if t.sessionID != "" { + t.sessionManager.RemoveSession(t.sessionID) + t.sessionID = "" + } + // Attempt to connect t.reconnecting.Store(true) conn, err := t.connect() @@ -190,6 +222,13 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) { t.connectTime = time.Now() t.totalReconnects.Add(1) + // Create session for the connection + sess := t.sessionManager.CreateSession(t.address, "tcp_client_sink", map[string]any{ + "local_addr": conn.LocalAddr().String(), + "sink_type": "tcp_client", + }) + t.sessionID = sess.ID + t.connMu.Lock() t.conn = conn t.connMu.Unlock() @@ -197,7 +236,8 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) { t.logger.Info("msg", "Connected to TCP server", "component", "tcp_client_sink", "address", t.address, - "local_addr", conn.LocalAddr()) + "local_addr", conn.LocalAddr(), + "session_id", t.sessionID) // Monitor connection t.monitorConnection(conn) @@ -214,10 +254,57 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) { t.logger.Warn("msg", "Lost connection to TCP server", "component", "tcp_client_sink", "address", t.address, - "uptime", uptime) + "uptime", uptime, + "session_id", t.sessionID) } } +// processLoop reads entries from the input channel and sends them. +func (t *TCPClientSink) processLoop(ctx context.Context) { + defer t.wg.Done() + + for { + select { + case entry, ok := <-t.input: + if !ok { + return + } + + t.totalProcessed.Add(1) + t.lastProcessed.Store(time.Now()) + + // Send entry + if err := t.sendEntry(entry); err != nil { + t.totalFailed.Add(1) + t.logger.Debug("msg", "Failed to send log entry", + "component", "tcp_client_sink", + "error", err) + } else { + // Update session activity on successful send + if t.sessionID != "" { + t.sessionManager.UpdateActivity(t.sessionID) + } else { + // Close invalid connection without session + t.logger.Warn("msg", "Connection without session detected, forcing reconnection", + "component", "tcp_client_sink") + t.connMu.Lock() + if t.conn != nil { + _ = t.conn.Close() + t.conn = nil + } + t.connMu.Unlock() + } + } + + case <-ctx.Done(): + return + case <-t.done: + return + } + } +} + +// connect attempts to establish a connection to the remote server. func (t *TCPClientSink) connect() (net.Conn, error) { dialer := &net.Dialer{ Timeout: time.Duration(t.config.DialTimeout) * time.Second, @@ -235,129 +322,10 @@ func (t *TCPClientSink) connect() (net.Conn, error) { tcpConn.SetKeepAlivePeriod(time.Duration(t.config.KeepAlive) * time.Second) } - // SCRAM authentication if credentials configured - if t.config.Auth != nil && t.config.Auth.Type == "scram" { - if err := t.performSCRAMAuth(conn); err != nil { - conn.Close() - return nil, fmt.Errorf("SCRAM authentication failed: %w", err) - } - t.logger.Debug("msg", "SCRAM authentication completed", - "component", "tcp_client_sink", - "address", t.address) - } - return conn, nil } -func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error { - reader := bufio.NewReader(conn) - - // Create SCRAM client - scramClient := auth.NewScramClient(t.config.Auth.Username, t.config.Auth.Password) - - // Wait for AUTH_REQUIRED from server - authPrompt, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("failed to read auth prompt: %w", err) - } - - if strings.TrimSpace(authPrompt) != "AUTH_REQUIRED" { - return fmt.Errorf("unexpected server greeting: %s", authPrompt) - } - - // Step 1: Send ClientFirst - clientFirst, err := scramClient.StartAuthentication() - if err != nil { - return fmt.Errorf("failed to start SCRAM: %w", err) - } - - msg, err := auth.FormatSCRAMRequest("SCRAM-FIRST", clientFirst) - if err != nil { - return err - } - - if _, err := conn.Write([]byte(msg)); err != nil { - return fmt.Errorf("failed to send SCRAM-FIRST: %w", err) - } - - // Step 2: Receive ServerFirst challenge - response, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("failed to read SCRAM challenge: %w", err) - } - - command, data, err := auth.ParseSCRAMResponse(response) - if err != nil { - return err - } - - if command != "SCRAM-CHALLENGE" { - return fmt.Errorf("unexpected server response: %s", command) - } - - var serverFirst auth.ServerFirst - if err := json.Unmarshal([]byte(data), &serverFirst); err != nil { - return fmt.Errorf("failed to parse server challenge: %w", err) - } - - // Step 3: Process challenge and send proof - clientFinal, err := scramClient.ProcessServerFirst(&serverFirst) - if err != nil { - return fmt.Errorf("failed to process challenge: %w", err) - } - - msg, err = auth.FormatSCRAMRequest("SCRAM-PROOF", clientFinal) - if err != nil { - return err - } - - if _, err := conn.Write([]byte(msg)); err != nil { - return fmt.Errorf("failed to send SCRAM-PROOF: %w", err) - } - - // Step 4: Receive ServerFinal - response, err = reader.ReadString('\n') - if err != nil { - return fmt.Errorf("failed to read SCRAM result: %w", err) - } - - command, data, err = auth.ParseSCRAMResponse(response) - if err != nil { - return err - } - - switch command { - case "SCRAM-OK": - var serverFinal auth.ServerFinal - if err := json.Unmarshal([]byte(data), &serverFinal); err != nil { - return fmt.Errorf("failed to parse server signature: %w", err) - } - - // Verify server signature - if err := scramClient.VerifyServerFinal(&serverFinal); err != nil { - return fmt.Errorf("server signature verification failed: %w", err) - } - - t.logger.Info("msg", "SCRAM authentication successful", - "component", "tcp_client_sink", - "address", t.address, - "username", t.config.Auth.Username, - "session_id", serverFinal.SessionID) - - return nil - - case "SCRAM-FAIL": - reason := data - if reason == "" { - reason = "unknown" - } - return fmt.Errorf("authentication failed: %s", reason) - - default: - return fmt.Errorf("unexpected response: %s", command) - } -} - +// monitorConnection checks the health of the connection. func (t *TCPClientSink) monitorConnection(conn net.Conn) { // Simple connection monitoring by periodic zero-byte reads ticker := time.NewTicker(5 * time.Second) @@ -390,35 +358,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) { } } -func (t *TCPClientSink) processLoop(ctx context.Context) { - defer t.wg.Done() - - for { - select { - case entry, ok := <-t.input: - if !ok { - return - } - - t.totalProcessed.Add(1) - t.lastProcessed.Store(time.Now()) - - // Send entry - if err := t.sendEntry(entry); err != nil { - t.totalFailed.Add(1) - t.logger.Debug("msg", "Failed to send log entry", - "component", "tcp_client_sink", - "error", err) - } - - case <-ctx.Done(): - return - case <-t.done: - return - } - } -} - +// sendEntry formats and sends a single log entry over the connection. func (t *TCPClientSink) sendEntry(entry core.LogEntry) error { // Get current connection t.connMu.RLock() diff --git a/src/internal/source/directory.go b/src/internal/source/directory.go index 7d7764e..211b77b 100644 --- a/src/internal/source/directory.go +++ b/src/internal/source/directory.go @@ -19,7 +19,7 @@ import ( "github.com/lixenwraith/log" ) -// Monitors a directory for log files +// DirectorySource monitors a directory for log files and tails them. type DirectorySource struct { config *config.DirectorySourceOptions subscribers []chan core.LogEntry @@ -35,7 +35,7 @@ type DirectorySource struct { logger *log.Logger } -// Creates a new directory monitoring source +// NewDirectorySource creates a new directory monitoring source. func NewDirectorySource(opts *config.DirectorySourceOptions, logger *log.Logger) (*DirectorySource, error) { if opts == nil { return nil, fmt.Errorf("directory source options cannot be nil") @@ -52,6 +52,7 @@ func NewDirectorySource(opts *config.DirectorySourceOptions, logger *log.Logger) return ds, nil } +// Subscribe returns a channel for receiving log entries. func (ds *DirectorySource) Subscribe() <-chan core.LogEntry { ds.mu.Lock() defer ds.mu.Unlock() @@ -61,6 +62,7 @@ func (ds *DirectorySource) Subscribe() <-chan core.LogEntry { return ch } +// Start begins the directory monitoring loop. func (ds *DirectorySource) Start() error { ds.ctx, ds.cancel = context.WithCancel(context.Background()) ds.wg.Add(1) @@ -74,6 +76,7 @@ func (ds *DirectorySource) Start() error { return nil } +// Stop gracefully shuts down the directory source and all file watchers. func (ds *DirectorySource) Stop() { if ds.cancel != nil { ds.cancel() @@ -82,7 +85,7 @@ func (ds *DirectorySource) Stop() { ds.mu.Lock() for _, w := range ds.watchers { - w.close() + w.stop() } for _, ch := range ds.subscribers { close(ch) @@ -94,6 +97,7 @@ func (ds *DirectorySource) Stop() { "path", ds.config.Path) } +// GetStats returns the source's statistics, including active watchers. func (ds *DirectorySource) GetStats() SourceStats { lastEntry, _ := ds.lastEntryTime.Load().(time.Time) @@ -128,24 +132,7 @@ func (ds *DirectorySource) GetStats() SourceStats { } } -func (ds *DirectorySource) publish(entry core.LogEntry) { - ds.mu.RLock() - defer ds.mu.RUnlock() - - ds.totalEntries.Add(1) - ds.lastEntryTime.Store(entry.Time) - - for _, ch := range ds.subscribers { - select { - case ch <- entry: - default: - ds.droppedEntries.Add(1) - ds.logger.Debug("msg", "Dropped log entry - subscriber buffer full", - "component", "directory_source") - } - } -} - +// monitorLoop periodically scans the directory for new or changed files. func (ds *DirectorySource) monitorLoop() { defer ds.wg.Done() @@ -164,6 +151,7 @@ func (ds *DirectorySource) monitorLoop() { } } +// checkTargets finds matching files and ensures watchers are running for them. func (ds *DirectorySource) checkTargets() { files, err := ds.scanDirectory() if err != nil { @@ -182,34 +170,7 @@ func (ds *DirectorySource) checkTargets() { ds.cleanupWatchers() } -func (ds *DirectorySource) scanDirectory() ([]string, error) { - entries, err := os.ReadDir(ds.config.Path) - if err != nil { - return nil, err - } - - // Convert glob pattern to regex - regexPattern := globToRegex(ds.config.Pattern) - re, err := regexp.Compile(regexPattern) - if err != nil { - return nil, fmt.Errorf("invalid pattern regex: %w", err) - } - - var files []string - for _, entry := range entries { - if entry.IsDir() { - continue - } - - name := entry.Name() - if re.MatchString(name) { - files = append(files, filepath.Join(ds.config.Path, name)) - } - } - - return files, nil -} - +// ensureWatcher creates and starts a new file watcher if one doesn't exist for the given path. func (ds *DirectorySource) ensureWatcher(path string) { ds.mu.Lock() defer ds.mu.Unlock() @@ -247,6 +208,7 @@ func (ds *DirectorySource) ensureWatcher(path string) { }() } +// cleanupWatchers stops and removes watchers for files that no longer exist. func (ds *DirectorySource) cleanupWatchers() { ds.mu.Lock() defer ds.mu.Unlock() @@ -262,6 +224,55 @@ func (ds *DirectorySource) cleanupWatchers() { } } +// publish sends a log entry to all subscribers. +func (ds *DirectorySource) publish(entry core.LogEntry) { + ds.mu.RLock() + defer ds.mu.RUnlock() + + ds.totalEntries.Add(1) + ds.lastEntryTime.Store(entry.Time) + + for _, ch := range ds.subscribers { + select { + case ch <- entry: + default: + ds.droppedEntries.Add(1) + ds.logger.Debug("msg", "Dropped log entry - subscriber buffer full", + "component", "directory_source") + } + } +} + +// scanDirectory finds all files in the configured path that match the pattern. +func (ds *DirectorySource) scanDirectory() ([]string, error) { + entries, err := os.ReadDir(ds.config.Path) + if err != nil { + return nil, err + } + + // Convert glob pattern to regex + regexPattern := globToRegex(ds.config.Pattern) + re, err := regexp.Compile(regexPattern) + if err != nil { + return nil, fmt.Errorf("invalid pattern regex: %w", err) + } + + var files []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + if re.MatchString(name) { + files = append(files, filepath.Join(ds.config.Path, name)) + } + } + + return files, nil +} + +// globToRegex converts a simple glob pattern to a regular expression. func globToRegex(glob string) string { regex := regexp.QuoteMeta(glob) regex = strings.ReplaceAll(regex, `\*`, `.*`) diff --git a/src/internal/source/file_watcher.go b/src/internal/source/file_watcher.go index 3d9ed35..efede44 100644 --- a/src/internal/source/file_watcher.go +++ b/src/internal/source/file_watcher.go @@ -20,7 +20,7 @@ import ( "github.com/lixenwraith/log" ) -// Contains information about a file watcher +// WatcherInfo contains snapshot information about a file watcher's state. type WatcherInfo struct { Path string Size int64 @@ -31,6 +31,7 @@ type WatcherInfo struct { Rotations int64 } +// fileWatcher tails a single file, handles rotations, and sends new lines to a callback. type fileWatcher struct { path string callback func(core.LogEntry) @@ -46,6 +47,7 @@ type fileWatcher struct { logger *log.Logger } +// newFileWatcher creates a new watcher for a specific file path. func newFileWatcher(path string, callback func(core.LogEntry), logger *log.Logger) *fileWatcher { w := &fileWatcher{ path: path, @@ -57,6 +59,7 @@ func newFileWatcher(path string, callback func(core.LogEntry), logger *log.Logge return w } +// watch starts the main monitoring loop for the file. func (w *fileWatcher) watch(ctx context.Context) error { if err := w.seekToEnd(); err != nil { return fmt.Errorf("seekToEnd failed: %w", err) @@ -81,49 +84,34 @@ func (w *fileWatcher) watch(ctx context.Context) error { } } -func (w *fileWatcher) seekToEnd() error { - file, err := os.Open(w.path) - if err != nil { - if os.IsNotExist(err) { - w.mu.Lock() - w.position = 0 - w.size = 0 - w.modTime = time.Now() - w.inode = 0 - w.mu.Unlock() - return nil - } - return err - } - defer file.Close() - - info, err := file.Stat() - if err != nil { - return err - } - +// stop signals the watcher to terminate its loop. +func (w *fileWatcher) stop() { w.mu.Lock() - defer w.mu.Unlock() - - // Keep existing position (including 0) - // First time initialization seeks to the end of the file - if w.position == -1 { - pos, err := file.Seek(0, io.SeekEnd) - if err != nil { - return err - } - w.position = pos - } - - w.size = info.Size() - w.modTime = info.ModTime() - if stat, ok := info.Sys().(*syscall.Stat_t); ok { - w.inode = stat.Ino - } - - return nil + w.stopped = true + w.mu.Unlock() } +// getInfo returns a snapshot of the watcher's current statistics. +func (w *fileWatcher) getInfo() WatcherInfo { + w.mu.Lock() + info := WatcherInfo{ + Path: w.path, + Size: w.size, + Position: w.position, + ModTime: w.modTime, + EntriesRead: w.entriesRead.Load(), + Rotations: w.rotationSeq, + } + w.mu.Unlock() + + if lastRead, ok := w.lastReadTime.Load().(time.Time); ok { + info.LastReadTime = lastRead + } + + return info +} + +// checkFile examines the file for changes, rotations, or new content. func (w *fileWatcher) checkFile() error { file, err := os.Open(w.path) if err != nil { @@ -310,6 +298,58 @@ func (w *fileWatcher) checkFile() error { return nil } +// seekToEnd sets the initial read position to the end of the file. +func (w *fileWatcher) seekToEnd() error { + file, err := os.Open(w.path) + if err != nil { + if os.IsNotExist(err) { + w.mu.Lock() + w.position = 0 + w.size = 0 + w.modTime = time.Now() + w.inode = 0 + w.mu.Unlock() + return nil + } + return err + } + defer file.Close() + + info, err := file.Stat() + if err != nil { + return err + } + + w.mu.Lock() + defer w.mu.Unlock() + + // Keep existing position (including 0) + // First time initialization seeks to the end of the file + if w.position == -1 { + pos, err := file.Seek(0, io.SeekEnd) + if err != nil { + return err + } + w.position = pos + } + + w.size = info.Size() + w.modTime = info.ModTime() + if stat, ok := info.Sys().(*syscall.Stat_t); ok { + w.inode = stat.Ino + } + + return nil +} + +// isStopped checks if the watcher has been instructed to stop. +func (w *fileWatcher) isStopped() bool { + w.mu.Lock() + defer w.mu.Unlock() + return w.stopped +} + +// parseLine attempts to parse a line as JSON, falling back to plain text. func (w *fileWatcher) parseLine(line string) core.LogEntry { var jsonLog struct { Time string `json:"time"` @@ -343,6 +383,7 @@ func (w *fileWatcher) parseLine(line string) core.LogEntry { } } +// extractLogLevel heuristically determines the log level from a line of text. func extractLogLevel(line string) string { patterns := []struct { patterns []string @@ -365,39 +406,4 @@ func extractLogLevel(line string) string { } return "" -} - -func (w *fileWatcher) getInfo() WatcherInfo { - w.mu.Lock() - info := WatcherInfo{ - Path: w.path, - Size: w.size, - Position: w.position, - ModTime: w.modTime, - EntriesRead: w.entriesRead.Load(), - Rotations: w.rotationSeq, - } - w.mu.Unlock() - - if lastRead, ok := w.lastReadTime.Load().(time.Time); ok { - info.LastReadTime = lastRead - } - - return info -} - -func (w *fileWatcher) close() { - w.stop() -} - -func (w *fileWatcher) stop() { - w.mu.Lock() - w.stopped = true - w.mu.Unlock() -} - -func (w *fileWatcher) isStopped() bool { - w.mu.Lock() - defer w.mu.Unlock() - return w.stopped } \ No newline at end of file diff --git a/src/internal/source/http.go b/src/internal/source/http.go index 076d064..932ced7 100644 --- a/src/internal/source/http.go +++ b/src/internal/source/http.go @@ -2,6 +2,7 @@ package source import ( + "crypto/tls" "encoding/json" "fmt" "net" @@ -9,17 +10,17 @@ import ( "sync/atomic" "time" - "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/limit" - "logwisp/src/internal/tls" + "logwisp/src/internal/session" + ltls "logwisp/src/internal/tls" "github.com/lixenwraith/log" "github.com/valyala/fasthttp" ) -// Receives log entries via HTTP POST requests +// HTTPSource receives log entries via HTTP POST requests. type HTTPSource struct { config *config.HTTPSourceOptions @@ -35,10 +36,10 @@ type HTTPSource struct { wg sync.WaitGroup // Security - authenticator *auth.Authenticator - authFailures atomic.Uint64 - authSuccesses atomic.Uint64 - tlsManager *tls.Manager + httpSessions sync.Map + sessionManager *session.Manager + tlsManager *ltls.ServerManager + tlsStates sync.Map // remoteAddr -> *tls.ConnectionState // Statistics totalEntries atomic.Uint64 @@ -48,7 +49,7 @@ type HTTPSource struct { lastEntryTime atomic.Value // time.Time } -// Creates a new HTTP server source +// NewHTTPSource creates a new HTTP server source. func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSource, error) { // Validation done in config package if opts == nil { @@ -56,10 +57,11 @@ func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSou } h := &HTTPSource{ - config: opts, - done: make(chan struct{}), - startTime: time.Now(), - logger: logger, + config: opts, + done: make(chan struct{}), + startTime: time.Now(), + logger: logger, + sessionManager: session.NewManager(core.MaxSessionTime), } h.lastEntryTime.Store(time.Time{}) @@ -72,34 +74,17 @@ func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSou // Initialize TLS manager if configured if opts.TLS != nil && opts.TLS.Enabled { - tlsManager, err := tls.NewManager(opts.TLS, logger) + tlsManager, err := ltls.NewServerManager(opts.TLS, logger) if err != nil { return nil, fmt.Errorf("failed to create TLS manager: %w", err) } h.tlsManager = tlsManager } - // Initialize authenticator if configured - if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" { - // Verify TLS is enabled for auth (validation should have caught this) - if h.tlsManager == nil { - return nil, fmt.Errorf("authentication requires TLS to be enabled") - } - - authenticator, err := auth.NewAuthenticator(opts.Auth, logger) - if err != nil { - return nil, fmt.Errorf("failed to create authenticator: %w", err) - } - h.authenticator = authenticator - - logger.Info("msg", "Authentication configured for HTTP source", - "component", "http_source", - "auth_type", opts.Auth.Type) - } - return h, nil } +// Subscribe returns a channel for receiving log entries. func (h *HTTPSource) Subscribe() <-chan core.LogEntry { h.mu.Lock() defer h.mu.Unlock() @@ -109,7 +94,13 @@ func (h *HTTPSource) Subscribe() <-chan core.LogEntry { return ch } +// Start initializes and starts the HTTP server. func (h *HTTPSource) Start() error { + // Register expiry callback + h.sessionManager.RegisterExpiryCallback("http_source", func(sessionID, remoteAddr string) { + h.handleSessionExpiry(sessionID, remoteAddr) + }) + h.server = &fasthttp.Server{ Handler: h.requestHandler, DisableKeepalive: false, @@ -120,6 +111,20 @@ func (h *HTTPSource) Start() error { MaxRequestBodySize: int(h.config.MaxRequestBodySize), } + // TLS and mTLS configuration + if h.tlsManager != nil { + h.server.TLSConfig = h.tlsManager.GetHTTPConfig() + + // Enforce mTLS configuration from the TLSServerConfig struct. + if h.config.TLS.ClientAuth { + if h.config.TLS.VerifyClientCert { + h.server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + } else { + h.server.TLSConfig.ClientAuth = tls.RequireAnyClientCert + } + } + } + // Use configured host and port addr := fmt.Sprintf("%s:%d", h.config.Host, h.config.Port) @@ -133,12 +138,22 @@ func (h *HTTPSource) Start() error { "port", h.config.Port, "ingest_path", h.config.IngestPath, "tls_enabled", h.tlsManager != nil, - "auth_enabled", h.authenticator != nil) + "mtls_enabled", h.config.TLS != nil && h.config.TLS.ClientAuth, + ) var err error if h.tlsManager != nil { - // HTTPS server h.server.TLSConfig = h.tlsManager.GetHTTPConfig() + + // Add certificate verification callback + if h.config.TLS.ClientAuth { + h.server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + if h.config.TLS.ClientCAFile != "" { + // ClientCAs already set by tls.Manager + } + } + + // HTTPS server err = h.server.ListenAndServeTLS(addr, h.config.TLS.CertFile, h.config.TLS.KeyFile) } else { // HTTP server @@ -163,8 +178,13 @@ func (h *HTTPSource) Start() error { } } +// Stop gracefully shuts down the HTTP server. func (h *HTTPSource) Stop() { h.logger.Info("msg", "Stopping HTTP source") + + // Unregister callback + h.sessionManager.UnregisterExpiryCallback("http_source") + close(h.done) if h.server != nil { @@ -189,9 +209,15 @@ func (h *HTTPSource) Stop() { } h.mu.Unlock() + // Stop session manager + if h.sessionManager != nil { + h.sessionManager.Stop() + } + h.logger.Info("msg", "HTTP source stopped") } +// GetStats returns the source's statistics. func (h *HTTPSource) GetStats() SourceStats { lastEntry, _ := h.lastEntryTime.Load().(time.Time) @@ -200,14 +226,9 @@ func (h *HTTPSource) GetStats() SourceStats { netLimitStats = h.netLimiter.GetStats() } - var authStats map[string]any - if h.authenticator != nil { - authStats = map[string]any{ - "enabled": true, - "type": h.config.Auth.Type, - "failures": h.authFailures.Load(), - "successes": h.authSuccesses.Load(), - } + var sessionStats map[string]any + if h.sessionManager != nil { + sessionStats = h.sessionManager.GetStats() } var tlsStats map[string]any @@ -227,12 +248,13 @@ func (h *HTTPSource) GetStats() SourceStats { "path": h.config.IngestPath, "invalid_entries": h.invalidEntries.Load(), "net_limit": netLimitStats, - "auth": authStats, + "sessions": sessionStats, "tls": tlsStats, }, } } +// requestHandler is the main entry point for all incoming HTTP requests. func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { remoteAddr := ctx.RemoteAddr().String() @@ -262,42 +284,26 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { } } - // 3. Check TLS requirement for auth - if h.authenticator != nil { - isTLS := ctx.IsTLS() || h.tlsManager != nil - if !isTLS { - ctx.SetStatusCode(fasthttp.StatusForbidden) - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(map[string]string{ - "error": "TLS required for authentication", - "hint": "Use HTTPS to submit authenticated requests", - }) - return + // 3. Create session for connections + var sess *session.Session + if savedID, exists := h.httpSessions.Load(remoteAddr); exists { + if s, found := h.sessionManager.GetSession(savedID.(string)); found { + sess = s + h.sessionManager.UpdateActivity(savedID.(string)) } + } - // Authenticate request - authHeader := string(ctx.Request.Header.Peek("Authorization")) - session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr) - if err != nil { - h.authFailures.Add(1) - h.logger.Warn("msg", "Authentication failed", - "component", "http_source", - "remote_addr", remoteAddr, - "error", err) + if sess == nil { + // New connection + sess = h.sessionManager.CreateSession(remoteAddr, "http_source", map[string]any{ + "tls": ctx.IsTLS() || h.tlsManager != nil, + "mtls_enabled": h.config.TLS != nil && h.config.TLS.ClientAuth, + }) + h.httpSessions.Store(remoteAddr, sess.ID) - ctx.SetStatusCode(fasthttp.StatusUnauthorized) - if h.config.Auth.Type == "basic" && h.config.Auth.Basic != nil && h.config.Auth.Basic.Realm != "" { - ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, h.config.Auth.Basic.Realm)) - } - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(map[string]string{ - "error": "Authentication failed", - }) - return - } - - h.authSuccesses.Add(1) - _ = session // Session can be used for audit logging + // Setup connection close handler + ctx.SetConnectionClose() + go h.cleanupHTTPSession(remoteAddr, sess.ID) } // 4. Path check @@ -359,14 +365,58 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) { // Publish to subscribers h.publish(entry) + // Update session activity after successful processing + h.sessionManager.UpdateActivity(sess.ID) + // Success response ctx.SetStatusCode(fasthttp.StatusAccepted) ctx.SetContentType("application/json") json.NewEncoder(ctx).Encode(map[string]string{ - "status": "accepted", + "status": "accepted", + "session_id": sess.ID, }) } +// publish sends a log entry to all subscribers. +func (h *HTTPSource) publish(entry core.LogEntry) { + h.mu.RLock() + defer h.mu.RUnlock() + + h.totalEntries.Add(1) + h.lastEntryTime.Store(entry.Time) + + for _, ch := range h.subscribers { + select { + case ch <- entry: + default: + h.droppedEntries.Add(1) + h.logger.Debug("msg", "Dropped log entry - subscriber buffer full", + "component", "http_source") + } + } +} + +// handleSessionExpiry is the callback for cleaning up expired sessions. +func (h *HTTPSource) handleSessionExpiry(sessionID, remoteAddr string) { + h.logger.Info("msg", "Removing expired HTTP session", + "component", "http_source", + "session_id", sessionID, + "remote_addr", remoteAddr) + + // Remove from mapping + h.httpSessions.Delete(remoteAddr) +} + +// cleanupHTTPSession removes a session when a client connection is closed. +func (h *HTTPSource) cleanupHTTPSession(addr, sessionID string) { + // Wait for connection to actually close + time.Sleep(100 * time.Millisecond) + + h.httpSessions.CompareAndDelete(addr, sessionID) + h.sessionManager.RemoveSession(sessionID) +} + +// parseEntries attempts to parse a request body as a single JSON object, a JSON array, or newline-delimited JSON. func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) { var entries []core.LogEntry @@ -442,25 +492,7 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) { return entries, nil } -func (h *HTTPSource) publish(entry core.LogEntry) { - h.mu.RLock() - defer h.mu.RUnlock() - - h.totalEntries.Add(1) - h.lastEntryTime.Store(entry.Time) - - for _, ch := range h.subscribers { - select { - case ch <- entry: - default: - h.droppedEntries.Add(1) - h.logger.Debug("msg", "Dropped log entry - subscriber buffer full", - "component", "http_source") - } - } -} - -// Splits bytes into lines, handling both \n and \r\n +// splitLines splits a byte slice into lines, handling both \n and \r\n. func splitLines(data []byte) [][]byte { var lines [][]byte start := 0 diff --git a/src/internal/source/source.go b/src/internal/source/source.go index 4fe5d64..ddb8fdf 100644 --- a/src/internal/source/source.go +++ b/src/internal/source/source.go @@ -7,22 +7,22 @@ import ( "logwisp/src/internal/core" ) -// Represents an input data stream +// Source represents an input data stream for log entries. type Source interface { - // Returns a channel that receives log entries + // Subscribe returns a channel that receives log entries from the source. Subscribe() <-chan core.LogEntry - // Begins reading from the source + // Start begins reading from the source. Start() error - // Gracefully shuts down the source + // Stop gracefully shuts down the source. Stop() - // Returns source statistics + // SourceStats contains statistics about a source. GetStats() SourceStats } -// Contains statistics about a source +// SourceStats contains statistics about a source. type SourceStats struct { Type string TotalEntries uint64 diff --git a/src/internal/source/stdin.go b/src/internal/source/stdin.go index 826b5cc..d4e9984 100644 --- a/src/internal/source/stdin.go +++ b/src/internal/source/stdin.go @@ -13,7 +13,7 @@ import ( "github.com/lixenwraith/log" ) -// Reads log entries from standard input +// StdinSource reads log entries from the standard input stream. type StdinSource struct { config *config.StdinSourceOptions subscribers []chan core.LogEntry @@ -25,6 +25,7 @@ type StdinSource struct { logger *log.Logger } +// NewStdinSource creates a new stdin source. func NewStdinSource(opts *config.StdinSourceOptions, logger *log.Logger) (*StdinSource, error) { if opts == nil { opts = &config.StdinSourceOptions{ @@ -43,18 +44,21 @@ func NewStdinSource(opts *config.StdinSourceOptions, logger *log.Logger) (*Stdin return source, nil } +// Subscribe returns a channel for receiving log entries. func (s *StdinSource) Subscribe() <-chan core.LogEntry { ch := make(chan core.LogEntry, s.config.BufferSize) s.subscribers = append(s.subscribers, ch) return ch } +// Start begins reading from the standard input. func (s *StdinSource) Start() error { go s.readLoop() s.logger.Info("msg", "Stdin source started", "component", "stdin_source") return nil } +// Stop signals the source to stop reading. func (s *StdinSource) Stop() { close(s.done) for _, ch := range s.subscribers { @@ -63,6 +67,7 @@ func (s *StdinSource) Stop() { s.logger.Info("msg", "Stdin source stopped", "component", "stdin_source") } +// GetStats returns the source's statistics. func (s *StdinSource) GetStats() SourceStats { lastEntry, _ := s.lastEntryTime.Load().(time.Time) @@ -76,6 +81,7 @@ func (s *StdinSource) GetStats() SourceStats { } } +// readLoop continuously reads lines from stdin and publishes them. func (s *StdinSource) readLoop() { scanner := bufio.NewScanner(os.Stdin) for scanner.Scan() { @@ -107,6 +113,7 @@ func (s *StdinSource) readLoop() { } } +// publish sends a log entry to all subscribers. func (s *StdinSource) publish(entry core.LogEntry) { s.totalEntries.Add(1) s.lastEntryTime.Store(entry.Time) diff --git a/src/internal/source/tcp.go b/src/internal/source/tcp.go index bfcd633..c603396 100644 --- a/src/internal/source/tcp.go +++ b/src/internal/source/tcp.go @@ -7,15 +7,14 @@ import ( "encoding/json" "fmt" "net" - "strings" "sync" "sync/atomic" "time" - "logwisp/src/internal/auth" "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/limit" + "logwisp/src/internal/session" "github.com/lixenwraith/log" "github.com/lixenwraith/log/compat" @@ -27,21 +26,19 @@ const ( maxLineLength = 1 * 1024 * 1024 // 1MB max per log line ) -// Receives log entries via TCP connections +// TCPSource receives log entries via TCP connections. type TCPSource struct { - config *config.TCPSourceOptions - server *tcpSourceServer - subscribers []chan core.LogEntry - mu sync.RWMutex - done chan struct{} - engine *gnet.Engine - engineMu sync.Mutex - wg sync.WaitGroup - authenticator *auth.Authenticator - netLimiter *limit.NetLimiter - logger *log.Logger - scramManager *auth.ScramManager - scramProtocolHandler *auth.ScramProtocolHandler + config *config.TCPSourceOptions + server *tcpSourceServer + subscribers []chan core.LogEntry + mu sync.RWMutex + done chan struct{} + engine *gnet.Engine + engineMu sync.Mutex + wg sync.WaitGroup + sessionManager *session.Manager + netLimiter *limit.NetLimiter + logger *log.Logger // Statistics totalEntries atomic.Uint64 @@ -50,11 +47,9 @@ type TCPSource struct { activeConns atomic.Int64 startTime time.Time lastEntryTime atomic.Value // time.Time - authFailures atomic.Uint64 - authSuccesses atomic.Uint64 } -// Creates a new TCP server source +// NewTCPSource creates a new TCP server source. func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource, error) { // Accept typed config - validation done in config package if opts == nil { @@ -62,10 +57,11 @@ func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource } t := &TCPSource{ - config: opts, - done: make(chan struct{}), - startTime: time.Now(), - logger: logger, + config: opts, + done: make(chan struct{}), + startTime: time.Now(), + logger: logger, + sessionManager: session.NewManager(core.MaxSessionTime), } t.lastEntryTime.Store(time.Time{}) @@ -76,20 +72,10 @@ func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource t.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger) } - // Initialize SCRAM - if opts.Auth != nil && opts.Auth.Type == "scram" && opts.Auth.Scram != nil { - t.scramManager = auth.NewScramManager(opts.Auth.Scram) - t.scramProtocolHandler = auth.NewScramProtocolHandler(t.scramManager, logger) - logger.Info("msg", "SCRAM authentication configured for TCP source", - "component", "tcp_source", - "users", len(opts.Auth.Scram.Users)) - } else if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" { - return nil, fmt.Errorf("TCP source only supports 'none' or 'scram' auth") - } - return t, nil } +// Subscribe returns a channel for receiving log entries. func (t *TCPSource) Subscribe() <-chan core.LogEntry { t.mu.Lock() defer t.mu.Unlock() @@ -99,12 +85,18 @@ func (t *TCPSource) Subscribe() <-chan core.LogEntry { return ch } +// Start initializes and starts the TCP server. func (t *TCPSource) Start() error { t.server = &tcpSourceServer{ source: t, clients: make(map[gnet.Conn]*tcpClient), } + // Register expiry callback + t.sessionManager.RegisterExpiryCallback("tcp_source", func(sessionID, remoteAddr string) { + t.handleSessionExpiry(sessionID, remoteAddr) + }) + // Use configured host and port addr := fmt.Sprintf("tcp://%s:%d", t.config.Host, t.config.Port) @@ -119,7 +111,7 @@ func (t *TCPSource) Start() error { t.logger.Info("msg", "TCP source server starting", "component", "tcp_source", "port", t.config.Port, - "auth_enabled", t.authenticator != nil) + ) err := gnet.Run(t.server, addr, gnet.WithLogger(gnetLogger), @@ -150,8 +142,13 @@ func (t *TCPSource) Start() error { } } +// Stop gracefully shuts down the TCP server. func (t *TCPSource) Stop() { t.logger.Info("msg", "Stopping TCP source") + + // Unregister callback + t.sessionManager.UnregisterExpiryCallback("tcp_source") + close(t.done) // Stop gnet engine if running @@ -182,6 +179,7 @@ func (t *TCPSource) Stop() { t.logger.Info("msg", "TCP source stopped") } +// GetStats returns the source's statistics. func (t *TCPSource) GetStats() SourceStats { lastEntry, _ := t.lastEntryTime.Load().(time.Time) @@ -190,14 +188,9 @@ func (t *TCPSource) GetStats() SourceStats { netLimitStats = t.netLimiter.GetStats() } - var authStats map[string]any - if t.authenticator != nil { - authStats = map[string]any{ - "enabled": true, - "type": t.config.Auth.Type, - "failures": t.authFailures.Load(), - "successes": t.authSuccesses.Load(), - } + var sessionStats map[string]any + if t.sessionManager != nil { + sessionStats = t.sessionManager.GetStats() } return SourceStats{ @@ -211,40 +204,12 @@ func (t *TCPSource) GetStats() SourceStats { "active_connections": t.activeConns.Load(), "invalid_entries": t.invalidEntries.Load(), "net_limit": netLimitStats, - "auth": authStats, + "sessions": sessionStats, }, } } -func (t *TCPSource) publish(entry core.LogEntry) { - t.mu.RLock() - defer t.mu.RUnlock() - - t.totalEntries.Add(1) - t.lastEntryTime.Store(entry.Time) - - for _, ch := range t.subscribers { - select { - case ch <- entry: - default: - t.droppedEntries.Add(1) - t.logger.Debug("msg", "Dropped log entry - subscriber buffer full", - "component", "tcp_source") - } - } -} - -// Represents a connected TCP client -type tcpClient struct { - conn gnet.Conn - buffer *bytes.Buffer - authenticated bool - authTimeout time.Time - session *auth.Session - maxBufferSeen int -} - -// Handles gnet events +// tcpSourceServer implements the gnet.EventHandler interface for the source. type tcpSourceServer struct { gnet.BuiltinEventEngine source *TCPSource @@ -252,6 +217,15 @@ type tcpSourceServer struct { mu sync.RWMutex } +// tcpClient represents a connected TCP client and its state. +type tcpClient struct { + conn gnet.Conn + buffer *bytes.Buffer + sessionID string + maxBufferSeen int +} + +// OnBoot is called when the server starts. func (s *tcpSourceServer) OnBoot(eng gnet.Engine) gnet.Action { // Store engine reference for shutdown s.source.engineMu.Lock() @@ -264,6 +238,7 @@ func (s *tcpSourceServer) OnBoot(eng gnet.Engine) gnet.Action { return gnet.None } +// OnOpen is called when a new connection is established. func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { remoteAddr := c.RemoteAddr().String() s.source.logger.Debug("msg", "TCP connection attempt", @@ -299,7 +274,6 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { } // Track connection - // s.source.netLimiter.AddConnection(remoteAddr) if !s.source.netLimiter.TrackConnection(ip.String(), "", "") { s.source.logger.Warn("msg", "TCP connection limit exceeded", "component", "tcp_source", @@ -308,21 +282,14 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { } } + // Create session + sess := s.source.sessionManager.CreateSession(remoteAddr, "tcp_source", nil) + // Create client state client := &tcpClient{ - conn: c, - buffer: bytes.NewBuffer(nil), - authenticated: s.source.authenticator == nil, // No auth = auto authenticated - } - - if s.source.authenticator != nil { - // Set auth timeout - client.authTimeout = time.Now().Add(10 * time.Second) - - // Send auth challenge for SCRAM - if s.source.config.Auth.Type == "scram" { - out = []byte("AUTH_REQUIRED\n") - } + conn: c, + buffer: bytes.NewBuffer(nil), + sessionID: sess.ID, } s.mu.Lock() @@ -333,19 +300,29 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { s.source.logger.Debug("msg", "TCP connection opened", "component", "tcp_source", "remote_addr", remoteAddr, - "auth_enabled", s.source.authenticator != nil) + "session_id", sess.ID) return out, gnet.None } +// OnClose is called when a connection is closed. func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action { remoteAddr := c.RemoteAddr().String() + // Get client to retrieve session ID + s.mu.RLock() + client, exists := s.clients[c] + s.mu.RUnlock() + + if exists && client.sessionID != "" { + // Remove session + s.source.sessionManager.RemoveSession(client.sessionID) + } + // Untrack connection if s.source.netLimiter != nil { if tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr); err == nil { s.source.netLimiter.ReleaseConnection(tcpAddr.IP.String(), "", "") - // s.source.netLimiter.RemoveConnection(remoteAddr) } } @@ -363,6 +340,7 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action { return gnet.None } +// OnTraffic is called when data is received from a connection. func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { s.mu.RLock() client, exists := s.clients[c] @@ -372,6 +350,11 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { return gnet.Close } + // Update session activity when client sends data + if client.sessionID != "" { + s.source.sessionManager.UpdateActivity(client.sessionID) + } + // Read all available data data, err := c.Next(-1) if err != nil { @@ -381,76 +364,10 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action { return gnet.Close } - // SCRAM Authentication phase - if !client.authenticated && s.source.scramManager != nil { - // Check auth timeout - if !client.authTimeout.IsZero() && time.Now().After(client.authTimeout) { - s.source.logger.Warn("msg", "Authentication timeout", - "component", "tcp_source", - "remote_addr", c.RemoteAddr().String()) - s.source.authFailures.Add(1) - c.AsyncWrite([]byte("AUTH_TIMEOUT\n"), nil) - return gnet.Close - } - - if len(data) == 0 { - return gnet.None - } - - client.buffer.Write(data) - - // Use centralized SCRAM protocol handler - if s.source.scramProtocolHandler == nil { - s.source.scramProtocolHandler = auth.NewScramProtocolHandler(s.source.scramManager, s.source.logger) - } - - // Look for complete auth line - for { - idx := bytes.IndexByte(client.buffer.Bytes(), '\n') - if idx < 0 { - break - } - - line := client.buffer.Bytes()[:idx] - client.buffer.Next(idx + 1) - - // Process auth message through handler - authenticated, session, err := s.source.scramProtocolHandler.HandleAuthMessage(line, c) - if err != nil { - s.source.logger.Warn("msg", "SCRAM authentication failed", - "component", "tcp_source", - "remote_addr", c.RemoteAddr().String(), - "error", err) - - if strings.Contains(err.Error(), "unknown command") { - return gnet.Close - } - // Continue for other errors (might be multi-step auth) - } - - if authenticated && session != nil { - // Authentication successful - s.mu.Lock() - client.authenticated = true - client.session = session - s.mu.Unlock() - - s.source.logger.Info("msg", "Client authenticated via SCRAM", - "component", "tcp_source", - "remote_addr", c.RemoteAddr().String(), - "session_id", session.ID) - - // Clear auth buffer - client.buffer.Reset() - break - } - } - return gnet.None - } - return s.processLogData(c, client, data) } +// processLogData processes raw data from a client, parsing and publishing log entries. func (s *tcpSourceServer) processLogData(c gnet.Conn, client *tcpClient, data []byte) gnet.Action { // Check if appending the new data would exceed the client buffer limit. if client.buffer.Len()+len(data) > maxClientBufferSize { @@ -542,4 +459,43 @@ func (s *tcpSourceServer) processLogData(c gnet.Conn, client *tcpClient, data [] } return gnet.None +} + +// publish sends a log entry to all subscribers. +func (t *TCPSource) publish(entry core.LogEntry) { + t.mu.RLock() + defer t.mu.RUnlock() + + t.totalEntries.Add(1) + t.lastEntryTime.Store(entry.Time) + + for _, ch := range t.subscribers { + select { + case ch <- entry: + default: + t.droppedEntries.Add(1) + t.logger.Debug("msg", "Dropped log entry - subscriber buffer full", + "component", "tcp_source") + } + } +} + +// handleSessionExpiry is the callback for cleaning up expired sessions. +func (t *TCPSource) handleSessionExpiry(sessionID, remoteAddr string) { + t.server.mu.RLock() + defer t.server.mu.RUnlock() + + // Find connection by session ID + for conn, client := range t.server.clients { + if client.sessionID == sessionID { + t.logger.Info("msg", "Closing expired session connection", + "component", "tcp_source", + "session_id", sessionID, + "remote_addr", remoteAddr) + + // Close connection + conn.Close() + return + } + } } \ No newline at end of file diff --git a/src/internal/tls/client.go b/src/internal/tls/client.go new file mode 100644 index 0000000..d661862 --- /dev/null +++ b/src/internal/tls/client.go @@ -0,0 +1,94 @@ +// FILE: src/internal/tls/client.go +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + + "logwisp/src/internal/config" + + "github.com/lixenwraith/log" +) + +// ClientManager handles TLS configuration for client components. +type ClientManager struct { + config *config.TLSClientConfig + tlsConfig *tls.Config + logger *log.Logger +} + +// NewClientManager creates a TLS manager for clients (HTTP Client Sink). +func NewClientManager(cfg *config.TLSClientConfig, logger *log.Logger) (*ClientManager, error) { + if cfg == nil || !cfg.Enabled { + return nil, nil + } + + m := &ClientManager{ + config: cfg, + logger: logger, + tlsConfig: &tls.Config{ + MinVersion: parseTLSVersion(cfg.MinVersion, tls.VersionTLS12), + MaxVersion: parseTLSVersion(cfg.MaxVersion, tls.VersionTLS13), + }, + } + + // Cipher suite configuration + if cfg.CipherSuites != "" { + m.tlsConfig.CipherSuites = parseCipherSuites(cfg.CipherSuites) + } + + // Load client certificate for mTLS, if provided. + if cfg.ClientCertFile != "" && cfg.ClientKeyFile != "" { + clientCert, err := tls.LoadX509KeyPair(cfg.ClientCertFile, cfg.ClientKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load client cert/key: %w", err) + } + m.tlsConfig.Certificates = []tls.Certificate{clientCert} + } else if cfg.ClientCertFile != "" || cfg.ClientKeyFile != "" { + return nil, fmt.Errorf("both client_cert_file and client_key_file must be provided for mTLS") + } + + // Load server CA for verification. + if cfg.ServerCAFile != "" { + caCert, err := os.ReadFile(cfg.ServerCAFile) + if err != nil { + return nil, fmt.Errorf("failed to read server CA file: %w", err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse server CA certificate") + } + m.tlsConfig.RootCAs = caCertPool + } + + m.tlsConfig.InsecureSkipVerify = cfg.InsecureSkipVerify + m.tlsConfig.ServerName = cfg.ServerName + + logger.Info("msg", "TLS Client Manager initialized", "component", "tls") + return m, nil +} + +// GetConfig returns the client's TLS configuration. +func (m *ClientManager) GetConfig() *tls.Config { + if m == nil { + return nil + } + return m.tlsConfig.Clone() +} + +// GetStats returns statistics about the current client TLS configuration. +func (m *ClientManager) GetStats() map[string]any { + if m == nil { + return map[string]any{"enabled": false} + } + return map[string]any{ + "enabled": true, + "min_version": tlsVersionString(m.tlsConfig.MinVersion), + "max_version": tlsVersionString(m.tlsConfig.MaxVersion), + "has_client_cert": m.config.ClientCertFile != "", + "has_server_ca": m.config.ServerCAFile != "", + "insecure_skip_verify": m.config.InsecureSkipVerify, + } +} \ No newline at end of file diff --git a/src/internal/tls/manager.go b/src/internal/tls/manager.go deleted file mode 100644 index 28cf979..0000000 --- a/src/internal/tls/manager.go +++ /dev/null @@ -1,236 +0,0 @@ -// FILE: logwisp/src/internal/tls/manager.go -package tls - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "os" - "strings" - - "logwisp/src/internal/config" - - "github.com/lixenwraith/log" -) - -// Handles TLS configuration for servers -type Manager struct { - config *config.TLSConfig - tlsConfig *tls.Config - logger *log.Logger -} - -// Creates a TLS configuration from TLS config -func NewManager(cfg *config.TLSConfig, logger *log.Logger) (*Manager, error) { - if cfg == nil || !cfg.Enabled { - return nil, nil - } - - m := &Manager{ - config: cfg, - logger: logger, - } - - // Load certificate and key - cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) - if err != nil { - return nil, fmt.Errorf("failed to load cert/key: %w", err) - } - - // Create base TLS config - m.tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: parseTLSVersion(cfg.MinVersion, tls.VersionTLS12), - MaxVersion: parseTLSVersion(cfg.MaxVersion, tls.VersionTLS13), - } - - // Configure cipher suites if specified - if cfg.CipherSuites != "" { - m.tlsConfig.CipherSuites = parseCipherSuites(cfg.CipherSuites) - } else { - // Use secure defaults - m.tlsConfig.CipherSuites = []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - } - } - - // Configure client authentication (mTLS) - if cfg.ClientAuth { - if cfg.VerifyClientCert { - m.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - } else { - m.tlsConfig.ClientAuth = tls.RequireAnyClientCert - } - - // Load client CA if specified - if cfg.ClientCAFile != "" { - caCert, err := os.ReadFile(cfg.ClientCAFile) - if err != nil { - return nil, fmt.Errorf("failed to read client CA: %w", err) - } - - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, fmt.Errorf("failed to parse client CA certificate") - } - m.tlsConfig.ClientCAs = caCertPool - } - } - - // Set secure defaults - m.tlsConfig.SessionTicketsDisabled = false - m.tlsConfig.Renegotiation = tls.RenegotiateNever - - logger.Info("msg", "TLS manager initialized", - "component", "tls", - "min_version", cfg.MinVersion, - "max_version", cfg.MaxVersion, - "client_auth", cfg.ClientAuth, - "cipher_count", len(m.tlsConfig.CipherSuites)) - - return m, nil -} - -// Returns the TLS configuration -func (m *Manager) GetConfig() *tls.Config { - if m == nil { - return nil - } - // Return a clone to prevent modification - return m.tlsConfig.Clone() -} - -// Returns TLS config suitable for HTTP servers -func (m *Manager) GetHTTPConfig() *tls.Config { - if m == nil { - return nil - } - - cfg := m.tlsConfig.Clone() - // Enable HTTP/2 - cfg.NextProtos = []string{"h2", "http/1.1"} - return cfg -} - -// Validates a client certificate for mTLS -func (m *Manager) ValidateClientCert(rawCerts [][]byte) error { - if m == nil || !m.config.ClientAuth { - return nil - } - - if len(rawCerts) == 0 { - return fmt.Errorf("no client certificate provided") - } - - cert, err := x509.ParseCertificate(rawCerts[0]) - if err != nil { - return fmt.Errorf("failed to parse client certificate: %w", err) - } - - // Verify against CA if configured - if m.tlsConfig.ClientCAs != nil { - opts := x509.VerifyOptions{ - Roots: m.tlsConfig.ClientCAs, - Intermediates: x509.NewCertPool(), - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - } - - // Add any intermediate certs - for i := 1; i < len(rawCerts); i++ { - intermediate, err := x509.ParseCertificate(rawCerts[i]) - if err != nil { - continue - } - opts.Intermediates.AddCert(intermediate) - } - - if _, err := cert.Verify(opts); err != nil { - return fmt.Errorf("client certificate verification failed: %w", err) - } - } - - m.logger.Debug("msg", "Client certificate validated", - "component", "tls", - "subject", cert.Subject.String(), - "serial", cert.SerialNumber.String()) - - return nil -} - -// Returns TLS statistics -func (m *Manager) GetStats() map[string]any { - if m == nil { - return map[string]any{"enabled": false} - } - - return map[string]any{ - "enabled": true, - "min_version": tlsVersionString(m.tlsConfig.MinVersion), - "max_version": tlsVersionString(m.tlsConfig.MaxVersion), - "client_auth": m.config.ClientAuth, - "cipher_suites": len(m.tlsConfig.CipherSuites), - } -} - -func parseTLSVersion(version string, defaultVersion uint16) uint16 { - switch strings.ToUpper(version) { - case "TLS1.0", "TLS10": - return tls.VersionTLS10 - case "TLS1.1", "TLS11": - return tls.VersionTLS11 - case "TLS1.2", "TLS12": - return tls.VersionTLS12 - case "TLS1.3", "TLS13": - return tls.VersionTLS13 - default: - return defaultVersion - } -} - -func parseCipherSuites(suites string) []uint16 { - var result []uint16 - - // Map of cipher suite names to IDs - suiteMap := map[string]uint16{ - // TLS 1.2 ECDHE suites (preferred) - "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - - // RSA suites (less preferred) - "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, - } - - for _, suite := range strings.Split(suites, ",") { - suite = strings.TrimSpace(suite) - if id, ok := suiteMap[suite]; ok { - result = append(result, id) - } - } - - return result -} - -func tlsVersionString(version uint16) string { - switch version { - case tls.VersionTLS10: - return "TLS1.0" - case tls.VersionTLS11: - return "TLS1.1" - case tls.VersionTLS12: - return "TLS1.2" - case tls.VersionTLS13: - return "TLS1.3" - default: - return fmt.Sprintf("0x%04x", version) - } -} \ No newline at end of file diff --git a/src/internal/tls/parse.go b/src/internal/tls/parse.go new file mode 100644 index 0000000..74b8db0 --- /dev/null +++ b/src/internal/tls/parse.go @@ -0,0 +1,69 @@ +// FILE: logwisp/src/internal/tls/parse.go +package tls + +import ( + "crypto/tls" + "fmt" + "strings" +) + +// parseTLSVersion converts a string representation (e.g., "TLS1.2") into a Go crypto/tls constant. +func parseTLSVersion(version string, defaultVersion uint16) uint16 { + switch strings.ToUpper(version) { + case "TLS1.0", "TLS10": + return tls.VersionTLS10 + case "TLS1.1", "TLS11": + return tls.VersionTLS11 + case "TLS1.2", "TLS12": + return tls.VersionTLS12 + case "TLS1.3", "TLS13": + return tls.VersionTLS13 + default: + return defaultVersion + } +} + +// parseCipherSuites converts a comma-separated string of cipher suite names into a slice of Go constants. +func parseCipherSuites(suites string) []uint16 { + var result []uint16 + + // Map of cipher suite names to IDs + suiteMap := map[string]uint16{ + // TLS 1.2 ECDHE suites (preferred) + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + + // RSA suites + "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + } + + for _, suite := range strings.Split(suites, ",") { + suite = strings.TrimSpace(suite) + if id, ok := suiteMap[suite]; ok { + result = append(result, id) + } + } + + return result +} + +// tlsVersionString converts a Go crypto/tls version constant back into a string representation. +func tlsVersionString(version uint16) string { + switch version { + case tls.VersionTLS10: + return "TLS1.0" + case tls.VersionTLS11: + return "TLS1.1" + case tls.VersionTLS12: + return "TLS1.2" + case tls.VersionTLS13: + return "TLS1.3" + default: + return fmt.Sprintf("0x%04x", version) + } +} \ No newline at end of file diff --git a/src/internal/tls/server.go b/src/internal/tls/server.go new file mode 100644 index 0000000..0477275 --- /dev/null +++ b/src/internal/tls/server.go @@ -0,0 +1,99 @@ +// FILE: src/internal/tls/server.go +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + + "logwisp/src/internal/config" + + "github.com/lixenwraith/log" +) + +// ServerManager handles TLS configuration for server components. +type ServerManager struct { + config *config.TLSServerConfig + tlsConfig *tls.Config + logger *log.Logger +} + +// NewServerManager creates a TLS manager for servers (HTTP Source/Sink). +func NewServerManager(cfg *config.TLSServerConfig, logger *log.Logger) (*ServerManager, error) { + if cfg == nil || !cfg.Enabled { + return nil, nil + } + + m := &ServerManager{ + config: cfg, + logger: logger, + } + + cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load server cert/key: %w", err) + } + + // Enforce TLS 1.2 / TLS 1.3 + m.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: parseTLSVersion(cfg.MinVersion, tls.VersionTLS12), + MaxVersion: parseTLSVersion(cfg.MaxVersion, tls.VersionTLS13), + } + + if cfg.CipherSuites != "" { + m.tlsConfig.CipherSuites = parseCipherSuites(cfg.CipherSuites) + } else { + // Use secure defaults + m.tlsConfig.CipherSuites = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } + } + + // Configure client authentication (mTLS) + if cfg.ClientAuth { + if cfg.ClientCAFile == "" { + return nil, fmt.Errorf("client_auth is enabled but client_ca_file is not specified") + } + caCert, err := os.ReadFile(cfg.ClientCAFile) + if err != nil { + return nil, fmt.Errorf("failed to read client CA file: %w", err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse client CA certificate") + } + m.tlsConfig.ClientCAs = caCertPool + } + + logger.Info("msg", "TLS Server Manager initialized", "component", "tls") + return m, nil +} + +// GetHTTPConfig returns a TLS configuration suitable for HTTP servers. +func (m *ServerManager) GetHTTPConfig() *tls.Config { + if m == nil { + return nil + } + cfg := m.tlsConfig.Clone() + cfg.NextProtos = []string{"h2", "http/1.1"} + return cfg +} + +// GetStats returns statistics about the current server TLS configuration. +func (m *ServerManager) GetStats() map[string]any { + if m == nil { + return map[string]any{"enabled": false} + } + return map[string]any{ + "enabled": true, + "min_version": tlsVersionString(m.tlsConfig.MinVersion), + "max_version": tlsVersionString(m.tlsConfig.MaxVersion), + "client_auth": m.config.ClientAuth, + "cipher_suites": len(m.tlsConfig.CipherSuites), + } +} \ No newline at end of file diff --git a/src/internal/version/version.go b/src/internal/version/version.go index 640352e..536bdeb 100644 --- a/src/internal/version/version.go +++ b/src/internal/version/version.go @@ -4,13 +4,15 @@ package version import "fmt" var ( - // Version is set at compile time via -ldflags - Version = "dev" + // Version is the application version, set at compile time via -ldflags. + Version = "dev" + // GitCommit is the git commit hash, set at compile time. GitCommit = "unknown" + // BuildTime is the application build time, set at compile time. BuildTime = "unknown" ) -// Returns a formatted version string +// String returns a detailed, formatted version string including commit and build time. func String() string { if Version == "dev" { return fmt.Sprintf("dev (commit: %s, built: %s)", GitCommit, BuildTime) @@ -18,7 +20,7 @@ func String() string { return fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime) } -// Returns just the version tag +// Short returns just the version tag. func Short() string { return Version } \ No newline at end of file diff --git a/test/README.md b/test/README.md deleted file mode 100644 index 5bf2a65..0000000 --- a/test/README.md +++ /dev/null @@ -1,12 +0,0 @@ -### Usage: - -- Copy logwisp executable to the test folder (to compile, in logwisp top directory: `make build`). - -- Run the test script for each scenario. - -### Notes: - -- The tests create configuration files and log files. Most tests set logging at debug level and don't clean up their temp files that are created in the current execution directory. - -- Some tests may need to be run on different hosts (containers can be used). -