v0.8.0 decoupled session management and auth, auth deprecated except mtls, session management, tls and mtls flows fixed, docs and config outdated
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ script
|
||||
build
|
||||
*.log
|
||||
*.toml
|
||||
build.sh
|
||||
|
||||
@ -13,7 +13,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Creates and initializes the log transport service
|
||||
// bootstrapService creates and initializes the main log transport service and its pipelines.
|
||||
func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service, error) {
|
||||
// Create service with logger dependency injection
|
||||
svc := service.NewService(ctx, logger)
|
||||
@ -45,7 +45,7 @@ func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// Sets up the logger based on configuration
|
||||
// initializeLogger sets up the global logger based on the application's configuration.
|
||||
func initializeLogger(cfg *config.Config) error {
|
||||
logger = log.NewLogger()
|
||||
logCfg := log.DefaultConfig()
|
||||
@ -103,7 +103,7 @@ func initializeLogger(cfg *config.Config) error {
|
||||
return logger.ApplyConfig(logCfg)
|
||||
}
|
||||
|
||||
// Sets up file-based logging parameters
|
||||
// configureFileLogging sets up file-based logging parameters from the configuration.
|
||||
func configureFileLogging(logCfg *log.Config, cfg *config.Config) {
|
||||
if cfg.Logging.File != nil {
|
||||
logCfg.Directory = cfg.Logging.File.Directory
|
||||
@ -116,6 +116,7 @@ func configureFileLogging(logCfg *log.Config, cfg *config.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
// parseLogLevel converts a string log level to its corresponding integer value.
|
||||
func parseLogLevel(level string) (int64, error) {
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
|
||||
@ -1,355 +0,0 @@
|
||||
// FILE: src/cmd/logwisp/commands/auth.go
|
||||
package commands
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type AuthCommand struct {
|
||||
output io.Writer
|
||||
errOut io.Writer
|
||||
}
|
||||
|
||||
func NewAuthCommand() *AuthCommand {
|
||||
return &AuthCommand{
|
||||
output: os.Stdout,
|
||||
errOut: os.Stderr,
|
||||
}
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) Execute(args []string) error {
|
||||
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
|
||||
cmd.SetOutput(ac.errOut)
|
||||
|
||||
var (
|
||||
// User credentials
|
||||
username = cmd.String("u", "", "Username")
|
||||
usernameLong = cmd.String("user", "", "Username")
|
||||
password = cmd.String("p", "", "Password (will prompt if not provided)")
|
||||
passwordLong = cmd.String("password", "", "Password (will prompt if not provided)")
|
||||
|
||||
// Auth type selection (multiple ways to specify)
|
||||
authType = cmd.String("t", "", "Auth type: basic, scram, or token")
|
||||
authTypeLong = cmd.String("type", "", "Auth type: basic, scram, or token")
|
||||
useScram = cmd.Bool("s", false, "Generate SCRAM credentials (TCP)")
|
||||
useScramLong = cmd.Bool("scram", false, "Generate SCRAM credentials (TCP)")
|
||||
useBasic = cmd.Bool("b", false, "Generate basic auth credentials (HTTP)")
|
||||
useBasicLong = cmd.Bool("basic", false, "Generate basic auth credentials (HTTP)")
|
||||
|
||||
// Token generation
|
||||
genToken = cmd.Bool("k", false, "Generate random bearer token")
|
||||
genTokenLong = cmd.Bool("token", false, "Generate random bearer token")
|
||||
tokenLen = cmd.Int("l", 32, "Token length in bytes")
|
||||
tokenLenLong = cmd.Int("length", 32, "Token length in bytes")
|
||||
|
||||
// Migration option
|
||||
migrate = cmd.Bool("m", false, "Convert basic auth PHC to SCRAM")
|
||||
migrateLong = cmd.Bool("migrate", false, "Convert basic auth PHC to SCRAM")
|
||||
phcHash = cmd.String("phc", "", "PHC hash to migrate (required with --migrate)")
|
||||
)
|
||||
|
||||
cmd.Usage = func() {
|
||||
fmt.Fprintln(ac.errOut, "Generate authentication credentials for LogWisp")
|
||||
fmt.Fprintln(ac.errOut, "\nUsage: logwisp auth [options]")
|
||||
fmt.Fprintln(ac.errOut, "\nExamples:")
|
||||
fmt.Fprintln(ac.errOut, " # Generate basic auth hash for HTTP sources/sinks")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth -u admin -b")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth --user=admin --basic")
|
||||
fmt.Fprintln(ac.errOut, " ")
|
||||
fmt.Fprintln(ac.errOut, " # Generate SCRAM credentials for TCP")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth -u tcpuser -s")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth --user=tcpuser --scram")
|
||||
fmt.Fprintln(ac.errOut, " ")
|
||||
fmt.Fprintln(ac.errOut, " # Generate bearer token")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth -k -l 64")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth --token --length=64")
|
||||
fmt.Fprintln(ac.errOut, "\nOptions:")
|
||||
cmd.PrintDefaults()
|
||||
}
|
||||
|
||||
if err := cmd.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for unparsed arguments
|
||||
if cmd.NArg() > 0 {
|
||||
return fmt.Errorf("unexpected argument(s): %s", strings.Join(cmd.Args(), " "))
|
||||
}
|
||||
|
||||
// Merge short and long form values
|
||||
finalUsername := coalesceString(*username, *usernameLong)
|
||||
finalPassword := coalesceString(*password, *passwordLong)
|
||||
finalAuthType := coalesceString(*authType, *authTypeLong)
|
||||
finalGenToken := coalesceBool(*genToken, *genTokenLong)
|
||||
finalTokenLen := coalesceInt(*tokenLen, *tokenLenLong, core.DefaultTokenLength)
|
||||
finalUseScram := coalesceBool(*useScram, *useScramLong)
|
||||
finalUseBasic := coalesceBool(*useBasic, *useBasicLong)
|
||||
finalMigrate := coalesceBool(*migrate, *migrateLong)
|
||||
|
||||
// Handle migration mode
|
||||
if finalMigrate {
|
||||
if *phcHash == "" || finalUsername == "" || finalPassword == "" {
|
||||
return fmt.Errorf("--migrate requires --user, --password, and --phc flags")
|
||||
}
|
||||
return ac.migrateToScram(finalUsername, finalPassword, *phcHash)
|
||||
}
|
||||
|
||||
// Determine auth type from flags
|
||||
if finalGenToken || finalAuthType == "token" {
|
||||
return ac.generateToken(finalTokenLen)
|
||||
}
|
||||
|
||||
// Determine credential type
|
||||
credType := "basic" // default
|
||||
|
||||
// Check explicit type flags
|
||||
if finalUseScram || finalAuthType == "scram" {
|
||||
credType = "scram"
|
||||
} else if finalUseBasic || finalAuthType == "basic" {
|
||||
credType = "basic"
|
||||
} else if finalAuthType != "" {
|
||||
return fmt.Errorf("invalid auth type: %s (valid: basic, scram, token)", finalAuthType)
|
||||
}
|
||||
|
||||
// Username required for password-based auth
|
||||
if finalUsername == "" {
|
||||
cmd.Usage()
|
||||
return fmt.Errorf("username required for %s auth generation", credType)
|
||||
}
|
||||
|
||||
return ac.generatePasswordHash(finalUsername, finalPassword, credType)
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) Description() string {
|
||||
return "Generate authentication credentials (passwords, tokens, SCRAM)"
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) Help() string {
|
||||
return `Auth Command - Generate authentication credentials for LogWisp
|
||||
|
||||
Usage:
|
||||
logwisp auth [options]
|
||||
|
||||
Authentication Types:
|
||||
HTTP/HTTPS Sources & Sinks (TLS required):
|
||||
- Basic Auth: Username/password with Argon2id hashing
|
||||
- Bearer Token: Random cryptographic tokens
|
||||
|
||||
TCP Sources & Sinks (No TLS):
|
||||
- SCRAM: Argon2-SCRAM-SHA256 for plaintext connections
|
||||
|
||||
Options:
|
||||
-u, --user <name> Username for credential generation
|
||||
-p, --password <pass> Password (will prompt if not provided)
|
||||
-t, --type <type> Auth type: "basic", "scram", or "token"
|
||||
-b, --basic Generate basic auth credentials (HTTP/HTTPS)
|
||||
-s, --scram Generate SCRAM credentials (TCP)
|
||||
-k, --token Generate random bearer token
|
||||
-l, --length <bytes> Token length in bytes (default: 32)
|
||||
|
||||
Examples:
|
||||
Examples:
|
||||
# Generate basic auth hash for HTTP/HTTPS (with TLS)
|
||||
logwisp auth -u admin -b
|
||||
logwisp auth --user=admin --basic
|
||||
|
||||
# Generate SCRAM credentials for TCP (without TLS)
|
||||
logwisp auth -u tcpuser -s
|
||||
logwisp auth --user=tcpuser --type=scram
|
||||
|
||||
# Generate 64-byte bearer token
|
||||
logwisp auth -k -l 64
|
||||
logwisp auth --token --length=64
|
||||
|
||||
# Convert existing basic auth to SCRAM (HTTPS to TCP conversion)
|
||||
logwisp auth -u admin -m --phc='$argon2id$v=19$m=65536...' --password='secret'
|
||||
|
||||
Output:
|
||||
The command outputs configuration snippets ready to paste into logwisp.toml
|
||||
and the raw credential values for external auth files.
|
||||
|
||||
Security Notes:
|
||||
- Basic auth and tokens require TLS encryption for HTTP connections
|
||||
- SCRAM provides authentication but NOT encryption for TCP connections
|
||||
- Use strong passwords (12+ characters with mixed case, numbers, symbols)
|
||||
- Store credentials securely and never commit them to version control
|
||||
`
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) generatePasswordHash(username, password, credType string) error {
|
||||
// Get password if not provided
|
||||
if password == "" {
|
||||
var err error
|
||||
password, err = ac.promptForPassword()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch credType {
|
||||
case "basic":
|
||||
return ac.generateBasicAuth(username, password)
|
||||
case "scram":
|
||||
return ac.generateScramAuth(username, password)
|
||||
default:
|
||||
return fmt.Errorf("invalid credential type: %s", credType)
|
||||
}
|
||||
}
|
||||
|
||||
// promptForPassword handles password prompting with confirmation
|
||||
func (ac *AuthCommand) promptForPassword() (string, error) {
|
||||
pass1 := ac.promptPassword("Enter password: ")
|
||||
pass2 := ac.promptPassword("Confirm password: ")
|
||||
if pass1 != pass2 {
|
||||
return "", fmt.Errorf("passwords don't match")
|
||||
}
|
||||
return pass1, nil
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) promptPassword(prompt string) string {
|
||||
fmt.Fprint(ac.errOut, prompt)
|
||||
password, err := term.ReadPassword(syscall.Stdin)
|
||||
fmt.Fprintln(ac.errOut)
|
||||
if err != nil {
|
||||
fmt.Fprintf(ac.errOut, "Failed to read password: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return string(password)
|
||||
}
|
||||
|
||||
// generateBasicAuth creates Argon2id hash for HTTP basic auth
|
||||
func (ac *AuthCommand) generateBasicAuth(username, password string) error {
|
||||
// Generate salt
|
||||
salt := make([]byte, core.Argon2SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Generate Argon2id hash
|
||||
cred, err := auth.DeriveCredential(username, password, salt,
|
||||
core.Argon2Time, core.Argon2Memory, core.Argon2Threads)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to derive credential: %w", err)
|
||||
}
|
||||
|
||||
// Output configuration snippets
|
||||
fmt.Fprintln(ac.output, "\n# Basic Auth Configuration (HTTP sources/sinks)")
|
||||
fmt.Fprintln(ac.output, "# REQUIRES HTTPS/TLS for security")
|
||||
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[pipelines.auth]")
|
||||
fmt.Fprintln(ac.output, `type = "basic"`)
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[[pipelines.auth.basic_auth.users]]")
|
||||
fmt.Fprintf(ac.output, "username = %q\n", username)
|
||||
fmt.Fprintf(ac.output, "password_hash = %q\n\n", cred.PHCHash)
|
||||
|
||||
fmt.Fprintln(ac.output, "# For external users file:")
|
||||
fmt.Fprintf(ac.output, "%s:%s\n", username, cred.PHCHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateScramAuth creates Argon2id-SCRAM-SHA256 credentials for TCP
|
||||
func (ac *AuthCommand) generateScramAuth(username, password string) error {
|
||||
// Generate salt
|
||||
salt := make([]byte, core.Argon2SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Use internal auth package to derive SCRAM credentials
|
||||
cred, err := auth.DeriveCredential(username, password, salt,
|
||||
core.Argon2Time, core.Argon2Memory, core.Argon2Threads)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to derive SCRAM credential: %w", err)
|
||||
}
|
||||
|
||||
// Output SCRAM configuration
|
||||
fmt.Fprintln(ac.output, "\n# SCRAM Auth Configuration (TCP sources/sinks)")
|
||||
fmt.Fprintln(ac.output, "# Provides authentication but NOT encryption")
|
||||
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[pipelines.auth]")
|
||||
fmt.Fprintln(ac.output, `type = "scram"`)
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]")
|
||||
fmt.Fprintf(ac.output, "username = %q\n", username)
|
||||
fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
|
||||
fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
|
||||
fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
|
||||
fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime)
|
||||
fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory)
|
||||
fmt.Fprintf(ac.output, "argon_threads = %d\n\n", cred.ArgonThreads)
|
||||
|
||||
fmt.Fprintln(ac.output, "# Note: SCRAM provides authentication only.")
|
||||
fmt.Fprintln(ac.output, "# Use TLS/mTLS for encryption if needed.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) generateToken(length int) error {
|
||||
if length < 16 {
|
||||
fmt.Fprintln(ac.errOut, "Warning: tokens < 16 bytes are cryptographically weak")
|
||||
}
|
||||
if length > 512 {
|
||||
return fmt.Errorf("token length exceeds maximum (512 bytes)")
|
||||
}
|
||||
|
||||
token := make([]byte, length)
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
return fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
|
||||
hex := fmt.Sprintf("%x", token)
|
||||
|
||||
fmt.Fprintln(ac.output, "\n# Token Configuration")
|
||||
fmt.Fprintln(ac.output, "# Add to logwisp.toml:")
|
||||
fmt.Fprintf(ac.output, "tokens = [%q]\n\n", b64)
|
||||
|
||||
fmt.Fprintln(ac.output, "# Generated Token:")
|
||||
fmt.Fprintf(ac.output, "Base64: %s\n", b64)
|
||||
fmt.Fprintf(ac.output, "Hex: %s\n", hex)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateToScram converts basic auth PHC hash to SCRAM credentials
|
||||
func (ac *AuthCommand) migrateToScram(username, password, phcHash string) error {
|
||||
// CHANGED: Moved from internal/auth to CLI command layer
|
||||
cred, err := auth.MigrateFromPHC(username, password, phcHash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
// Output SCRAM configuration (reuse format from generateScramAuth)
|
||||
fmt.Fprintln(ac.output, "\n# Migrated SCRAM Credentials")
|
||||
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[pipelines.auth]")
|
||||
fmt.Fprintln(ac.output, `type = "scram"`)
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]")
|
||||
fmt.Fprintf(ac.output, "username = %q\n", username)
|
||||
fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
|
||||
fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
|
||||
fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
|
||||
fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime)
|
||||
fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory)
|
||||
fmt.Fprintf(ac.output, "argon_threads = %d\n", cred.ArgonThreads)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// generalHelpTemplate is the default help message shown when no specific command is requested.
|
||||
const generalHelpTemplate = `LogWisp: A flexible log transport and processing tool.
|
||||
|
||||
Usage:
|
||||
@ -37,9 +38,6 @@ Configuration Sources (Precedence: CLI > Env > File > Defaults):
|
||||
- TOML configuration file is the primary method
|
||||
|
||||
Examples:
|
||||
# Generate password for admin user
|
||||
logwisp auth -u admin
|
||||
|
||||
# Start service with custom config
|
||||
logwisp -c /etc/logwisp/prod.toml
|
||||
|
||||
@ -49,17 +47,17 @@ Examples:
|
||||
For detailed configuration options, please refer to the documentation.
|
||||
`
|
||||
|
||||
// HelpCommand handles help display
|
||||
// HelpCommand handles the display of general or command-specific help messages.
|
||||
type HelpCommand struct {
|
||||
router *CommandRouter
|
||||
}
|
||||
|
||||
// NewHelpCommand creates a new help command
|
||||
// NewHelpCommand creates a new help command handler.
|
||||
func NewHelpCommand(router *CommandRouter) *HelpCommand {
|
||||
return &HelpCommand{router: router}
|
||||
}
|
||||
|
||||
// Execute displays help information
|
||||
// Execute displays the appropriate help message based on the provided arguments.
|
||||
func (c *HelpCommand) Execute(args []string) error {
|
||||
// Check if help is requested for a specific command
|
||||
if len(args) > 0 && args[0] != "" {
|
||||
@ -78,7 +76,27 @@ func (c *HelpCommand) Execute(args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatCommandList creates a formatted list of available commands
|
||||
// Description returns a brief one-line description of the command.
|
||||
func (c *HelpCommand) Description() string {
|
||||
return "Display help information"
|
||||
}
|
||||
|
||||
// Help returns the detailed help text for the 'help' command itself.
|
||||
func (c *HelpCommand) Help() string {
|
||||
return `Help Command - Display help information
|
||||
|
||||
Usage:
|
||||
logwisp help Show general help
|
||||
logwisp help <command> Show help for a specific command
|
||||
|
||||
Examples:
|
||||
logwisp help # Show general help
|
||||
logwisp help auth # Show auth command help
|
||||
logwisp auth --help # Alternative way to get command help
|
||||
`
|
||||
}
|
||||
|
||||
// formatCommandList creates a formatted and aligned list of all available commands.
|
||||
func (c *HelpCommand) formatCommandList() string {
|
||||
commands := c.router.GetCommands()
|
||||
|
||||
@ -103,21 +121,3 @@ func (c *HelpCommand) formatCommandList() string {
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (c *HelpCommand) Description() string {
|
||||
return "Display help information"
|
||||
}
|
||||
|
||||
func (c *HelpCommand) Help() string {
|
||||
return `Help Command - Display help information
|
||||
|
||||
Usage:
|
||||
logwisp help Show general help
|
||||
logwisp help <command> Show help for a specific command
|
||||
|
||||
Examples:
|
||||
logwisp help # Show general help
|
||||
logwisp help auth # Show auth command help
|
||||
logwisp auth --help # Alternative way to get command help
|
||||
`
|
||||
}
|
||||
@ -6,26 +6,25 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// Handler defines the interface for subcommands
|
||||
// Handler defines the interface required for all subcommands.
|
||||
type Handler interface {
|
||||
Execute(args []string) error
|
||||
Description() string
|
||||
Help() string
|
||||
}
|
||||
|
||||
// CommandRouter handles subcommand routing before main app initialization
|
||||
// CommandRouter handles the routing of CLI arguments to the appropriate subcommand handler.
|
||||
type CommandRouter struct {
|
||||
commands map[string]Handler
|
||||
}
|
||||
|
||||
// NewCommandRouter creates and initializes the command router
|
||||
// NewCommandRouter creates and initializes the command router with all available commands.
|
||||
func NewCommandRouter() *CommandRouter {
|
||||
router := &CommandRouter{
|
||||
commands: make(map[string]Handler),
|
||||
}
|
||||
|
||||
// Register available commands
|
||||
router.commands["auth"] = NewAuthCommand()
|
||||
router.commands["tls"] = NewTLSCommand()
|
||||
router.commands["version"] = NewVersionCommand()
|
||||
router.commands["help"] = NewHelpCommand(router)
|
||||
@ -33,7 +32,7 @@ func NewCommandRouter() *CommandRouter {
|
||||
return router
|
||||
}
|
||||
|
||||
// Route checks for and executes subcommands
|
||||
// Route checks for and executes a subcommand based on the provided CLI arguments.
|
||||
func (r *CommandRouter) Route(args []string) (bool, error) {
|
||||
if len(args) < 2 {
|
||||
return false, nil // No command specified, let main app continue
|
||||
@ -69,18 +68,18 @@ func (r *CommandRouter) Route(args []string) (bool, error) {
|
||||
return true, handler.Execute(args[2:])
|
||||
}
|
||||
|
||||
// GetCommand returns a command handler by name
|
||||
// GetCommand returns a specific command handler by its name.
|
||||
func (r *CommandRouter) GetCommand(name string) (Handler, bool) {
|
||||
cmd, exists := r.commands[name]
|
||||
return cmd, exists
|
||||
}
|
||||
|
||||
// GetCommands returns all registered commands
|
||||
// GetCommands returns a map of all registered commands.
|
||||
func (r *CommandRouter) GetCommands() map[string]Handler {
|
||||
return r.commands
|
||||
}
|
||||
|
||||
// ShowCommands displays available subcommands
|
||||
// ShowCommands displays a list of available subcommands to stderr.
|
||||
func (r *CommandRouter) ShowCommands() {
|
||||
for name, handler := range r.commands {
|
||||
fmt.Fprintf(os.Stderr, " %-10s %s\n", name, handler.Description())
|
||||
@ -88,7 +87,7 @@ func (r *CommandRouter) ShowCommands() {
|
||||
fmt.Fprintln(os.Stderr, "\nUse 'logwisp <command> --help' for command-specific help")
|
||||
}
|
||||
|
||||
// Helper functions to merge short and long options
|
||||
// coalesceString returns the first non-empty string from a list of arguments.
|
||||
func coalesceString(values ...string) string {
|
||||
for _, v := range values {
|
||||
if v != "" {
|
||||
@ -98,6 +97,7 @@ func coalesceString(values ...string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// coalesceInt returns the first non-default integer from a list of arguments.
|
||||
func coalesceInt(primary, secondary, defaultVal int) int {
|
||||
if primary != defaultVal {
|
||||
return primary
|
||||
@ -108,6 +108,7 @@ func coalesceInt(primary, secondary, defaultVal int) int {
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// coalesceBool returns true if any of the boolean arguments is true.
|
||||
func coalesceBool(values ...bool) bool {
|
||||
for _, v := range values {
|
||||
if v {
|
||||
|
||||
@ -17,11 +17,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TLSCommand handles the generation of TLS certificates.
|
||||
type TLSCommand struct {
|
||||
output io.Writer
|
||||
errOut io.Writer
|
||||
}
|
||||
|
||||
// NewTLSCommand creates a new TLS command handler.
|
||||
func NewTLSCommand() *TLSCommand {
|
||||
return &TLSCommand{
|
||||
output: os.Stdout,
|
||||
@ -29,6 +31,7 @@ func NewTLSCommand() *TLSCommand {
|
||||
}
|
||||
}
|
||||
|
||||
// Execute parses flags and routes to the appropriate certificate generation function.
|
||||
func (tc *TLSCommand) Execute(args []string) error {
|
||||
cmd := flag.NewFlagSet("tls", flag.ContinueOnError)
|
||||
cmd.SetOutput(tc.errOut)
|
||||
@ -133,10 +136,12 @@ func (tc *TLSCommand) Execute(args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Description returns a brief one-line description of the command.
|
||||
func (tc *TLSCommand) Description() string {
|
||||
return "Generate TLS certificates (CA, server, client, self-signed)"
|
||||
}
|
||||
|
||||
// Help returns the detailed help text for the command.
|
||||
func (tc *TLSCommand) Help() string {
|
||||
return `TLS Command - Generate TLS certificates for LogWisp
|
||||
|
||||
@ -195,7 +200,7 @@ Security Notes:
|
||||
`
|
||||
}
|
||||
|
||||
// Create and manage private CA
|
||||
// generateCA creates a new Certificate Authority (CA) certificate and private key.
|
||||
func (tc *TLSCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error {
|
||||
// Generate RSA key
|
||||
priv, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
@ -250,28 +255,7 @@ func (tc *TLSCommand) generateCA(cn, org, country string, days, bits int, certFi
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseHosts(hostList string) ([]string, []net.IP) {
|
||||
var dnsNames []string
|
||||
var ipAddrs []net.IP
|
||||
|
||||
if hostList == "" {
|
||||
return dnsNames, ipAddrs
|
||||
}
|
||||
|
||||
hosts := strings.Split(hostList, ",")
|
||||
for _, h := range hosts {
|
||||
h = strings.TrimSpace(h)
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
ipAddrs = append(ipAddrs, ip)
|
||||
} else {
|
||||
dnsNames = append(dnsNames, h)
|
||||
}
|
||||
}
|
||||
|
||||
return dnsNames, ipAddrs
|
||||
}
|
||||
|
||||
// Generate self-signed certificate
|
||||
// generateSelfSigned creates a new self-signed server certificate and private key.
|
||||
func (tc *TLSCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error {
|
||||
// 1. Generate an RSA private key with the specified bit size
|
||||
priv, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
@ -338,7 +322,7 @@ func (tc *TLSCommand) generateSelfSigned(cn, org, country, hosts string, days, b
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate server cert with CA
|
||||
// generateServerCert creates a new server certificate signed by a provided CA.
|
||||
func (tc *TLSCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
|
||||
caCert, caKey, err := loadCA(caFile, caKeyFile)
|
||||
if err != nil {
|
||||
@ -401,7 +385,7 @@ func (tc *TLSCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyF
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate client cert with CA
|
||||
// generateClientCert creates a new client certificate signed by a provided CA for mTLS.
|
||||
func (tc *TLSCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
|
||||
caCert, caKey, err := loadCA(caFile, caKeyFile)
|
||||
if err != nil {
|
||||
@ -458,7 +442,7 @@ func (tc *TLSCommand) generateClientCert(cn, org, country, caFile, caKeyFile str
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load cert with CA
|
||||
// loadCA reads and parses a CA certificate and its corresponding private key from files.
|
||||
func loadCA(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) {
|
||||
// Load CA certificate
|
||||
certPEM, err := os.ReadFile(certFile)
|
||||
@ -517,6 +501,7 @@ func loadCA(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error
|
||||
return caCert, caKey, nil
|
||||
}
|
||||
|
||||
// saveCert saves a DER-encoded certificate to a file in PEM format.
|
||||
func saveCert(filename string, certDER []byte) error {
|
||||
certFile, err := os.Create(filename)
|
||||
if err != nil {
|
||||
@ -539,6 +524,7 @@ func saveCert(filename string, certDER []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveKey saves an RSA private key to a file in PEM format with restricted permissions.
|
||||
func saveKey(filename string, key *rsa.PrivateKey) error {
|
||||
keyFile, err := os.Create(filename)
|
||||
if err != nil {
|
||||
@ -561,3 +547,25 @@ func saveKey(filename string, key *rsa.PrivateKey) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseHosts splits a comma-separated string of hosts into slices of DNS names and IP addresses.
|
||||
func parseHosts(hostList string) ([]string, []net.IP) {
|
||||
var dnsNames []string
|
||||
var ipAddrs []net.IP
|
||||
|
||||
if hostList == "" {
|
||||
return dnsNames, ipAddrs
|
||||
}
|
||||
|
||||
hosts := strings.Split(hostList, ",")
|
||||
for _, h := range hosts {
|
||||
h = strings.TrimSpace(h)
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
ipAddrs = append(ipAddrs, ip)
|
||||
} else {
|
||||
dnsNames = append(dnsNames, h)
|
||||
}
|
||||
}
|
||||
|
||||
return dnsNames, ipAddrs
|
||||
}
|
||||
@ -7,23 +7,26 @@ import (
|
||||
"logwisp/src/internal/version"
|
||||
)
|
||||
|
||||
// VersionCommand handles version display
|
||||
// VersionCommand handles the display of the application's version information.
|
||||
type VersionCommand struct{}
|
||||
|
||||
// NewVersionCommand creates a new version command
|
||||
// NewVersionCommand creates a new version command handler.
|
||||
func NewVersionCommand() *VersionCommand {
|
||||
return &VersionCommand{}
|
||||
}
|
||||
|
||||
// Execute prints the detailed version string to stdout.
|
||||
func (c *VersionCommand) Execute(args []string) error {
|
||||
fmt.Println(version.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Description returns a brief one-line description of the command.
|
||||
func (c *VersionCommand) Description() string {
|
||||
return "Show version information"
|
||||
}
|
||||
|
||||
// Help returns the detailed help text for the command.
|
||||
func (c *VersionCommand) Help() string {
|
||||
return `Version Command - Show LogWisp version information
|
||||
|
||||
|
||||
@ -18,8 +18,10 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// logger is the global logger instance for the application.
|
||||
var logger *log.Logger
|
||||
|
||||
// main is the entry point for the LogWisp application.
|
||||
func main() {
|
||||
// Handle subcommands before any config loading
|
||||
// This prevents flag conflicts with lixenwraith/config
|
||||
@ -185,6 +187,7 @@ func main() {
|
||||
logger.Info("msg", "Shutdown complete")
|
||||
}
|
||||
|
||||
// shutdownLogger gracefully shuts down the global logger.
|
||||
func shutdownLogger() {
|
||||
if logger != nil {
|
||||
if err := logger.Shutdown(2 * time.Second); err != nil {
|
||||
|
||||
@ -8,7 +8,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manages all application output respecting quiet mode
|
||||
// OutputHandler manages all application output, respecting the global quiet mode.
|
||||
type OutputHandler struct {
|
||||
quiet bool
|
||||
mu sync.RWMutex
|
||||
@ -16,10 +16,10 @@ type OutputHandler struct {
|
||||
stderr io.Writer
|
||||
}
|
||||
|
||||
// Global output handler instance
|
||||
// output is the global instance of the OutputHandler.
|
||||
var output *OutputHandler
|
||||
|
||||
// Initializes the global output handler
|
||||
// InitOutputHandler initializes the global output handler.
|
||||
func InitOutputHandler(quiet bool) {
|
||||
output = &OutputHandler{
|
||||
quiet: quiet,
|
||||
@ -28,59 +28,21 @@ func InitOutputHandler(quiet bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// Writes to stdout if not in quiet mode
|
||||
func (o *OutputHandler) Print(format string, args ...any) {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
if !o.quiet {
|
||||
fmt.Fprintf(o.stdout, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Writes to stderr if not in quiet mode
|
||||
func (o *OutputHandler) Error(format string, args ...any) {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
if !o.quiet {
|
||||
fmt.Fprintf(o.stderr, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Writes to stderr and exits (respects quiet mode)
|
||||
func (o *OutputHandler) FatalError(code int, format string, args ...any) {
|
||||
o.Error(format, args...)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// Returns the current quiet mode status
|
||||
func (o *OutputHandler) IsQuiet() bool {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
return o.quiet
|
||||
}
|
||||
|
||||
// Updates quiet mode (useful for testing)
|
||||
func (o *OutputHandler) SetQuiet(quiet bool) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
o.quiet = quiet
|
||||
}
|
||||
|
||||
// Helper functions for global output handler
|
||||
// Print writes to stdout.
|
||||
func Print(format string, args ...any) {
|
||||
if output != nil {
|
||||
output.Print(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error writes to stderr.
|
||||
func Error(format string, args ...any) {
|
||||
if output != nil {
|
||||
output.Error(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// FatalError writes to stderr and exits the application.
|
||||
func FatalError(code int, format string, args ...any) {
|
||||
if output != nil {
|
||||
output.FatalError(code, format, args...)
|
||||
@ -90,3 +52,43 @@ func FatalError(code int, format string, args ...any) {
|
||||
os.Exit(code)
|
||||
}
|
||||
}
|
||||
|
||||
// Print writes a formatted string to stdout if not in quiet mode.
|
||||
func (o *OutputHandler) Print(format string, args ...any) {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
if !o.quiet {
|
||||
fmt.Fprintf(o.stdout, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error writes a formatted string to stderr if not in quiet mode.
|
||||
func (o *OutputHandler) Error(format string, args ...any) {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
if !o.quiet {
|
||||
fmt.Fprintf(o.stderr, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// FatalError writes a formatted string to stderr and exits with the given code.
|
||||
func (o *OutputHandler) FatalError(code int, format string, args ...any) {
|
||||
o.Error(format, args...)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// IsQuiet returns the current quiet mode status.
|
||||
func (o *OutputHandler) IsQuiet() bool {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
return o.quiet
|
||||
}
|
||||
|
||||
// SetQuiet updates the quiet mode status.
|
||||
func (o *OutputHandler) SetQuiet(quiet bool) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
o.quiet = quiet
|
||||
}
|
||||
@ -17,7 +17,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Handles configuration hot reload
|
||||
// ReloadManager handles the configuration hot-reloading functionality.
|
||||
type ReloadManager struct {
|
||||
configPath string
|
||||
service *service.Service
|
||||
@ -35,7 +35,7 @@ type ReloadManager struct {
|
||||
statusReporterMu sync.Mutex
|
||||
}
|
||||
|
||||
// Creates a new reload manager
|
||||
// NewReloadManager creates a new reload manager.
|
||||
func NewReloadManager(configPath string, initialCfg *config.Config, logger *log.Logger) *ReloadManager {
|
||||
return &ReloadManager{
|
||||
configPath: configPath,
|
||||
@ -45,7 +45,7 @@ func NewReloadManager(configPath string, initialCfg *config.Config, logger *log.
|
||||
}
|
||||
}
|
||||
|
||||
// Begins watching for configuration changes
|
||||
// Start bootstraps the initial service and begins watching for configuration changes.
|
||||
func (rm *ReloadManager) Start(ctx context.Context) error {
|
||||
// Bootstrap initial service
|
||||
svc, err := bootstrapService(ctx, rm.cfg)
|
||||
@ -90,7 +90,75 @@ func (rm *ReloadManager) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Monitors configuration changes
|
||||
// Shutdown gracefully stops the reload manager and the currently active service.
|
||||
func (rm *ReloadManager) Shutdown() {
|
||||
rm.logger.Info("msg", "Shutting down reload manager")
|
||||
|
||||
// Stop status reporter
|
||||
rm.stopStatusReporter()
|
||||
|
||||
// Stop watching
|
||||
close(rm.shutdownCh)
|
||||
rm.wg.Wait()
|
||||
|
||||
// Stop config watching
|
||||
if rm.lcfg != nil {
|
||||
rm.lcfg.StopAutoUpdate()
|
||||
}
|
||||
|
||||
// Shutdown current services
|
||||
rm.mu.RLock()
|
||||
currentService := rm.service
|
||||
rm.mu.RUnlock()
|
||||
|
||||
if currentService != nil {
|
||||
rm.logger.Info("msg", "Shutting down service")
|
||||
currentService.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
// GetService returns the currently active service instance in a thread-safe manner.
|
||||
func (rm *ReloadManager) GetService() *service.Service {
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
return rm.service
|
||||
}
|
||||
|
||||
// triggerReload initiates the configuration reload process.
|
||||
func (rm *ReloadManager) triggerReload(ctx context.Context) {
|
||||
// Prevent concurrent reloads
|
||||
rm.reloadingMu.Lock()
|
||||
if rm.isReloading {
|
||||
rm.reloadingMu.Unlock()
|
||||
rm.logger.Debug("msg", "Reload already in progress, skipping")
|
||||
return
|
||||
}
|
||||
rm.isReloading = true
|
||||
rm.reloadingMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
rm.reloadingMu.Lock()
|
||||
rm.isReloading = false
|
||||
rm.reloadingMu.Unlock()
|
||||
}()
|
||||
|
||||
rm.logger.Info("msg", "Starting configuration hot reload")
|
||||
|
||||
// Create reload context with timeout
|
||||
reloadCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := rm.performReload(reloadCtx); err != nil {
|
||||
rm.logger.Error("msg", "Hot reload failed",
|
||||
"error", err,
|
||||
"action", "keeping current configuration and services")
|
||||
return
|
||||
}
|
||||
|
||||
rm.logger.Info("msg", "Configuration hot reload completed successfully")
|
||||
}
|
||||
|
||||
// watchLoop is the main goroutine that monitors for configuration file changes.
|
||||
func (rm *ReloadManager) watchLoop(ctx context.Context) {
|
||||
defer rm.wg.Done()
|
||||
|
||||
@ -144,91 +212,7 @@ func (rm *ReloadManager) watchLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify file permissions for security
|
||||
func verifyFilePermissions(path string) error {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat config file: %w", err)
|
||||
}
|
||||
|
||||
// Extract file mode and system stats
|
||||
mode := info.Mode()
|
||||
stat, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return fmt.Errorf("unable to get file ownership info")
|
||||
}
|
||||
|
||||
// Check ownership - must be current user or root
|
||||
currentUID := uint32(os.Getuid())
|
||||
if stat.Uid != currentUID && stat.Uid != 0 {
|
||||
return fmt.Errorf("config file owned by uid %d, expected %d or 0", stat.Uid, currentUID)
|
||||
}
|
||||
|
||||
// Check permissions - must not be writable by group or other
|
||||
perm := mode.Perm()
|
||||
if perm&0022 != 0 {
|
||||
// Group or other has write permission
|
||||
return fmt.Errorf("insecure permissions %04o - file must not be writable by group/other", perm)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determines if a config change requires service reload
|
||||
func (rm *ReloadManager) shouldReload(path string) bool {
|
||||
// Pipeline changes always require reload
|
||||
if strings.HasPrefix(path, "pipelines.") || path == "pipelines" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Logging changes don't require service reload
|
||||
if strings.HasPrefix(path, "logging.") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Status reporter changes
|
||||
if path == "disable_status_reporter" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Performs the actual reload
|
||||
func (rm *ReloadManager) triggerReload(ctx context.Context) {
|
||||
// Prevent concurrent reloads
|
||||
rm.reloadingMu.Lock()
|
||||
if rm.isReloading {
|
||||
rm.reloadingMu.Unlock()
|
||||
rm.logger.Debug("msg", "Reload already in progress, skipping")
|
||||
return
|
||||
}
|
||||
rm.isReloading = true
|
||||
rm.reloadingMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
rm.reloadingMu.Lock()
|
||||
rm.isReloading = false
|
||||
rm.reloadingMu.Unlock()
|
||||
}()
|
||||
|
||||
rm.logger.Info("msg", "Starting configuration hot reload")
|
||||
|
||||
// Create reload context with timeout
|
||||
reloadCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := rm.performReload(reloadCtx); err != nil {
|
||||
rm.logger.Error("msg", "Hot reload failed",
|
||||
"error", err,
|
||||
"action", "keeping current configuration and services")
|
||||
return
|
||||
}
|
||||
|
||||
rm.logger.Info("msg", "Configuration hot reload completed successfully")
|
||||
}
|
||||
|
||||
// Executes the reload process
|
||||
// performReload executes the steps to validate and apply a new configuration.
|
||||
func (rm *ReloadManager) performReload(ctx context.Context) error {
|
||||
// Get updated config from lconfig
|
||||
updatedCfg, err := rm.lcfg.AsStruct()
|
||||
@ -272,7 +256,57 @@ func (rm *ReloadManager) performReload(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Gracefully shuts down old services
|
||||
// shouldReload determines if a given configuration change requires a full service reload.
|
||||
func (rm *ReloadManager) shouldReload(path string) bool {
|
||||
// Pipeline changes always require reload
|
||||
if strings.HasPrefix(path, "pipelines.") || path == "pipelines" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Logging changes don't require service reload
|
||||
if strings.HasPrefix(path, "logging.") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Status reporter changes
|
||||
if path == "disable_status_reporter" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// verifyFilePermissions checks the ownership and permissions of the config file for security.
|
||||
func verifyFilePermissions(path string) error {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat config file: %w", err)
|
||||
}
|
||||
|
||||
// Extract file mode and system stats
|
||||
mode := info.Mode()
|
||||
stat, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return fmt.Errorf("unable to get file ownership info")
|
||||
}
|
||||
|
||||
// Check ownership - must be current user or root
|
||||
currentUID := uint32(os.Getuid())
|
||||
if stat.Uid != currentUID && stat.Uid != 0 {
|
||||
return fmt.Errorf("config file owned by uid %d, expected %d or 0", stat.Uid, currentUID)
|
||||
}
|
||||
|
||||
// Check permissions - must not be writable by group or other
|
||||
perm := mode.Perm()
|
||||
if perm&0022 != 0 {
|
||||
// Group or other has write permission
|
||||
return fmt.Errorf("insecure permissions %04o - file must not be writable by group/other", perm)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// shutdownOldServices gracefully shuts down the previous service instance after a successful reload.
|
||||
func (rm *ReloadManager) shutdownOldServices(svc *service.Service) {
|
||||
// Give connections time to drain
|
||||
rm.logger.Debug("msg", "Draining connections from old services")
|
||||
@ -286,7 +320,7 @@ func (rm *ReloadManager) shutdownOldServices(svc *service.Service) {
|
||||
rm.logger.Debug("msg", "Old services shutdown complete")
|
||||
}
|
||||
|
||||
// Starts a new status reporter
|
||||
// startStatusReporter starts a new status reporter for service.
|
||||
func (rm *ReloadManager) startStatusReporter(ctx context.Context, svc *service.Service) {
|
||||
rm.statusReporterMu.Lock()
|
||||
defer rm.statusReporterMu.Unlock()
|
||||
@ -299,7 +333,19 @@ func (rm *ReloadManager) startStatusReporter(ctx context.Context, svc *service.S
|
||||
rm.logger.Debug("msg", "Started status reporter")
|
||||
}
|
||||
|
||||
// Stops old and starts new status reporter
|
||||
// stopStatusReporter stops the currently running status reporter.
|
||||
func (rm *ReloadManager) stopStatusReporter() {
|
||||
rm.statusReporterMu.Lock()
|
||||
defer rm.statusReporterMu.Unlock()
|
||||
|
||||
if rm.statusReporterCancel != nil {
|
||||
rm.statusReporterCancel()
|
||||
rm.statusReporterCancel = nil
|
||||
rm.logger.Debug("msg", "Stopped status reporter")
|
||||
}
|
||||
}
|
||||
|
||||
// restartStatusReporter stops the old status reporter and starts a new one.
|
||||
func (rm *ReloadManager) restartStatusReporter(ctx context.Context, newService *service.Service) {
|
||||
if rm.cfg.DisableStatusReporter {
|
||||
// Just stop the old one if disabled
|
||||
@ -323,49 +369,3 @@ func (rm *ReloadManager) restartStatusReporter(ctx context.Context, newService *
|
||||
go statusReporter(newService, reporterCtx)
|
||||
rm.logger.Debug("msg", "Started new status reporter")
|
||||
}
|
||||
|
||||
// Stops the status reporter
|
||||
func (rm *ReloadManager) stopStatusReporter() {
|
||||
rm.statusReporterMu.Lock()
|
||||
defer rm.statusReporterMu.Unlock()
|
||||
|
||||
if rm.statusReporterCancel != nil {
|
||||
rm.statusReporterCancel()
|
||||
rm.statusReporterCancel = nil
|
||||
rm.logger.Debug("msg", "Stopped status reporter")
|
||||
}
|
||||
}
|
||||
|
||||
// Stops the reload manager
|
||||
func (rm *ReloadManager) Shutdown() {
|
||||
rm.logger.Info("msg", "Shutting down reload manager")
|
||||
|
||||
// Stop status reporter
|
||||
rm.stopStatusReporter()
|
||||
|
||||
// Stop watching
|
||||
close(rm.shutdownCh)
|
||||
rm.wg.Wait()
|
||||
|
||||
// Stop config watching
|
||||
if rm.lcfg != nil {
|
||||
rm.lcfg.StopAutoUpdate()
|
||||
}
|
||||
|
||||
// Shutdown current services
|
||||
rm.mu.RLock()
|
||||
currentService := rm.service
|
||||
rm.mu.RUnlock()
|
||||
|
||||
if currentService != nil {
|
||||
rm.logger.Info("msg", "Shutting down service")
|
||||
currentService.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the current service (thread-safe)
|
||||
func (rm *ReloadManager) GetService() *service.Service {
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
return rm.service
|
||||
}
|
||||
@ -10,14 +10,14 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Manages OS signals
|
||||
// SignalHandler manages OS signals for shutdown and configuration reloads.
|
||||
type SignalHandler struct {
|
||||
reloadManager *ReloadManager
|
||||
logger *log.Logger
|
||||
sigChan chan os.Signal
|
||||
}
|
||||
|
||||
// Creates a signal handler
|
||||
// NewSignalHandler creates a new signal handler.
|
||||
func NewSignalHandler(rm *ReloadManager, logger *log.Logger) *SignalHandler {
|
||||
sh := &SignalHandler{
|
||||
reloadManager: rm,
|
||||
@ -36,7 +36,7 @@ func NewSignalHandler(rm *ReloadManager, logger *log.Logger) *SignalHandler {
|
||||
return sh
|
||||
}
|
||||
|
||||
// Processes signals
|
||||
// Handle blocks and processes incoming OS signals.
|
||||
func (sh *SignalHandler) Handle(ctx context.Context) os.Signal {
|
||||
for {
|
||||
select {
|
||||
@ -58,7 +58,7 @@ func (sh *SignalHandler) Handle(ctx context.Context) os.Signal {
|
||||
}
|
||||
}
|
||||
|
||||
// Cleans up signal handling
|
||||
// Stop cleans up the signal handling channel.
|
||||
func (sh *SignalHandler) Stop() {
|
||||
signal.Stop(sh.sigChan)
|
||||
close(sh.sigChan)
|
||||
|
||||
@ -10,7 +10,7 @@ import (
|
||||
"logwisp/src/internal/service"
|
||||
)
|
||||
|
||||
// Periodically logs service status
|
||||
// statusReporter is a goroutine that periodically logs the health and statistics of the service.
|
||||
func statusReporter(service *service.Service, ctx context.Context) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@ -60,55 +60,7 @@ func statusReporter(service *service.Service, ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Logs the status of an individual pipeline
|
||||
func logPipelineStatus(name string, stats map[string]any) {
|
||||
statusFields := []any{
|
||||
"msg", "Pipeline status",
|
||||
"pipeline", name,
|
||||
}
|
||||
|
||||
// Add processing statistics
|
||||
if totalProcessed, ok := stats["total_processed"].(uint64); ok {
|
||||
statusFields = append(statusFields, "entries_processed", totalProcessed)
|
||||
}
|
||||
if totalFiltered, ok := stats["total_filtered"].(uint64); ok {
|
||||
statusFields = append(statusFields, "entries_filtered", totalFiltered)
|
||||
}
|
||||
|
||||
// Add source count
|
||||
if sourceCount, ok := stats["source_count"].(int); ok {
|
||||
statusFields = append(statusFields, "sources", sourceCount)
|
||||
}
|
||||
|
||||
// Add sink statistics
|
||||
if sinks, ok := stats["sinks"].([]map[string]any); ok {
|
||||
tcpConns := int64(0)
|
||||
httpConns := int64(0)
|
||||
|
||||
for _, sink := range sinks {
|
||||
sinkType := sink["type"].(string)
|
||||
if activeConns, ok := sink["active_connections"].(int64); ok {
|
||||
switch sinkType {
|
||||
case "tcp":
|
||||
tcpConns += activeConns
|
||||
case "http":
|
||||
httpConns += activeConns
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tcpConns > 0 {
|
||||
statusFields = append(statusFields, "tcp_connections", tcpConns)
|
||||
}
|
||||
if httpConns > 0 {
|
||||
statusFields = append(statusFields, "http_connections", httpConns)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(statusFields...)
|
||||
}
|
||||
|
||||
// Logs the configured endpoints for a pipeline
|
||||
// displayPipelineEndpoints logs the configured source and sink endpoints for a pipeline at startup.
|
||||
func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
||||
// Display sink endpoints
|
||||
for i, sinkCfg := range cfg.Sinks {
|
||||
@ -257,3 +209,51 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
||||
"filter_count", len(cfg.Filters))
|
||||
}
|
||||
}
|
||||
|
||||
// logPipelineStatus logs the detailed status and statistics of an individual pipeline.
|
||||
func logPipelineStatus(name string, stats map[string]any) {
|
||||
statusFields := []any{
|
||||
"msg", "Pipeline status",
|
||||
"pipeline", name,
|
||||
}
|
||||
|
||||
// Add processing statistics
|
||||
if totalProcessed, ok := stats["total_processed"].(uint64); ok {
|
||||
statusFields = append(statusFields, "entries_processed", totalProcessed)
|
||||
}
|
||||
if totalFiltered, ok := stats["total_filtered"].(uint64); ok {
|
||||
statusFields = append(statusFields, "entries_filtered", totalFiltered)
|
||||
}
|
||||
|
||||
// Add source count
|
||||
if sourceCount, ok := stats["source_count"].(int); ok {
|
||||
statusFields = append(statusFields, "sources", sourceCount)
|
||||
}
|
||||
|
||||
// Add sink statistics
|
||||
if sinks, ok := stats["sinks"].([]map[string]any); ok {
|
||||
tcpConns := int64(0)
|
||||
httpConns := int64(0)
|
||||
|
||||
for _, sink := range sinks {
|
||||
sinkType := sink["type"].(string)
|
||||
if activeConns, ok := sink["active_connections"].(int64); ok {
|
||||
switch sinkType {
|
||||
case "tcp":
|
||||
tcpConns += activeConns
|
||||
case "http":
|
||||
httpConns += activeConns
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tcpConns > 0 {
|
||||
statusFields = append(statusFields, "tcp_connections", tcpConns)
|
||||
}
|
||||
if httpConns > 0 {
|
||||
statusFields = append(statusFields, "http_connections", httpConns)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(statusFields...)
|
||||
}
|
||||
@ -1,213 +0,0 @@
|
||||
// FILE: logwisp/src/internal/auth/authenticator.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Prevent unbounded map growth
|
||||
const maxAuthTrackedIPs = 10000
|
||||
|
||||
// Handles all authentication methods for a pipeline
|
||||
type Authenticator struct {
|
||||
config *config.ServerAuthConfig
|
||||
logger *log.Logger
|
||||
tokens map[string]bool // token -> valid
|
||||
mu sync.RWMutex
|
||||
|
||||
// Session tracking
|
||||
sessions map[string]*Session
|
||||
sessionMu sync.RWMutex
|
||||
}
|
||||
|
||||
// TODO: only one connection per user, token, mtls
|
||||
// TODO: implement tracker logic
|
||||
// Represents an authenticated connection
|
||||
type Session struct {
|
||||
ID string
|
||||
Username string
|
||||
Method string // basic, token, mtls
|
||||
RemoteAddr string
|
||||
CreatedAt time.Time
|
||||
LastActivity time.Time
|
||||
}
|
||||
|
||||
// Creates a new authenticator from config
|
||||
func NewAuthenticator(cfg *config.ServerAuthConfig, logger *log.Logger) (*Authenticator, error) {
|
||||
// SCRAM is handled by ScramManager in sources
|
||||
if cfg == nil || cfg.Type == "none" || cfg.Type == "scram" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
a := &Authenticator{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
tokens: make(map[string]bool),
|
||||
sessions: make(map[string]*Session),
|
||||
}
|
||||
|
||||
// Initialize tokens
|
||||
if cfg.Type == "token" && cfg.Token != nil {
|
||||
for _, token := range cfg.Token.Tokens {
|
||||
a.tokens[token] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Start session cleanup
|
||||
go a.sessionCleanup()
|
||||
|
||||
logger.Info("msg", "Authenticator initialized",
|
||||
"component", "auth",
|
||||
"type", cfg.Type)
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Handles HTTP authentication headers
|
||||
func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) {
|
||||
if a == nil || a.config.Type == "none" {
|
||||
return &Session{
|
||||
ID: generateSessionID(),
|
||||
Method: "none",
|
||||
RemoteAddr: remoteAddr,
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var session *Session
|
||||
var err error
|
||||
|
||||
switch a.config.Type {
|
||||
case "token":
|
||||
session, err = a.authenticateToken(authHeader, remoteAddr)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported auth type: %s", a.config.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) authenticateToken(authHeader, remoteAddr string) (*Session, error) {
|
||||
if !strings.HasPrefix(authHeader, "Token") {
|
||||
return nil, fmt.Errorf("invalid token auth header")
|
||||
}
|
||||
|
||||
token := authHeader[7:]
|
||||
return a.validateToken(token, remoteAddr)
|
||||
}
|
||||
|
||||
func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) {
|
||||
// Check static tokens first
|
||||
a.mu.RLock()
|
||||
isValid := a.tokens[token]
|
||||
a.mu.RUnlock()
|
||||
|
||||
if !isValid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
ID: generateSessionID(),
|
||||
Method: "token",
|
||||
RemoteAddr: remoteAddr,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivity: time.Now(),
|
||||
}
|
||||
a.storeSession(session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) storeSession(session *Session) {
|
||||
a.sessionMu.Lock()
|
||||
a.sessions[session.ID] = session
|
||||
a.sessionMu.Unlock()
|
||||
|
||||
a.logger.Info("msg", "Session created",
|
||||
"component", "auth",
|
||||
"session_id", session.ID,
|
||||
"username", session.Username,
|
||||
"method", session.Method,
|
||||
"remote_addr", session.RemoteAddr)
|
||||
}
|
||||
|
||||
func (a *Authenticator) sessionCleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
a.sessionMu.Lock()
|
||||
now := time.Now()
|
||||
for id, session := range a.sessions {
|
||||
if now.Sub(session.LastActivity) > 30*time.Minute {
|
||||
delete(a.sessions, id)
|
||||
a.logger.Debug("msg", "Session expired",
|
||||
"component", "auth",
|
||||
"session_id", id)
|
||||
}
|
||||
}
|
||||
a.sessionMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to a less secure method if crypto/rand fails
|
||||
return fmt.Sprintf("fallback-%d", time.Now().UnixNano())
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// Checks if a session is still valid
|
||||
func (a *Authenticator) ValidateSession(sessionID string) bool {
|
||||
if a == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
a.sessionMu.RLock()
|
||||
session, exists := a.sessions[sessionID]
|
||||
a.sessionMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Update activity
|
||||
a.sessionMu.Lock()
|
||||
session.LastActivity = time.Now()
|
||||
a.sessionMu.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Returns authentication statistics
|
||||
func (a *Authenticator) GetStats() map[string]any {
|
||||
if a == nil {
|
||||
return map[string]any{"enabled": false}
|
||||
}
|
||||
|
||||
a.sessionMu.RLock()
|
||||
sessionCount := len(a.sessions)
|
||||
a.sessionMu.RUnlock()
|
||||
|
||||
return map[string]any{
|
||||
"enabled": true,
|
||||
"type": a.config.Type,
|
||||
"active_sessions": sessionCount,
|
||||
"static_tokens": len(a.tokens),
|
||||
}
|
||||
}
|
||||
@ -1,106 +0,0 @@
|
||||
// FILE: src/internal/auth/scram_client.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
// Client handles SCRAM client-side authentication
|
||||
type ScramClient struct {
|
||||
Username string
|
||||
Password string
|
||||
|
||||
// Handshake state
|
||||
clientNonce string
|
||||
serverFirst *ServerFirst
|
||||
authMessage string
|
||||
serverKey []byte
|
||||
}
|
||||
|
||||
// NewScramClient creates SCRAM client
|
||||
func NewScramClient(username, password string) *ScramClient {
|
||||
return &ScramClient{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
// StartAuthentication generates ClientFirst message
|
||||
func (c *ScramClient) StartAuthentication() (*ClientFirst, error) {
|
||||
// Generate client nonce
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
c.clientNonce = base64.StdEncoding.EncodeToString(nonce)
|
||||
|
||||
return &ClientFirst{
|
||||
Username: c.Username,
|
||||
ClientNonce: c.clientNonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessServerFirst handles server challenge
|
||||
func (c *ScramClient) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) {
|
||||
c.serverFirst = msg
|
||||
|
||||
// Decode salt
|
||||
salt, err := base64.StdEncoding.DecodeString(msg.Salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid salt encoding: %w", err)
|
||||
}
|
||||
|
||||
// Derive keys using Argon2id
|
||||
saltedPassword := argon2.IDKey([]byte(c.Password), salt,
|
||||
msg.ArgonTime, msg.ArgonMemory, msg.ArgonThreads, 32)
|
||||
|
||||
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||
storedKey := sha256.Sum256(clientKey)
|
||||
|
||||
// Build auth message
|
||||
clientFirstBare := fmt.Sprintf("u=%s,n=%s", c.Username, c.clientNonce)
|
||||
clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce)
|
||||
c.authMessage = clientFirstBare + "," + msg.Marshal() + "," + clientFinalBare
|
||||
|
||||
// Compute client proof
|
||||
clientSignature := computeHMAC(storedKey[:], []byte(c.authMessage))
|
||||
clientProof := xorBytes(clientKey, clientSignature)
|
||||
|
||||
// Store server key for verification
|
||||
c.serverKey = serverKey
|
||||
|
||||
return &ClientFinal{
|
||||
FullNonce: msg.FullNonce,
|
||||
ClientProof: base64.StdEncoding.EncodeToString(clientProof),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyServerFinal validates server signature
|
||||
func (c *ScramClient) VerifyServerFinal(msg *ServerFinal) error {
|
||||
if c.authMessage == "" || c.serverKey == nil {
|
||||
return fmt.Errorf("invalid handshake state")
|
||||
}
|
||||
|
||||
// Compute expected server signature
|
||||
expectedSig := computeHMAC(c.serverKey, []byte(c.authMessage))
|
||||
|
||||
// Decode received signature
|
||||
receivedSig, err := base64.StdEncoding.DecodeString(msg.ServerSignature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid signature encoding: %w", err)
|
||||
}
|
||||
|
||||
// ☢ SECURITY: Constant-time comparison
|
||||
if subtle.ConstantTimeCompare(expectedSig, receivedSig) != 1 {
|
||||
return fmt.Errorf("server authentication failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,108 +0,0 @@
|
||||
// FILE: src/internal/auth/scram_credential.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
// Credential stores SCRAM authentication data
|
||||
type Credential struct {
|
||||
Username string
|
||||
Salt []byte // 16+ bytes
|
||||
ArgonTime uint32 // e.g., 3
|
||||
ArgonMemory uint32 // e.g., 64*1024 KiB
|
||||
ArgonThreads uint8 // e.g., 4
|
||||
StoredKey []byte // SHA256(ClientKey)
|
||||
ServerKey []byte // For server auth
|
||||
PHCHash string
|
||||
}
|
||||
|
||||
// DeriveCredential creates SCRAM credential from password
|
||||
func DeriveCredential(username, password string, salt []byte, time, memory uint32, threads uint8) (*Credential, error) {
|
||||
if len(salt) < 16 {
|
||||
return nil, fmt.Errorf("salt must be at least 16 bytes")
|
||||
}
|
||||
|
||||
// Derive salted password using Argon2id
|
||||
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, core.Argon2KeyLen)
|
||||
|
||||
// Construct PHC format for basic auth compatibility
|
||||
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||
hashB64 := base64.RawStdEncoding.EncodeToString(saltedPassword)
|
||||
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, memory, time, threads, saltB64, hashB64)
|
||||
|
||||
// Derive keys
|
||||
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||
storedKey := sha256.Sum256(clientKey)
|
||||
|
||||
return &Credential{
|
||||
Username: username,
|
||||
Salt: salt,
|
||||
ArgonTime: time,
|
||||
ArgonMemory: memory,
|
||||
ArgonThreads: threads,
|
||||
StoredKey: storedKey[:],
|
||||
ServerKey: serverKey,
|
||||
PHCHash: phcHash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MigrateFromPHC converts existing Argon2 PHC hash to SCRAM credential
|
||||
func MigrateFromPHC(username, password, phcHash string) (*Credential, error) {
|
||||
// Parse PHC: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
|
||||
parts := strings.Split(phcHash, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return nil, fmt.Errorf("invalid PHC format")
|
||||
}
|
||||
|
||||
var memory, time uint32
|
||||
var threads uint8
|
||||
fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads)
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid salt encoding: %w", err)
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid hash encoding: %w", err)
|
||||
}
|
||||
|
||||
// Verify password matches
|
||||
computedHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||
if subtle.ConstantTimeCompare(computedHash, expectedHash) != 1 {
|
||||
return nil, fmt.Errorf("password verification failed")
|
||||
}
|
||||
|
||||
// Now derive SCRAM credential
|
||||
return DeriveCredential(username, password, salt, time, memory, threads)
|
||||
}
|
||||
|
||||
func computeHMAC(key, message []byte) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(message)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func xorBytes(a, b []byte) []byte {
|
||||
if len(a) != len(b) {
|
||||
panic("xor length mismatch")
|
||||
}
|
||||
result := make([]byte, len(a))
|
||||
for i := range a {
|
||||
result[i] = a[i] ^ b[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -1,83 +0,0 @@
|
||||
// FILE: src/internal/auth/scram_manager.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
)
|
||||
|
||||
// ScramManager provides high-level SCRAM operations with rate limiting
|
||||
type ScramManager struct {
|
||||
server *ScramServer
|
||||
}
|
||||
|
||||
// NewScramManager creates SCRAM manager
|
||||
func NewScramManager(scramAuthCfg *config.ScramAuthConfig) *ScramManager {
|
||||
manager := &ScramManager{
|
||||
server: NewScramServer(),
|
||||
}
|
||||
|
||||
// Load users from SCRAM config
|
||||
for _, user := range scramAuthCfg.Users {
|
||||
storedKey, err := base64.StdEncoding.DecodeString(user.StoredKey)
|
||||
if err != nil {
|
||||
// Skip user with invalid stored key
|
||||
continue
|
||||
}
|
||||
|
||||
serverKey, err := base64.StdEncoding.DecodeString(user.ServerKey)
|
||||
if err != nil {
|
||||
// Skip user with invalid server key
|
||||
continue
|
||||
}
|
||||
|
||||
salt, err := base64.StdEncoding.DecodeString(user.Salt)
|
||||
if err != nil {
|
||||
// Skip user with invalid salt
|
||||
continue
|
||||
}
|
||||
|
||||
cred := &Credential{
|
||||
Username: user.Username,
|
||||
StoredKey: storedKey,
|
||||
ServerKey: serverKey,
|
||||
Salt: salt,
|
||||
ArgonTime: user.ArgonTime,
|
||||
ArgonMemory: user.ArgonMemory,
|
||||
ArgonThreads: user.ArgonThreads,
|
||||
}
|
||||
manager.server.AddCredential(cred)
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// RegisterUser creates new user credential
|
||||
func (sm *ScramManager) RegisterUser(username, password string) error {
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("salt generation failed: %w", err)
|
||||
}
|
||||
|
||||
cred, err := DeriveCredential(username, password, salt,
|
||||
sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sm.server.AddCredential(cred)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleClientFirst wraps server's HandleClientFirst
|
||||
func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
|
||||
return sm.server.HandleClientFirst(msg)
|
||||
}
|
||||
|
||||
// HandleClientFinal wraps server's HandleClientFinal
|
||||
func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
|
||||
return sm.server.HandleClientFinal(msg)
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
// FILE: src/internal/auth/scram_message.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ClientFirst initiates authentication
|
||||
type ClientFirst struct {
|
||||
Username string `json:"u"`
|
||||
ClientNonce string `json:"n"`
|
||||
}
|
||||
|
||||
// ServerFirst contains server challenge
|
||||
type ServerFirst struct {
|
||||
FullNonce string `json:"r"` // client_nonce + server_nonce
|
||||
Salt string `json:"s"` // base64
|
||||
ArgonTime uint32 `json:"t"`
|
||||
ArgonMemory uint32 `json:"m"`
|
||||
ArgonThreads uint8 `json:"p"`
|
||||
}
|
||||
|
||||
// ClientFinal contains client proof
|
||||
type ClientFinal struct {
|
||||
FullNonce string `json:"r"`
|
||||
ClientProof string `json:"p"` // base64
|
||||
}
|
||||
|
||||
// ServerFinal contains server signature for mutual auth
|
||||
type ServerFinal struct {
|
||||
ServerSignature string `json:"v"` // base64
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
}
|
||||
|
||||
func (sf *ServerFirst) Marshal() string {
|
||||
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
|
||||
sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads)
|
||||
}
|
||||
@ -1,117 +0,0 @@
|
||||
// FILE: src/internal/auth/scram_protocol.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
)
|
||||
|
||||
// ScramProtocolHandler handles SCRAM message exchange for TCP
|
||||
type ScramProtocolHandler struct {
|
||||
manager *ScramManager
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewScramProtocolHandler creates protocol handler
|
||||
func NewScramProtocolHandler(manager *ScramManager, logger *log.Logger) *ScramProtocolHandler {
|
||||
return &ScramProtocolHandler{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAuthMessage processes a complete auth line from buffer
|
||||
func (sph *ScramProtocolHandler) HandleAuthMessage(line []byte, conn gnet.Conn) (authenticated bool, session *Session, err error) {
|
||||
// Parse SCRAM messages
|
||||
parts := strings.Fields(string(line))
|
||||
if len(parts) < 2 {
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil)
|
||||
return false, nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
switch parts[0] {
|
||||
case "SCRAM-FIRST":
|
||||
// Parse ClientFirst JSON
|
||||
var clientFirst ClientFirst
|
||||
if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil {
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
|
||||
return false, nil, fmt.Errorf("invalid JSON")
|
||||
}
|
||||
|
||||
// Process with SCRAM server
|
||||
serverFirst, err := sph.manager.HandleClientFirst(&clientFirst)
|
||||
if err != nil {
|
||||
// Still send challenge to prevent user enumeration
|
||||
response, _ := json.Marshal(serverFirst)
|
||||
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
// Send ServerFirst challenge
|
||||
response, _ := json.Marshal(serverFirst)
|
||||
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
|
||||
return false, nil, nil // Not authenticated yet
|
||||
|
||||
case "SCRAM-PROOF":
|
||||
// Parse ClientFinal JSON
|
||||
var clientFinal ClientFinal
|
||||
if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil {
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
|
||||
return false, nil, fmt.Errorf("invalid JSON")
|
||||
}
|
||||
|
||||
// Verify proof
|
||||
serverFinal, err := sph.manager.HandleClientFinal(&clientFinal)
|
||||
if err != nil {
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil)
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
// Authentication successful
|
||||
session = &Session{
|
||||
ID: serverFinal.SessionID,
|
||||
Method: "scram-sha-256",
|
||||
RemoteAddr: conn.RemoteAddr().String(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Send ServerFinal with signature
|
||||
response, _ := json.Marshal(serverFinal)
|
||||
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil)
|
||||
|
||||
return true, session, nil
|
||||
|
||||
default:
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil)
|
||||
return false, nil, fmt.Errorf("unknown command: %s", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSCRAMRequest formats a SCRAM protocol message for TCP
|
||||
func FormatSCRAMRequest(command string, data interface{}) (string, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal %s: %w", command, err)
|
||||
}
|
||||
return fmt.Sprintf("%s %s\n", command, jsonData), nil
|
||||
}
|
||||
|
||||
// ParseSCRAMResponse parses a SCRAM protocol response from TCP
|
||||
func ParseSCRAMResponse(response string) (command string, data string, err error) {
|
||||
response = strings.TrimSpace(response)
|
||||
parts := strings.SplitN(response, " ", 2)
|
||||
if len(parts) < 1 {
|
||||
return "", "", fmt.Errorf("empty response")
|
||||
}
|
||||
|
||||
command = parts[0]
|
||||
if len(parts) > 1 {
|
||||
data = parts[1]
|
||||
}
|
||||
return command, data, nil
|
||||
}
|
||||
@ -1,174 +0,0 @@
|
||||
// FILE: src/internal/auth/scram_server.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
// Server handles SCRAM authentication
|
||||
type ScramServer struct {
|
||||
credentials map[string]*Credential
|
||||
handshakes map[string]*HandshakeState
|
||||
mu sync.RWMutex
|
||||
|
||||
// TODO: configurability useful? to be included in config or refactor to use core.const directly for simplicity
|
||||
// Default Argon2 params for new registrations
|
||||
DefaultTime uint32
|
||||
DefaultMemory uint32
|
||||
DefaultThreads uint8
|
||||
}
|
||||
|
||||
// HandshakeState tracks ongoing authentication
|
||||
type HandshakeState struct {
|
||||
Username string
|
||||
ClientNonce string
|
||||
ServerNonce string
|
||||
FullNonce string
|
||||
Credential *Credential
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// NewScramServer creates SCRAM server
|
||||
func NewScramServer() *ScramServer {
|
||||
return &ScramServer{
|
||||
credentials: make(map[string]*Credential),
|
||||
handshakes: make(map[string]*HandshakeState),
|
||||
DefaultTime: core.Argon2Time,
|
||||
DefaultMemory: core.Argon2Memory,
|
||||
DefaultThreads: core.Argon2Threads,
|
||||
}
|
||||
}
|
||||
|
||||
// AddCredential registers user credential
|
||||
func (s *ScramServer) AddCredential(cred *Credential) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.credentials[cred.Username] = cred
|
||||
}
|
||||
|
||||
// HandleClientFirst processes initial auth request
|
||||
func (s *ScramServer) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if user exists
|
||||
cred, exists := s.credentials[msg.Username]
|
||||
if !exists {
|
||||
// Prevent user enumeration - still generate response
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
serverNonce := generateNonce()
|
||||
|
||||
return &ServerFirst{
|
||||
FullNonce: msg.ClientNonce + serverNonce,
|
||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||
ArgonTime: s.DefaultTime,
|
||||
ArgonMemory: s.DefaultMemory,
|
||||
ArgonThreads: s.DefaultThreads,
|
||||
}, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// Generate server nonce
|
||||
serverNonce := generateNonce()
|
||||
fullNonce := msg.ClientNonce + serverNonce
|
||||
|
||||
// Store handshake state
|
||||
state := &HandshakeState{
|
||||
Username: msg.Username,
|
||||
ClientNonce: msg.ClientNonce,
|
||||
ServerNonce: serverNonce,
|
||||
FullNonce: fullNonce,
|
||||
Credential: cred,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.handshakes[fullNonce] = state
|
||||
|
||||
// Cleanup old handshakes
|
||||
s.cleanupHandshakes()
|
||||
|
||||
return &ServerFirst{
|
||||
FullNonce: fullNonce,
|
||||
Salt: base64.StdEncoding.EncodeToString(cred.Salt),
|
||||
ArgonTime: cred.ArgonTime,
|
||||
ArgonMemory: cred.ArgonMemory,
|
||||
ArgonThreads: cred.ArgonThreads,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleClientFinal verifies client proof
|
||||
func (s *ScramServer) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
state, exists := s.handshakes[msg.FullNonce]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("invalid nonce or expired handshake")
|
||||
}
|
||||
defer delete(s.handshakes, msg.FullNonce)
|
||||
|
||||
// Check timeout
|
||||
if time.Since(state.CreatedAt) > 60*time.Second {
|
||||
return nil, fmt.Errorf("handshake timeout")
|
||||
}
|
||||
|
||||
// Decode client proof
|
||||
clientProof, err := base64.StdEncoding.DecodeString(msg.ClientProof)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proof encoding")
|
||||
}
|
||||
|
||||
// Build auth message
|
||||
clientFirstBare := fmt.Sprintf("u=%s,n=%s", state.Username, state.ClientNonce)
|
||||
serverFirst := &ServerFirst{
|
||||
FullNonce: state.FullNonce,
|
||||
Salt: base64.StdEncoding.EncodeToString(state.Credential.Salt),
|
||||
ArgonTime: state.Credential.ArgonTime,
|
||||
ArgonMemory: state.Credential.ArgonMemory,
|
||||
ArgonThreads: state.Credential.ArgonThreads,
|
||||
}
|
||||
clientFinalBare := fmt.Sprintf("r=%s", msg.FullNonce)
|
||||
authMessage := clientFirstBare + "," + serverFirst.Marshal() + "," + clientFinalBare
|
||||
|
||||
// Compute client signature
|
||||
clientSignature := computeHMAC(state.Credential.StoredKey, []byte(authMessage))
|
||||
|
||||
// XOR to get ClientKey
|
||||
clientKey := xorBytes(clientProof, clientSignature)
|
||||
|
||||
// Verify by computing StoredKey
|
||||
computedStoredKey := sha256.Sum256(clientKey)
|
||||
if subtle.ConstantTimeCompare(computedStoredKey[:], state.Credential.StoredKey) != 1 {
|
||||
return nil, fmt.Errorf("authentication failed")
|
||||
}
|
||||
|
||||
// Generate server signature for mutual auth
|
||||
serverSignature := computeHMAC(state.Credential.ServerKey, []byte(authMessage))
|
||||
|
||||
return &ServerFinal{
|
||||
ServerSignature: base64.StdEncoding.EncodeToString(serverSignature),
|
||||
SessionID: generateSessionID(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ScramServer) cleanupHandshakes() {
|
||||
cutoff := time.Now().Add(-60 * time.Second)
|
||||
for nonce, state := range s.handshakes {
|
||||
if state.CreatedAt.Before(cutoff) {
|
||||
delete(s.handshakes, nonce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateNonce() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
@ -3,6 +3,7 @@ package config
|
||||
|
||||
// --- LogWisp Configuration Options ---
|
||||
|
||||
// Config is the top-level configuration structure for the LogWisp application.
|
||||
type Config struct {
|
||||
// Top-level flags for application control
|
||||
Background bool `toml:"background"`
|
||||
@ -26,7 +27,7 @@ type Config struct {
|
||||
|
||||
// --- Logging Options ---
|
||||
|
||||
// Represents logging configuration for LogWisp
|
||||
// LogConfig represents the logging configuration for the LogWisp application itself.
|
||||
type LogConfig struct {
|
||||
// Output mode: "file", "stdout", "stderr", "split", "all", "none"
|
||||
Output string `toml:"output"`
|
||||
@ -41,6 +42,7 @@ type LogConfig struct {
|
||||
Console *LogConsoleConfig `toml:"console"`
|
||||
}
|
||||
|
||||
// LogFileConfig defines settings for file-based application logging.
|
||||
type LogFileConfig struct {
|
||||
// Directory for log files
|
||||
Directory string `toml:"directory"`
|
||||
@ -58,6 +60,7 @@ type LogFileConfig struct {
|
||||
RetentionHours float64 `toml:"retention_hours"`
|
||||
}
|
||||
|
||||
// LogConsoleConfig defines settings for console-based application logging.
|
||||
type LogConsoleConfig struct {
|
||||
// Target for console output: "stdout", "stderr", "split"
|
||||
// "split": info/debug to stdout, warn/error to stderr
|
||||
@ -69,19 +72,19 @@ type LogConsoleConfig struct {
|
||||
|
||||
// --- Pipeline Options ---
|
||||
|
||||
// PipelineConfig defines a complete data flow from sources to sinks.
|
||||
type PipelineConfig struct {
|
||||
Name string `toml:"name"`
|
||||
Sources []SourceConfig `toml:"sources"`
|
||||
RateLimit *RateLimitConfig `toml:"rate_limit"`
|
||||
Filters []FilterConfig `toml:"filters"`
|
||||
Format *FormatConfig `toml:"format"`
|
||||
|
||||
Sinks []SinkConfig `toml:"sinks"`
|
||||
// Auth *ServerAuthConfig `toml:"auth"` // Global auth for pipeline
|
||||
Sinks []SinkConfig `toml:"sinks"`
|
||||
}
|
||||
|
||||
// Common configuration structs used across components
|
||||
|
||||
// NetLimitConfig defines network-level access control and rate limiting rules.
|
||||
type NetLimitConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
MaxConnections int64 `toml:"max_connections"`
|
||||
@ -95,27 +98,37 @@ type NetLimitConfig struct {
|
||||
IPBlacklist []string `toml:"ip_blacklist"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
CertFile string `toml:"cert_file"`
|
||||
KeyFile string `toml:"key_file"`
|
||||
CAFile string `toml:"ca_file"`
|
||||
ServerName string `toml:"server_name"` // for client verification
|
||||
SkipVerify bool `toml:"skip_verify"`
|
||||
// TLSServerConfig defines TLS settings for a server (HTTP Source, HTTP Sink).
|
||||
type TLSServerConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
CertFile string `toml:"cert_file"` // Server's certificate file.
|
||||
KeyFile string `toml:"key_file"` // Server's private key file.
|
||||
ClientAuth bool `toml:"client_auth"` // Enable/disable mTLS.
|
||||
ClientCAFile string `toml:"client_ca_file"` // CA for verifying client certificates.
|
||||
VerifyClientCert bool `toml:"verify_client_cert"` // Require and verify client certs.
|
||||
|
||||
// Client certificate authentication
|
||||
ClientAuth bool `toml:"client_auth"`
|
||||
ClientCAFile string `toml:"client_ca_file"`
|
||||
VerifyClientCert bool `toml:"verify_client_cert"`
|
||||
|
||||
// TLS version constraints
|
||||
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
|
||||
MaxVersion string `toml:"max_version"`
|
||||
|
||||
// Cipher suites (comma-separated list)
|
||||
// Common TLS settings
|
||||
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
|
||||
MaxVersion string `toml:"max_version"`
|
||||
CipherSuites string `toml:"cipher_suites"`
|
||||
}
|
||||
|
||||
// TLSClientConfig defines TLS settings for a client (HTTP Client Sink).
|
||||
type TLSClientConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
ServerCAFile string `toml:"server_ca_file"` // CA for verifying the remote server's certificate.
|
||||
ClientCertFile string `toml:"client_cert_file"` // Client's certificate for mTLS.
|
||||
ClientKeyFile string `toml:"client_key_file"` // Client's private key for mTLS.
|
||||
ServerName string `toml:"server_name"` // For server certificate validation (SNI).
|
||||
InsecureSkipVerify bool `toml:"insecure_skip_verify"` // Use with caution.
|
||||
|
||||
// Common TLS settings
|
||||
MinVersion string `toml:"min_version"`
|
||||
MaxVersion string `toml:"max_version"`
|
||||
CipherSuites string `toml:"cipher_suites"`
|
||||
}
|
||||
|
||||
// HeartbeatConfig defines settings for periodic keep-alive or status messages.
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
IntervalMS int64 `toml:"interval_ms"`
|
||||
@ -124,15 +137,15 @@ type HeartbeatConfig struct {
|
||||
Format string `toml:"format"`
|
||||
}
|
||||
|
||||
// TODO: Future implementation
|
||||
// ClientAuthConfig defines settings for client-side authentication.
|
||||
type ClientAuthConfig struct {
|
||||
Type string `toml:"type"` // "none", "basic", "token", "scram"
|
||||
Username string `toml:"username"`
|
||||
Password string `toml:"password"`
|
||||
Token string `toml:"token"`
|
||||
Type string `toml:"type"` // "none"
|
||||
}
|
||||
|
||||
// --- Source Options ---
|
||||
|
||||
// SourceConfig is a polymorphic struct representing a single data source.
|
||||
type SourceConfig struct {
|
||||
Type string `toml:"type"`
|
||||
|
||||
@ -143,6 +156,7 @@ type SourceConfig struct {
|
||||
TCP *TCPSourceOptions `toml:"tcp,omitempty"`
|
||||
}
|
||||
|
||||
// DirectorySourceOptions defines settings for a directory-based source.
|
||||
type DirectorySourceOptions struct {
|
||||
Path string `toml:"path"`
|
||||
Pattern string `toml:"pattern"` // glob pattern
|
||||
@ -150,10 +164,12 @@ type DirectorySourceOptions struct {
|
||||
Recursive bool `toml:"recursive"` // TODO: implement logic
|
||||
}
|
||||
|
||||
// StdinSourceOptions defines settings for a stdin-based source.
|
||||
type StdinSourceOptions struct {
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
}
|
||||
|
||||
// HTTPSourceOptions defines settings for an HTTP server source.
|
||||
type HTTPSourceOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
@ -163,10 +179,11 @@ type HTTPSourceOptions struct {
|
||||
ReadTimeout int64 `toml:"read_timeout_ms"`
|
||||
WriteTimeout int64 `toml:"write_timeout_ms"`
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
TLS *TLSConfig `toml:"tls"`
|
||||
TLS *TLSServerConfig `toml:"tls"`
|
||||
Auth *ServerAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
// TCPSourceOptions defines settings for a TCP server source.
|
||||
type TCPSourceOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
@ -180,6 +197,7 @@ type TCPSourceOptions struct {
|
||||
|
||||
// --- Sink Options ---
|
||||
|
||||
// SinkConfig is a polymorphic struct representing a single data sink.
|
||||
type SinkConfig struct {
|
||||
Type string `toml:"type"`
|
||||
|
||||
@ -192,12 +210,14 @@ type SinkConfig struct {
|
||||
TCPClient *TCPClientSinkOptions `toml:"tcp_client,omitempty"`
|
||||
}
|
||||
|
||||
// ConsoleSinkOptions defines settings for a console-based sink.
|
||||
type ConsoleSinkOptions struct {
|
||||
Target string `toml:"target"` // "stdout", "stderr", "split"
|
||||
Colorize bool `toml:"colorize"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
}
|
||||
|
||||
// FileSinkOptions defines settings for a file-based sink.
|
||||
type FileSinkOptions struct {
|
||||
Directory string `toml:"directory"`
|
||||
Name string `toml:"name"`
|
||||
@ -209,6 +229,7 @@ type FileSinkOptions struct {
|
||||
FlushInterval int64 `toml:"flush_interval_ms"`
|
||||
}
|
||||
|
||||
// HTTPSinkOptions defines settings for an HTTP server sink.
|
||||
type HTTPSinkOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
@ -218,10 +239,11 @@ type HTTPSinkOptions struct {
|
||||
WriteTimeout int64 `toml:"write_timeout_ms"`
|
||||
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
TLS *TLSConfig `toml:"tls"`
|
||||
TLS *TLSServerConfig `toml:"tls"`
|
||||
Auth *ServerAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
// TCPSinkOptions defines settings for a TCP server sink.
|
||||
type TCPSinkOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
@ -234,6 +256,7 @@ type TCPSinkOptions struct {
|
||||
Auth *ServerAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
// HTTPClientSinkOptions defines settings for an HTTP client sink.
|
||||
type HTTPClientSinkOptions struct {
|
||||
URL string `toml:"url"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
@ -244,10 +267,11 @@ type HTTPClientSinkOptions struct {
|
||||
RetryDelayMS int64 `toml:"retry_delay_ms"`
|
||||
RetryBackoff float64 `toml:"retry_backoff"`
|
||||
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
|
||||
TLS *TLSConfig `toml:"tls"`
|
||||
TLS *TLSClientConfig `toml:"tls"`
|
||||
Auth *ClientAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
// TCPClientSinkOptions defines settings for a TCP client sink.
|
||||
type TCPClientSinkOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
@ -264,7 +288,7 @@ type TCPClientSinkOptions struct {
|
||||
|
||||
// --- Rate Limit Options ---
|
||||
|
||||
// Defines the action to take when a rate limit is exceeded.
|
||||
// RateLimitPolicy defines the action to take when a rate limit is exceeded.
|
||||
type RateLimitPolicy int
|
||||
|
||||
const (
|
||||
@ -274,7 +298,7 @@ const (
|
||||
PolicyDrop
|
||||
)
|
||||
|
||||
// Defines the configuration for pipeline-level rate limiting.
|
||||
// RateLimitConfig defines the configuration for pipeline-level rate limiting.
|
||||
type RateLimitConfig struct {
|
||||
// Rate is the number of log entries allowed per second. Default: 0 (disabled).
|
||||
Rate float64 `toml:"rate"`
|
||||
@ -288,23 +312,27 @@ type RateLimitConfig struct {
|
||||
|
||||
// --- Filter Options ---
|
||||
|
||||
// Represents the filter type
|
||||
// FilterType represents the filter's behavior (include or exclude).
|
||||
type FilterType string
|
||||
|
||||
const (
|
||||
// FilterTypeInclude specifies that only matching logs will pass.
|
||||
FilterTypeInclude FilterType = "include" // Whitelist - only matching logs pass
|
||||
// FilterTypeExclude specifies that matching logs will be dropped.
|
||||
FilterTypeExclude FilterType = "exclude" // Blacklist - matching logs are dropped
|
||||
)
|
||||
|
||||
// Represents how multiple patterns are combined
|
||||
// FilterLogic represents how multiple filter patterns are combined.
|
||||
type FilterLogic string
|
||||
|
||||
const (
|
||||
FilterLogicOr FilterLogic = "or" // Match any pattern
|
||||
// FilterLogicOr specifies that a match on any pattern is sufficient.
|
||||
FilterLogicOr FilterLogic = "or" // Match any pattern
|
||||
// FilterLogicAnd specifies that all patterns must match.
|
||||
FilterLogicAnd FilterLogic = "and" // Match all patterns
|
||||
)
|
||||
|
||||
// Represents filter configuration
|
||||
// FilterConfig represents the configuration for a single filter.
|
||||
type FilterConfig struct {
|
||||
Type FilterType `toml:"type"`
|
||||
Logic FilterLogic `toml:"logic"`
|
||||
@ -313,6 +341,7 @@ type FilterConfig struct {
|
||||
|
||||
// --- Formatter Options ---
|
||||
|
||||
// FormatConfig is a polymorphic struct representing log entry formatting options.
|
||||
type FormatConfig struct {
|
||||
// Format configuration - polymorphic like sources/sinks
|
||||
Type string `toml:"type"` // "json", "txt", "raw"
|
||||
@ -323,6 +352,7 @@ type FormatConfig struct {
|
||||
RawFormatOptions *RawFormatterOptions `toml:"raw,omitempty"`
|
||||
}
|
||||
|
||||
// JSONFormatterOptions defines settings for the JSON formatter.
|
||||
type JSONFormatterOptions struct {
|
||||
Pretty bool `toml:"pretty"`
|
||||
TimestampField string `toml:"timestamp_field"`
|
||||
@ -331,49 +361,21 @@ type JSONFormatterOptions struct {
|
||||
SourceField string `toml:"source_field"`
|
||||
}
|
||||
|
||||
// TxtFormatterOptions defines settings for the text template formatter.
|
||||
type TxtFormatterOptions struct {
|
||||
Template string `toml:"template"`
|
||||
TimestampFormat string `toml:"timestamp_format"`
|
||||
}
|
||||
|
||||
// RawFormatterOptions defines settings for the raw pass-through formatter.
|
||||
type RawFormatterOptions struct {
|
||||
AddNewLine bool `toml:"add_new_line"`
|
||||
}
|
||||
|
||||
// --- Server-side Auth (for sources) ---
|
||||
|
||||
type BasicAuthConfig struct {
|
||||
Users []BasicAuthUser `toml:"users"`
|
||||
Realm string `toml:"realm"`
|
||||
}
|
||||
|
||||
type BasicAuthUser struct {
|
||||
Username string `toml:"username"`
|
||||
PasswordHash string `toml:"password_hash"` // Argon2
|
||||
}
|
||||
|
||||
type ScramAuthConfig struct {
|
||||
Users []ScramUser `toml:"users"`
|
||||
}
|
||||
|
||||
type ScramUser struct {
|
||||
Username string `toml:"username"`
|
||||
StoredKey string `toml:"stored_key"` // base64
|
||||
ServerKey string `toml:"server_key"` // base64
|
||||
Salt string `toml:"salt"` // base64
|
||||
ArgonTime uint32 `toml:"argon_time"`
|
||||
ArgonMemory uint32 `toml:"argon_memory"`
|
||||
ArgonThreads uint8 `toml:"argon_threads"`
|
||||
}
|
||||
|
||||
type TokenAuthConfig struct {
|
||||
Tokens []string `toml:"tokens"`
|
||||
}
|
||||
|
||||
// Server auth wrapper (for sources accepting connections)
|
||||
// TODO: future implementation
|
||||
// ServerAuthConfig defines settings for server-side authentication.
|
||||
type ServerAuthConfig struct {
|
||||
Type string `toml:"type"` // "none", "basic", "token", "scram"
|
||||
Basic *BasicAuthConfig `toml:"basic,omitempty"`
|
||||
Token *TokenAuthConfig `toml:"token,omitempty"`
|
||||
Scram *ScramAuthConfig `toml:"scram,omitempty"`
|
||||
Type string `toml:"type"` // "none"
|
||||
}
|
||||
@ -11,13 +11,66 @@ import (
|
||||
lconfig "github.com/lixenwraith/config"
|
||||
)
|
||||
|
||||
// configManager holds the global instance of the configuration manager.
|
||||
var configManager *lconfig.Config
|
||||
|
||||
// Hot reload access
|
||||
// Load is the single entry point for loading all application configuration.
|
||||
func Load(args []string) (*Config, error) {
|
||||
configPath, isExplicit := resolveConfigPath(args)
|
||||
// Build configuration with all sources
|
||||
|
||||
// Create target config instance that will be populated
|
||||
finalConfig := &Config{}
|
||||
|
||||
// Builder handles loading, populating the target struct, and validation
|
||||
cfg, err := lconfig.NewBuilder().
|
||||
WithTarget(finalConfig). // Typed target struct
|
||||
WithDefaults(defaults()). // Default values
|
||||
WithSources(
|
||||
lconfig.SourceCLI,
|
||||
lconfig.SourceEnv,
|
||||
lconfig.SourceFile,
|
||||
lconfig.SourceDefault,
|
||||
).
|
||||
WithEnvTransform(customEnvTransform). // Convert '.' to '_' in env separation
|
||||
WithEnvPrefix("LOGWISP_"). // Environment variable prefix
|
||||
WithArgs(args). // Command-line arguments
|
||||
WithFile(configPath). // TOML config file
|
||||
WithFileFormat("toml"). // Explicit format
|
||||
WithTypedValidator(ValidateConfig). // Centralized validation
|
||||
WithSecurityOptions(lconfig.SecurityOptions{
|
||||
PreventPathTraversal: true,
|
||||
MaxFileSize: 10 * 1024 * 1024, // 10MB max config
|
||||
}).
|
||||
Build()
|
||||
|
||||
if err != nil {
|
||||
// Handle file not found errors - maintain existing behavior
|
||||
if errors.Is(err, lconfig.ErrConfigNotFound) {
|
||||
if isExplicit {
|
||||
return nil, fmt.Errorf("config file not found: %s", configPath)
|
||||
}
|
||||
// If the default config file is not found, it's not an error, default/cli/env will be used
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to load or validate config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Store the config file path for hot reload
|
||||
finalConfig.ConfigFile = configPath
|
||||
|
||||
// Store the manager for hot reload
|
||||
configManager = cfg
|
||||
|
||||
return finalConfig, nil
|
||||
}
|
||||
|
||||
// GetConfigManager returns the global configuration manager instance for hot-reloading.
|
||||
func GetConfigManager() *lconfig.Config {
|
||||
return configManager
|
||||
}
|
||||
|
||||
// defaults provides the default configuration values for the application.
|
||||
func defaults() *Config {
|
||||
return &Config{
|
||||
// Top-level flag defaults
|
||||
@ -76,58 +129,7 @@ func defaults() *Config {
|
||||
}
|
||||
}
|
||||
|
||||
// Single entry point for loading all configuration
|
||||
func Load(args []string) (*Config, error) {
|
||||
configPath, isExplicit := resolveConfigPath(args)
|
||||
// Build configuration with all sources
|
||||
|
||||
// Create target config instance that will be populated
|
||||
finalConfig := &Config{}
|
||||
|
||||
// Builder handles loading, populating the target struct, and validation
|
||||
cfg, err := lconfig.NewBuilder().
|
||||
WithTarget(finalConfig). // Typed target struct
|
||||
WithDefaults(defaults()). // Default values
|
||||
WithSources(
|
||||
lconfig.SourceCLI,
|
||||
lconfig.SourceEnv,
|
||||
lconfig.SourceFile,
|
||||
lconfig.SourceDefault,
|
||||
).
|
||||
WithEnvTransform(customEnvTransform). // Convert '.' to '_' in env separation
|
||||
WithEnvPrefix("LOGWISP_"). // Environment variable prefix
|
||||
WithArgs(args). // Command-line arguments
|
||||
WithFile(configPath). // TOML config file
|
||||
WithFileFormat("toml"). // Explicit format
|
||||
WithTypedValidator(ValidateConfig). // Centralized validation
|
||||
WithSecurityOptions(lconfig.SecurityOptions{
|
||||
PreventPathTraversal: true,
|
||||
MaxFileSize: 10 * 1024 * 1024, // 10MB max config
|
||||
}).
|
||||
Build()
|
||||
|
||||
if err != nil {
|
||||
// Handle file not found errors - maintain existing behavior
|
||||
if errors.Is(err, lconfig.ErrConfigNotFound) {
|
||||
if isExplicit {
|
||||
return nil, fmt.Errorf("config file not found: %s", configPath)
|
||||
}
|
||||
// If the default config file is not found, it's not an error, default/cli/env will be used
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to load or validate config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Store the config file path for hot reload
|
||||
finalConfig.ConfigFile = configPath
|
||||
|
||||
// Store the manager for hot reload
|
||||
configManager = cfg
|
||||
|
||||
return finalConfig, nil
|
||||
}
|
||||
|
||||
// Returns the configuration file path
|
||||
// resolveConfigPath determines the configuration file path based on CLI args, env vars, and default locations.
|
||||
func resolveConfigPath(args []string) (path string, isExplicit bool) {
|
||||
// 1. Check for --config flag in command-line arguments (highest precedence)
|
||||
for i, arg := range args {
|
||||
@ -163,6 +165,7 @@ func resolveConfigPath(args []string) (path string, isExplicit bool) {
|
||||
return "logwisp.toml", false
|
||||
}
|
||||
|
||||
// customEnvTransform converts TOML-style config paths (e.g., logging.level) to environment variable format (LOGGING_LEVEL).
|
||||
func customEnvTransform(path string) string {
|
||||
env := strings.ReplaceAll(path, ".", "_")
|
||||
env = strings.ToUpper(env)
|
||||
|
||||
@ -11,8 +11,7 @@ import (
|
||||
lconfig "github.com/lixenwraith/config"
|
||||
)
|
||||
|
||||
// validateConfig is the centralized validator for the entire configuration
|
||||
// This replaces the old (c *Config) validate() method
|
||||
// ValidateConfig is the centralized validator for the entire configuration structure.
|
||||
func ValidateConfig(cfg *Config) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config is nil")
|
||||
@ -39,6 +38,7 @@ func ValidateConfig(cfg *Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateLogConfig validates the application's own logging settings.
|
||||
func validateLogConfig(cfg *LogConfig) error {
|
||||
validOutputs := map[string]bool{
|
||||
"file": true, "stdout": true, "stderr": true,
|
||||
@ -74,6 +74,7 @@ func validateLogConfig(cfg *LogConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePipeline validates a single pipeline's configuration.
|
||||
func validatePipeline(index int, p *PipelineConfig, pipelineNames map[string]bool, allPorts map[int64]string) error {
|
||||
// Validate pipeline name
|
||||
if err := lconfig.NonEmpty(p.Name); err != nil {
|
||||
@ -131,7 +132,7 @@ func validatePipeline(index int, p *PipelineConfig, pipelineNames map[string]boo
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateSourceConfig validates typed source configuration
|
||||
// validateSourceConfig validates a polymorphic source configuration.
|
||||
func validateSourceConfig(pipelineName string, index int, s *SourceConfig) error {
|
||||
if err := lconfig.NonEmpty(s.Type); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: missing type", pipelineName, index)
|
||||
@ -186,177 +187,7 @@ func validateSourceConfig(pipelineName string, index int, s *SourceConfig) error
|
||||
}
|
||||
}
|
||||
|
||||
func validateDirectorySource(pipelineName string, index int, opts *DirectorySourceOptions) error {
|
||||
if err := lconfig.NonEmpty(opts.Path); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: directory requires 'path'", pipelineName, index)
|
||||
} else {
|
||||
absPath, err := filepath.Abs(opts.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid path %s: %w", opts.Path, err)
|
||||
}
|
||||
opts.Path = absPath
|
||||
}
|
||||
|
||||
// Check for directory traversal
|
||||
// TODO: traversal check only if optional security settings from cli/env set
|
||||
if strings.Contains(opts.Path, "..") {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: path contains directory traversal", pipelineName, index)
|
||||
}
|
||||
|
||||
// Validate pattern if provided
|
||||
if opts.Pattern != "" {
|
||||
if strings.Count(opts.Pattern, "*") == 0 && strings.Count(opts.Pattern, "?") == 0 {
|
||||
// If no wildcards, ensure valid filename
|
||||
if filepath.Base(opts.Pattern) != opts.Pattern {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: pattern contains path separators", pipelineName, index)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
opts.Pattern = "*"
|
||||
}
|
||||
|
||||
// Validate check interval
|
||||
if opts.CheckIntervalMS < 10 {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: check_interval_ms must be at least 10ms", pipelineName, index)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateStdinSource(pipelineName string, index int, opts *StdinSourceOptions) error {
|
||||
if opts.BufferSize < 0 {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: buffer_size must be positive", pipelineName, index)
|
||||
} else if opts.BufferSize == 0 {
|
||||
opts.BufferSize = 1000
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHTTPSource(pipelineName string, index int, opts *HTTPSourceOptions) error {
|
||||
// Validate port
|
||||
if err := lconfig.Port(opts.Port); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if opts.Host == "" {
|
||||
opts.Host = "0.0.0.0"
|
||||
}
|
||||
if opts.IngestPath == "" {
|
||||
opts.IngestPath = "/ingest"
|
||||
}
|
||||
if opts.MaxRequestBodySize <= 0 {
|
||||
opts.MaxRequestBodySize = 10 * 1024 * 1024 // 10MB default
|
||||
}
|
||||
if opts.ReadTimeout <= 0 {
|
||||
opts.ReadTimeout = 5000 // 5 seconds
|
||||
}
|
||||
if opts.WriteTimeout <= 0 {
|
||||
opts.WriteTimeout = 5000 // 5 seconds
|
||||
}
|
||||
|
||||
// Validate host if specified
|
||||
if opts.Host != "" && opts.Host != "0.0.0.0" {
|
||||
if err := lconfig.IPAddress(opts.Host); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate paths
|
||||
if !strings.HasPrefix(opts.IngestPath, "/") {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: ingest_path must start with /", pipelineName, index)
|
||||
}
|
||||
|
||||
// Validate auth configuration
|
||||
validHTTPSourceAuthTypes := map[string]bool{"basic": true, "token": true, "mtls": true}
|
||||
if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" {
|
||||
if !validHTTPSourceAuthTypes[opts.Auth.Type] {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %s is not a valid auth type",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
// All non-none auth types require TLS for HTTP
|
||||
if opts.TLS == nil || !opts.TLS.Enabled {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %s auth requires TLS to be enabled",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
|
||||
// Validate specific auth types
|
||||
if err := validateServerAuth(pipelineName, opts.Auth); err != nil {
|
||||
return fmt.Errorf("source[%d]: %w", index, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate nested configs
|
||||
if opts.NetLimit != nil {
|
||||
if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if opts.TLS != nil {
|
||||
if err := validateTLS(pipelineName, fmt.Sprintf("source[%d]", index), opts.TLS); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateTCPSource(pipelineName string, index int, opts *TCPSourceOptions) error {
|
||||
// Validate port
|
||||
if err := lconfig.Port(opts.Port); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if opts.Host == "" {
|
||||
opts.Host = "0.0.0.0"
|
||||
}
|
||||
if opts.ReadTimeout <= 0 {
|
||||
opts.ReadTimeout = 5000 // 5 seconds
|
||||
}
|
||||
if !opts.KeepAlive {
|
||||
opts.KeepAlive = true // Default enabled
|
||||
}
|
||||
if opts.KeepAlivePeriod <= 0 {
|
||||
opts.KeepAlivePeriod = 30000 // 30 seconds
|
||||
}
|
||||
|
||||
// Validate host if specified
|
||||
if opts.Host != "" && opts.Host != "0.0.0.0" {
|
||||
if err := lconfig.IPAddress(opts.Host); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TCP source does NOT support TLS
|
||||
// Validate auth configuration - only none and scram are allowed
|
||||
if opts.Auth != nil {
|
||||
switch opts.Auth.Type {
|
||||
case "", "none":
|
||||
// OK
|
||||
case "scram":
|
||||
// SCRAM doesn't require TLS
|
||||
if err := validateServerAuth(pipelineName, opts.Auth); err != nil {
|
||||
return fmt.Errorf("source[%d]: %w", index, err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: TCP source only supports 'none' or 'scram' auth (got '%s')",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate NetLimit if present
|
||||
if opts.NetLimit != nil {
|
||||
if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateSinkConfig validates typed sink configuration
|
||||
// validateSinkConfig validates a polymorphic sink configuration.
|
||||
func validateSinkConfig(pipelineName string, index int, s *SinkConfig, allPorts map[int64]string) error {
|
||||
if err := lconfig.NonEmpty(s.Type); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: missing type", pipelineName, index)
|
||||
@ -423,6 +254,268 @@ func validateSinkConfig(pipelineName string, index int, s *SinkConfig, allPorts
|
||||
}
|
||||
}
|
||||
|
||||
// validateFormatterConfig validates formatter configuration
|
||||
func validateFormatterConfig(p *PipelineConfig) error {
|
||||
if p.Format == nil {
|
||||
p.Format = &FormatConfig{
|
||||
Type: "raw",
|
||||
}
|
||||
} else if p.Format.Type == "" {
|
||||
p.Format.Type = "raw" // Default
|
||||
}
|
||||
|
||||
switch p.Format.Type {
|
||||
|
||||
case "raw":
|
||||
if p.Format.RawFormatOptions == nil {
|
||||
p.Format.RawFormatOptions = &RawFormatterOptions{}
|
||||
}
|
||||
|
||||
case "txt":
|
||||
if p.Format.TxtFormatOptions == nil {
|
||||
p.Format.TxtFormatOptions = &TxtFormatterOptions{}
|
||||
}
|
||||
|
||||
// Default template format
|
||||
templateStr := "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}"
|
||||
if p.Format.TxtFormatOptions.Template != "" {
|
||||
p.Format.TxtFormatOptions.Template = templateStr
|
||||
}
|
||||
|
||||
// Default timestamp format
|
||||
timestampFormat := time.RFC3339
|
||||
if p.Format.TxtFormatOptions.TimestampFormat != "" {
|
||||
p.Format.TxtFormatOptions.TimestampFormat = timestampFormat
|
||||
}
|
||||
|
||||
case "json":
|
||||
if p.Format.JSONFormatOptions == nil {
|
||||
p.Format.JSONFormatOptions = &JSONFormatterOptions{}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRateLimit validates the pipeline-level rate limit settings.
|
||||
func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.Rate < 0 {
|
||||
return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
if cfg.Burst < 0 {
|
||||
return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
if cfg.MaxEntrySizeBytes < 0 {
|
||||
return fmt.Errorf("pipeline '%s': max entry size bytes cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
// Validate policy
|
||||
switch strings.ToLower(cfg.Policy) {
|
||||
case "", "pass", "drop":
|
||||
// Valid policies
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')",
|
||||
pipelineName, cfg.Policy)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFilter validates a single filter's configuration.
|
||||
func validateFilter(pipelineName string, filterIndex int, cfg *FilterConfig) error {
|
||||
// Validate filter type
|
||||
switch cfg.Type {
|
||||
case FilterTypeInclude, FilterTypeExclude, "":
|
||||
// Valid types
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' filter[%d]: invalid type '%s' (must be 'include' or 'exclude')",
|
||||
pipelineName, filterIndex, cfg.Type)
|
||||
}
|
||||
|
||||
// Validate filter logic
|
||||
switch cfg.Logic {
|
||||
case FilterLogicOr, FilterLogicAnd, "":
|
||||
// Valid logic
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' filter[%d]: invalid logic '%s' (must be 'or' or 'and')",
|
||||
pipelineName, filterIndex, cfg.Logic)
|
||||
}
|
||||
|
||||
// Empty patterns is valid - passes everything
|
||||
if len(cfg.Patterns) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate regex patterns
|
||||
for i, pattern := range cfg.Patterns {
|
||||
if _, err := regexp.Compile(pattern); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' filter[%d] pattern[%d] '%s': invalid regex: %w",
|
||||
pipelineName, filterIndex, i, pattern, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDirectorySource validates the settings for a directory source.
|
||||
func validateDirectorySource(pipelineName string, index int, opts *DirectorySourceOptions) error {
|
||||
if err := lconfig.NonEmpty(opts.Path); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: directory requires 'path'", pipelineName, index)
|
||||
} else {
|
||||
absPath, err := filepath.Abs(opts.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid path %s: %w", opts.Path, err)
|
||||
}
|
||||
opts.Path = absPath
|
||||
}
|
||||
|
||||
// Check for directory traversal
|
||||
// TODO: traversal check only if optional security settings from cli/env set
|
||||
if strings.Contains(opts.Path, "..") {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: path contains directory traversal", pipelineName, index)
|
||||
}
|
||||
|
||||
// Validate pattern if provided
|
||||
if opts.Pattern != "" {
|
||||
if strings.Count(opts.Pattern, "*") == 0 && strings.Count(opts.Pattern, "?") == 0 {
|
||||
// If no wildcards, ensure valid filename
|
||||
if filepath.Base(opts.Pattern) != opts.Pattern {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: pattern contains path separators", pipelineName, index)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
opts.Pattern = "*"
|
||||
}
|
||||
|
||||
// Validate check interval
|
||||
if opts.CheckIntervalMS < 10 {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: check_interval_ms must be at least 10ms", pipelineName, index)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStdinSource validates the settings for a stdin source.
|
||||
func validateStdinSource(pipelineName string, index int, opts *StdinSourceOptions) error {
|
||||
if opts.BufferSize < 0 {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: buffer_size must be positive", pipelineName, index)
|
||||
} else if opts.BufferSize == 0 {
|
||||
opts.BufferSize = 1000
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHTTPSource validates the settings for an HTTP source.
|
||||
func validateHTTPSource(pipelineName string, index int, opts *HTTPSourceOptions) error {
|
||||
// Validate port
|
||||
if err := lconfig.Port(opts.Port); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if opts.Host == "" {
|
||||
opts.Host = "0.0.0.0"
|
||||
}
|
||||
if opts.IngestPath == "" {
|
||||
opts.IngestPath = "/ingest"
|
||||
}
|
||||
if opts.MaxRequestBodySize <= 0 {
|
||||
opts.MaxRequestBodySize = 10 * 1024 * 1024 // 10MB default
|
||||
}
|
||||
if opts.ReadTimeout <= 0 {
|
||||
opts.ReadTimeout = 5000 // 5 seconds
|
||||
}
|
||||
if opts.WriteTimeout <= 0 {
|
||||
opts.WriteTimeout = 5000 // 5 seconds
|
||||
}
|
||||
|
||||
// Validate host if specified
|
||||
if opts.Host != "" && opts.Host != "0.0.0.0" {
|
||||
if err := lconfig.IPAddress(opts.Host); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate paths
|
||||
if !strings.HasPrefix(opts.IngestPath, "/") {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: ingest_path must start with /", pipelineName, index)
|
||||
}
|
||||
|
||||
// Validate auth configuration
|
||||
validHTTPSourceAuthTypes := map[string]bool{"basic": true, "token": true, "mtls": true}
|
||||
if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" {
|
||||
if !validHTTPSourceAuthTypes[opts.Auth.Type] {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %s is not a valid auth type",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
// All non-none auth types require TLS for HTTP
|
||||
if opts.TLS == nil || !opts.TLS.Enabled {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %s auth requires TLS to be enabled",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate nested configs
|
||||
if opts.NetLimit != nil {
|
||||
if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if opts.TLS != nil {
|
||||
if err := validateTLSServer(pipelineName, fmt.Sprintf("source[%d]", index), opts.TLS); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTCPSource validates the settings for a TCP source.
|
||||
func validateTCPSource(pipelineName string, index int, opts *TCPSourceOptions) error {
|
||||
// Validate port
|
||||
if err := lconfig.Port(opts.Port); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if opts.Host == "" {
|
||||
opts.Host = "0.0.0.0"
|
||||
}
|
||||
if opts.ReadTimeout <= 0 {
|
||||
opts.ReadTimeout = 5000 // 5 seconds
|
||||
}
|
||||
if !opts.KeepAlive {
|
||||
opts.KeepAlive = true // Default enabled
|
||||
}
|
||||
if opts.KeepAlivePeriod <= 0 {
|
||||
opts.KeepAlivePeriod = 30000 // 30 seconds
|
||||
}
|
||||
|
||||
// Validate host if specified
|
||||
if opts.Host != "" && opts.Host != "0.0.0.0" {
|
||||
if err := lconfig.IPAddress(opts.Host); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: %w", pipelineName, index, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate NetLimit if present
|
||||
if opts.NetLimit != nil {
|
||||
if err := validateNetLimit(pipelineName, fmt.Sprintf("source[%d]", index), opts.NetLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateConsoleSink validates the settings for a console sink.
|
||||
func validateConsoleSink(pipelineName string, index int, opts *ConsoleSinkOptions) error {
|
||||
if opts.BufferSize < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: buffer_size must be positive", pipelineName, index)
|
||||
@ -430,6 +523,7 @@ func validateConsoleSink(pipelineName string, index int, opts *ConsoleSinkOption
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFileSink validates the settings for a file sink.
|
||||
func validateFileSink(pipelineName string, index int, opts *FileSinkOptions) error {
|
||||
if err := lconfig.NonEmpty(opts.Directory); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: file requires 'directory'", pipelineName, index)
|
||||
@ -463,6 +557,7 @@ func validateFileSink(pipelineName string, index int, opts *FileSinkOptions) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHTTPSink validates the settings for an HTTP sink.
|
||||
func validateHTTPSink(pipelineName string, index int, opts *HTTPSinkOptions, allPorts map[int64]string) error {
|
||||
// Validate port
|
||||
if err := lconfig.Port(opts.Port); err != nil {
|
||||
@ -511,7 +606,7 @@ func validateHTTPSink(pipelineName string, index int, opts *HTTPSinkOptions, all
|
||||
}
|
||||
|
||||
if opts.TLS != nil {
|
||||
if err := validateTLS(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil {
|
||||
if err := validateTLSServer(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -519,6 +614,7 @@ func validateHTTPSink(pipelineName string, index int, opts *HTTPSinkOptions, all
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTCPSink validates the settings for a TCP sink.
|
||||
func validateTCPSink(pipelineName string, index int, opts *TCPSinkOptions, allPorts map[int64]string) error {
|
||||
// Validate port
|
||||
if err := lconfig.Port(opts.Port); err != nil {
|
||||
@ -560,6 +656,7 @@ func validateTCPSink(pipelineName string, index int, opts *TCPSinkOptions, allPo
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHTTPClientSink validates the settings for an HTTP client sink.
|
||||
func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSinkOptions) error {
|
||||
// Validate URL
|
||||
if err := lconfig.NonEmpty(opts.URL); err != nil {
|
||||
@ -575,8 +672,6 @@ func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSink
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: URL must use http or https scheme", pipelineName, index)
|
||||
}
|
||||
|
||||
isHTTPS := parsedURL.Scheme == "https"
|
||||
|
||||
// Set defaults for unspecified fields
|
||||
if opts.BufferSize <= 0 {
|
||||
opts.BufferSize = 1000
|
||||
@ -600,57 +695,9 @@ func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSink
|
||||
opts.RetryBackoff = 2.0
|
||||
}
|
||||
|
||||
// Validate auth configuration
|
||||
if opts.Auth != nil {
|
||||
switch opts.Auth.Type {
|
||||
case "basic":
|
||||
if opts.Auth.Username == "" || opts.Auth.Password == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: username and password required for basic auth",
|
||||
pipelineName, index)
|
||||
}
|
||||
if !isHTTPS && !opts.InsecureSkipVerify {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: basic auth requires HTTPS (security: credentials would be sent in plaintext)",
|
||||
pipelineName, index)
|
||||
}
|
||||
|
||||
case "token":
|
||||
if opts.Auth.Token == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: token required for %s auth",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
if !isHTTPS && !opts.InsecureSkipVerify {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: %s auth requires HTTPS (security: token would be sent in plaintext)",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
|
||||
case "mtls":
|
||||
if !isHTTPS {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: mTLS requires HTTPS",
|
||||
pipelineName, index)
|
||||
}
|
||||
// mTLS certs should be in TLS config, not auth config
|
||||
if opts.TLS == nil || opts.TLS.CertFile == "" || opts.TLS.KeyFile == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: cert_file and key_file required in TLS config for mTLS auth",
|
||||
pipelineName, index)
|
||||
}
|
||||
|
||||
case "none", "":
|
||||
// Clear any credentials if auth is "none" or empty
|
||||
if opts.Auth != nil {
|
||||
opts.Auth.Username = ""
|
||||
opts.Auth.Password = ""
|
||||
opts.Auth.Token = ""
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid auth type '%s' (valid: none, basic, token, mtls)",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate TLS config if present
|
||||
if opts.TLS != nil {
|
||||
if err := validateTLS(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil {
|
||||
if err := validateTLSClient(pipelineName, fmt.Sprintf("sink[%d]", index), opts.TLS); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -658,6 +705,7 @@ func validateHTTPClientSink(pipelineName string, index int, opts *HTTPClientSink
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTCPClientSink validates the settings for a TCP client sink.
|
||||
func validateTCPClientSink(pipelineName string, index int, opts *TCPClientSinkOptions) error {
|
||||
// Validate host and port
|
||||
if err := lconfig.NonEmpty(opts.Host); err != nil {
|
||||
@ -694,78 +742,10 @@ func validateTCPClientSink(pipelineName string, index int, opts *TCPClientSinkOp
|
||||
opts.ReconnectBackoff = 1.5
|
||||
}
|
||||
|
||||
// Validate auth configuration
|
||||
if opts.Auth != nil {
|
||||
switch opts.Auth.Type {
|
||||
|
||||
case "scram":
|
||||
if opts.Auth.Username == "" || opts.Auth.Password == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: username and password required for SCRAM auth",
|
||||
pipelineName, index)
|
||||
}
|
||||
// SCRAM doesn't require TLS as it uses challenge-response
|
||||
|
||||
case "none", "":
|
||||
// Clear credentials
|
||||
if opts.Auth != nil {
|
||||
opts.Auth.Username = ""
|
||||
opts.Auth.Password = ""
|
||||
opts.Auth.Token = ""
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid auth type '%s' (valid: none, basic, token, scram, mtls)",
|
||||
pipelineName, index, opts.Auth.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFormatterConfig validates formatter configuration
|
||||
func validateFormatterConfig(p *PipelineConfig) error {
|
||||
if p.Format == nil {
|
||||
p.Format = &FormatConfig{
|
||||
Type: "raw",
|
||||
}
|
||||
} else if p.Format.Type == "" {
|
||||
p.Format.Type = "raw" // Default
|
||||
}
|
||||
|
||||
switch p.Format.Type {
|
||||
|
||||
case "raw":
|
||||
if p.Format.RawFormatOptions == nil {
|
||||
p.Format.RawFormatOptions = &RawFormatterOptions{}
|
||||
}
|
||||
|
||||
case "txt":
|
||||
if p.Format.TxtFormatOptions == nil {
|
||||
p.Format.TxtFormatOptions = &TxtFormatterOptions{}
|
||||
}
|
||||
|
||||
// Default template format
|
||||
templateStr := "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}"
|
||||
if p.Format.TxtFormatOptions.Template != "" {
|
||||
p.Format.TxtFormatOptions.Template = templateStr
|
||||
}
|
||||
|
||||
// Default timestamp format
|
||||
timestampFormat := time.RFC3339
|
||||
if p.Format.TxtFormatOptions.TimestampFormat != "" {
|
||||
p.Format.TxtFormatOptions.TimestampFormat = timestampFormat
|
||||
}
|
||||
|
||||
case "json":
|
||||
if p.Format.JSONFormatOptions == nil {
|
||||
p.Format.JSONFormatOptions = &JSONFormatterOptions{}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper validation functions for nested configs
|
||||
// validateNetLimit validates nested NetLimitConfig settings.
|
||||
func validateNetLimit(pipelineName, location string, nl *NetLimitConfig) error {
|
||||
if !nl.Enabled {
|
||||
return nil // Skip validation if disabled
|
||||
@ -782,21 +762,45 @@ func validateNetLimit(pipelineName, location string, nl *NetLimitConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateTLS(pipelineName, location string, tls *TLSConfig) error {
|
||||
// validateTLSServer validates the new TLSServerConfig struct.
|
||||
func validateTLSServer(pipelineName, location string, tls *TLSServerConfig) error {
|
||||
if !tls.Enabled {
|
||||
return nil // Skip validation if disabled
|
||||
}
|
||||
|
||||
// If TLS enabled, cert and key files required (unless skip verify)
|
||||
if !tls.SkipVerify {
|
||||
if tls.CertFile == "" || tls.KeyFile == "" {
|
||||
return fmt.Errorf("pipeline '%s' %s: TLS enabled requires cert_file and key_file", pipelineName, location)
|
||||
}
|
||||
// If TLS is enabled for a server, cert and key files are mandatory.
|
||||
if tls.CertFile == "" || tls.KeyFile == "" {
|
||||
return fmt.Errorf("pipeline '%s' %s: TLS enabled requires both cert_file and key_file", pipelineName, location)
|
||||
}
|
||||
|
||||
// If mTLS (ClientAuth) is enabled, a client CA file is mandatory.
|
||||
if tls.ClientAuth && tls.ClientCAFile == "" {
|
||||
return fmt.Errorf("pipeline '%s' %s: client_auth is enabled, which requires a client_ca_file", pipelineName, location)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTLSClient validates the new TLSClientConfig struct.
|
||||
func validateTLSClient(pipelineName, location string, tls *TLSClientConfig) error {
|
||||
if !tls.Enabled {
|
||||
return nil // Skip validation if disabled
|
||||
}
|
||||
|
||||
// If verification is not skipped, a server CA file must be provided.
|
||||
if !tls.InsecureSkipVerify && tls.ServerCAFile == "" {
|
||||
return fmt.Errorf("pipeline '%s' %s: TLS verification is enabled (insecure_skip_verify=false) but server_ca_file is not provided", pipelineName, location)
|
||||
}
|
||||
|
||||
// For client mTLS, both the cert and key must be provided together.
|
||||
if (tls.ClientCertFile != "" && tls.ClientKeyFile == "") || (tls.ClientCertFile == "" && tls.ClientKeyFile != "") {
|
||||
return fmt.Errorf("pipeline '%s' %s: for client mTLS, both client_cert_file and client_key_file must be provided", pipelineName, location)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHeartbeat validates nested HeartbeatConfig settings.
|
||||
func validateHeartbeat(pipelineName, location string, hb *HeartbeatConfig) error {
|
||||
if !hb.Enabled {
|
||||
return nil // Skip validation if disabled
|
||||
@ -808,138 +812,3 @@ func validateHeartbeat(pipelineName, location string, hb *HeartbeatConfig) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateServerAuth(pipelineName string, auth *ServerAuthConfig) error {
|
||||
if auth.Type == "" || auth.Type == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count populated auth configs
|
||||
populated := 0
|
||||
var populatedType string
|
||||
|
||||
if auth.Basic != nil {
|
||||
populated++
|
||||
populatedType = "basic"
|
||||
}
|
||||
if auth.Token != nil {
|
||||
populated++
|
||||
populatedType = "token"
|
||||
}
|
||||
if auth.Scram != nil {
|
||||
populated++
|
||||
populatedType = "scram"
|
||||
}
|
||||
|
||||
if populated == 0 {
|
||||
return fmt.Errorf("pipeline '%s': auth type '%s' specified but config missing", pipelineName, auth.Type)
|
||||
}
|
||||
if populated > 1 {
|
||||
return fmt.Errorf("pipeline '%s': multiple auth configurations provided", pipelineName)
|
||||
}
|
||||
if populatedType != auth.Type {
|
||||
return fmt.Errorf("pipeline '%s': auth type mismatch - type is '%s' but config is for '%s'",
|
||||
pipelineName, auth.Type, populatedType)
|
||||
}
|
||||
|
||||
// Validate specific auth type
|
||||
switch auth.Type {
|
||||
case "basic":
|
||||
if len(auth.Basic.Users) == 0 {
|
||||
return fmt.Errorf("pipeline '%s': basic auth requires at least one user", pipelineName)
|
||||
}
|
||||
for i, user := range auth.Basic.Users {
|
||||
if err := lconfig.NonEmpty(user.Username); err != nil {
|
||||
return fmt.Errorf("pipeline '%s': basic auth user[%d] missing username", pipelineName, i)
|
||||
}
|
||||
if err := lconfig.NonEmpty(user.PasswordHash); err != nil {
|
||||
return fmt.Errorf("pipeline '%s': basic auth user[%d] missing password_hash", pipelineName, i)
|
||||
}
|
||||
}
|
||||
case "token":
|
||||
if len(auth.Token.Tokens) == 0 {
|
||||
return fmt.Errorf("pipeline '%s': token auth requires at least one token", pipelineName)
|
||||
}
|
||||
case "scram":
|
||||
if len(auth.Scram.Users) == 0 {
|
||||
return fmt.Errorf("pipeline '%s': scram auth requires at least one user", pipelineName)
|
||||
}
|
||||
for i, user := range auth.Scram.Users {
|
||||
if err := lconfig.NonEmpty(user.Username); err != nil {
|
||||
return fmt.Errorf("pipeline '%s': scram auth user[%d] missing username", pipelineName, i)
|
||||
}
|
||||
// Validate required SCRAM fields
|
||||
if user.StoredKey == "" || user.ServerKey == "" || user.Salt == "" {
|
||||
return fmt.Errorf("pipeline '%s': scram auth user[%d] missing required fields", pipelineName, i)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s': unknown auth type '%s'", pipelineName, auth.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.Rate < 0 {
|
||||
return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
if cfg.Burst < 0 {
|
||||
return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
if cfg.MaxEntrySizeBytes < 0 {
|
||||
return fmt.Errorf("pipeline '%s': max entry size bytes cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
// Validate policy
|
||||
switch strings.ToLower(cfg.Policy) {
|
||||
case "", "pass", "drop":
|
||||
// Valid policies
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')",
|
||||
pipelineName, cfg.Policy)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateFilter(pipelineName string, filterIndex int, cfg *FilterConfig) error {
|
||||
// Validate filter type
|
||||
switch cfg.Type {
|
||||
case FilterTypeInclude, FilterTypeExclude, "":
|
||||
// Valid types
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' filter[%d]: invalid type '%s' (must be 'include' or 'exclude')",
|
||||
pipelineName, filterIndex, cfg.Type)
|
||||
}
|
||||
|
||||
// Validate filter logic
|
||||
switch cfg.Logic {
|
||||
case FilterLogicOr, FilterLogicAnd, "":
|
||||
// Valid logic
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' filter[%d]: invalid logic '%s' (must be 'or' or 'and')",
|
||||
pipelineName, filterIndex, cfg.Logic)
|
||||
}
|
||||
|
||||
// Empty patterns is valid - passes everything
|
||||
if len(cfg.Patterns) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate regex patterns
|
||||
for i, pattern := range cfg.Patterns {
|
||||
if _, err := regexp.Compile(pattern); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' filter[%d] pattern[%d] '%s': invalid regex: %w",
|
||||
pipelineName, filterIndex, i, pattern, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,13 +0,0 @@
|
||||
// FILE: logwisp/src/internal/core/const.go
|
||||
package core
|
||||
|
||||
// Argon2id parameters
|
||||
const (
|
||||
Argon2Time = 3
|
||||
Argon2Memory = 64 * 1024 // 64 MB
|
||||
Argon2Threads = 4
|
||||
Argon2SaltLen = 16
|
||||
Argon2KeyLen = 32
|
||||
)
|
||||
|
||||
const DefaultTokenLength = 32
|
||||
@ -1,4 +1,4 @@
|
||||
// FILE: logwisp/src/internal/core/types.go
|
||||
// FILE: logwisp/src/internal/core/data.go
|
||||
package core
|
||||
|
||||
import (
|
||||
@ -6,6 +6,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const MaxSessionTime = time.Minute * 30
|
||||
|
||||
// Represents a single log record flowing through the pipeline
|
||||
type LogEntry struct {
|
||||
Time time.Time `json:"time"`
|
||||
@ -11,7 +11,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Manages multiple filters in sequence
|
||||
// Chain manages a sequence of filters, applying them in order.
|
||||
type Chain struct {
|
||||
filters []*Filter
|
||||
logger *log.Logger
|
||||
@ -21,7 +21,7 @@ type Chain struct {
|
||||
totalPassed atomic.Uint64
|
||||
}
|
||||
|
||||
// Creates a new filter chain from configurations
|
||||
// NewChain creates a new filter chain from a slice of filter configurations.
|
||||
func NewChain(configs []config.FilterConfig, logger *log.Logger) (*Chain, error) {
|
||||
chain := &Chain{
|
||||
filters: make([]*Filter, 0, len(configs)),
|
||||
@ -42,7 +42,7 @@ func NewChain(configs []config.FilterConfig, logger *log.Logger) (*Chain, error)
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
// Runs all filters in sequence, returns true if the entry passes all filters
|
||||
// Apply runs a log entry through all filters in the chain.
|
||||
func (c *Chain) Apply(entry core.LogEntry) bool {
|
||||
c.totalProcessed.Add(1)
|
||||
|
||||
@ -67,7 +67,7 @@ func (c *Chain) Apply(entry core.LogEntry) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Returns chain statistics
|
||||
// GetStats returns aggregated statistics for the entire chain.
|
||||
func (c *Chain) GetStats() map[string]any {
|
||||
filterStats := make([]map[string]any, len(c.filters))
|
||||
for i, filter := range c.filters {
|
||||
|
||||
@ -13,7 +13,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Applies regex-based filtering to log entries
|
||||
// Filter applies regex-based filtering to log entries.
|
||||
type Filter struct {
|
||||
config config.FilterConfig
|
||||
patterns []*regexp.Regexp
|
||||
@ -26,7 +26,7 @@ type Filter struct {
|
||||
totalDropped atomic.Uint64
|
||||
}
|
||||
|
||||
// Creates a new filter from configuration
|
||||
// NewFilter creates a new filter from a configuration.
|
||||
func NewFilter(cfg config.FilterConfig, logger *log.Logger) (*Filter, error) {
|
||||
// Set defaults
|
||||
if cfg.Type == "" {
|
||||
@ -60,7 +60,7 @@ func NewFilter(cfg config.FilterConfig, logger *log.Logger) (*Filter, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Checks if a log entry should be passed through
|
||||
// Apply determines if a log entry should be passed through based on the filter's rules.
|
||||
func (f *Filter) Apply(entry core.LogEntry) bool {
|
||||
f.totalProcessed.Add(1)
|
||||
|
||||
@ -130,7 +130,44 @@ func (f *Filter) Apply(entry core.LogEntry) bool {
|
||||
return shouldPass
|
||||
}
|
||||
|
||||
// Checks if text matches the patterns according to the logic
|
||||
// GetStats returns the filter's current statistics.
|
||||
func (f *Filter) GetStats() map[string]any {
|
||||
return map[string]any{
|
||||
"type": f.config.Type,
|
||||
"logic": f.config.Logic,
|
||||
"pattern_count": len(f.patterns),
|
||||
"total_processed": f.totalProcessed.Load(),
|
||||
"total_matched": f.totalMatched.Load(),
|
||||
"total_dropped": f.totalDropped.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// UpdatePatterns allows for dynamic, thread-safe updates to the filter's regex patterns.
|
||||
func (f *Filter) UpdatePatterns(patterns []string) error {
|
||||
compiled := make([]*regexp.Regexp, 0, len(patterns))
|
||||
|
||||
// Compile all patterns first
|
||||
for i, pattern := range patterns {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid regex pattern[%d] '%s': %w", i, pattern, err)
|
||||
}
|
||||
compiled = append(compiled, re)
|
||||
}
|
||||
|
||||
// Update atomically
|
||||
f.mu.Lock()
|
||||
f.patterns = compiled
|
||||
f.config.Patterns = patterns
|
||||
f.mu.Unlock()
|
||||
|
||||
f.logger.Info("msg", "Filter patterns updated",
|
||||
"component", "filter",
|
||||
"pattern_count", len(patterns))
|
||||
return nil
|
||||
}
|
||||
|
||||
// matches checks if the given text matches the filter's patterns according to its logic.
|
||||
func (f *Filter) matches(text string) bool {
|
||||
switch f.config.Logic {
|
||||
case config.FilterLogicOr:
|
||||
@ -159,40 +196,3 @@ func (f *Filter) matches(text string) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Returns filter statistics
|
||||
func (f *Filter) GetStats() map[string]any {
|
||||
return map[string]any{
|
||||
"type": f.config.Type,
|
||||
"logic": f.config.Logic,
|
||||
"pattern_count": len(f.patterns),
|
||||
"total_processed": f.totalProcessed.Load(),
|
||||
"total_matched": f.totalMatched.Load(),
|
||||
"total_dropped": f.totalDropped.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allows dynamic pattern updates
|
||||
func (f *Filter) UpdatePatterns(patterns []string) error {
|
||||
compiled := make([]*regexp.Regexp, 0, len(patterns))
|
||||
|
||||
// Compile all patterns first
|
||||
for i, pattern := range patterns {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid regex pattern[%d] '%s': %w", i, pattern, err)
|
||||
}
|
||||
compiled = append(compiled, re)
|
||||
}
|
||||
|
||||
// Update atomically
|
||||
f.mu.Lock()
|
||||
f.patterns = compiled
|
||||
f.config.Patterns = patterns
|
||||
f.mu.Unlock()
|
||||
|
||||
f.logger.Info("msg", "Filter patterns updated",
|
||||
"component", "filter",
|
||||
"pattern_count", len(patterns))
|
||||
return nil
|
||||
}
|
||||
@ -10,16 +10,16 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Defines the interface for transforming a LogEntry into a byte slice.
|
||||
// Formatter defines the interface for transforming a LogEntry into a byte slice.
|
||||
type Formatter interface {
|
||||
// Format takes a LogEntry and returns the formatted log as a byte slice.
|
||||
Format(entry core.LogEntry) ([]byte, error)
|
||||
|
||||
// Name returns the formatter type name
|
||||
// Name returns the formatter's type name (e.g., "json", "raw").
|
||||
Name() string
|
||||
}
|
||||
|
||||
// Creates a new Formatter based on the provided configuration.
|
||||
// NewFormatter is a factory function that creates a Formatter based on the provided configuration.
|
||||
func NewFormatter(cfg *config.FormatConfig, logger *log.Logger) (Formatter, error) {
|
||||
switch cfg.Type {
|
||||
case "json":
|
||||
|
||||
@ -12,13 +12,13 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Produces structured JSON logs
|
||||
// JSONFormatter produces structured JSON logs from LogEntry objects.
|
||||
type JSONFormatter struct {
|
||||
config *config.JSONFormatterOptions
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a new JSON formatter
|
||||
// NewJSONFormatter creates a new JSON formatter from configuration options.
|
||||
func NewJSONFormatter(opts *config.JSONFormatterOptions, logger *log.Logger) (*JSONFormatter, error) {
|
||||
f := &JSONFormatter{
|
||||
config: opts,
|
||||
@ -28,7 +28,7 @@ func NewJSONFormatter(opts *config.JSONFormatterOptions, logger *log.Logger) (*J
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Formats the log entry as JSON
|
||||
// Format transforms a single LogEntry into a JSON byte slice.
|
||||
func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
// Start with a clean map
|
||||
output := make(map[string]any)
|
||||
@ -92,13 +92,12 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
return append(result, '\n'), nil
|
||||
}
|
||||
|
||||
// Returns the formatter name
|
||||
// Name returns the formatter's type name.
|
||||
func (f *JSONFormatter) Name() string {
|
||||
return "json"
|
||||
}
|
||||
|
||||
// Formats multiple entries as a JSON array
|
||||
// This is a special method for sinks that need to batch entries
|
||||
// FormatBatch transforms a slice of LogEntry objects into a single JSON array byte slice.
|
||||
func (f *JSONFormatter) FormatBatch(entries []core.LogEntry) ([]byte, error) {
|
||||
// For batching, we need to create an array of formatted objects
|
||||
batch := make([]json.RawMessage, 0, len(entries))
|
||||
|
||||
@ -8,13 +8,13 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Outputs the log message as-is with a newline
|
||||
// RawFormatter outputs the raw log message, optionally with a newline.
|
||||
type RawFormatter struct {
|
||||
config *config.RawFormatterOptions
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a new raw formatter
|
||||
// NewRawFormatter creates a new raw pass-through formatter.
|
||||
func NewRawFormatter(cfg *config.RawFormatterOptions, logger *log.Logger) (*RawFormatter, error) {
|
||||
return &RawFormatter{
|
||||
config: cfg,
|
||||
@ -22,7 +22,7 @@ func NewRawFormatter(cfg *config.RawFormatterOptions, logger *log.Logger) (*RawF
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Returns the message with a newline appended
|
||||
// Format returns the raw message from the LogEntry as a byte slice.
|
||||
func (f *RawFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
// TODO: Standardize not to add "\n" when processing raw, check lixenwraith/log for consistency
|
||||
if f.config.AddNewLine {
|
||||
@ -32,7 +32,7 @@ func (f *RawFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the formatter name
|
||||
// Name returns the formatter's type name.
|
||||
func (f *RawFormatter) Name() string {
|
||||
return "raw"
|
||||
}
|
||||
@ -14,14 +14,14 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Produces human-readable text logs using templates
|
||||
// TxtFormatter produces human-readable, template-based text logs.
|
||||
type TxtFormatter struct {
|
||||
config *config.TxtFormatterOptions
|
||||
template *template.Template
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a new text formatter
|
||||
// NewTxtFormatter creates a new text formatter from a template configuration.
|
||||
func NewTxtFormatter(opts *config.TxtFormatterOptions, logger *log.Logger) (*TxtFormatter, error) {
|
||||
f := &TxtFormatter{
|
||||
config: opts,
|
||||
@ -47,7 +47,7 @@ func NewTxtFormatter(opts *config.TxtFormatterOptions, logger *log.Logger) (*Txt
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Formats the log entry using the template
|
||||
// Format transforms a LogEntry into a text byte slice using the configured template.
|
||||
func (f *TxtFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
// Prepare data for template
|
||||
data := map[string]any{
|
||||
@ -91,7 +91,7 @@ func (f *TxtFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Returns the formatter name
|
||||
// Name returns the formatter's type name.
|
||||
func (f *TxtFormatter) Name() string {
|
||||
return "txt"
|
||||
}
|
||||
@ -14,7 +14,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// DenialReason indicates why a request was denied
|
||||
// DenialReason indicates why a network request was denied.
|
||||
type DenialReason string
|
||||
|
||||
// ** THIS PROGRAM IS IPV4 ONLY !!**
|
||||
@ -32,7 +32,7 @@ const (
|
||||
ReasonInvalidIP DenialReason = "Invalid IP address"
|
||||
)
|
||||
|
||||
// NetLimiter manages net limiting for a transport
|
||||
// NetLimiter manages network-level limiting including ACLs, rate limits, and connection counts.
|
||||
type NetLimiter struct {
|
||||
config *config.NetLimitConfig
|
||||
logger *log.Logger
|
||||
@ -75,20 +75,21 @@ type NetLimiter struct {
|
||||
cleanupDone chan struct{}
|
||||
}
|
||||
|
||||
// ipLimiter holds the rate limiting and activity state for a single IP address.
|
||||
type ipLimiter struct {
|
||||
bucket *TokenBucket
|
||||
lastSeen time.Time
|
||||
connections atomic.Int64
|
||||
}
|
||||
|
||||
// Connection tracking with activity timestamp
|
||||
// connTracker tracks active connections and their last activity.
|
||||
type connTracker struct {
|
||||
connections atomic.Int64
|
||||
lastSeen time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Creates a new net limiter
|
||||
// NewNetLimiter creates a new network limiter from configuration.
|
||||
func NewNetLimiter(cfg *config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
@ -145,120 +146,7 @@ func NewNetLimiter(cfg *config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
return l
|
||||
}
|
||||
|
||||
// parseIPLists parses and validates IP whitelist/blacklist
|
||||
func (l *NetLimiter) parseIPLists() {
|
||||
// Parse whitelist
|
||||
for _, entry := range l.config.IPWhitelist {
|
||||
if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil {
|
||||
l.ipWhitelist = append(l.ipWhitelist, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse blacklist
|
||||
for _, entry := range l.config.IPBlacklist {
|
||||
if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil {
|
||||
l.ipBlacklist = append(l.ipBlacklist, ipNet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseIPEntry parses a single IP or CIDR entry
|
||||
func (l *NetLimiter) parseIPEntry(entry, listType string) *net.IPNet {
|
||||
// Handle single IP
|
||||
if !strings.Contains(entry, "/") {
|
||||
ip := net.ParseIP(entry)
|
||||
if ip == nil {
|
||||
l.logger.Warn("msg", "Invalid IP entry",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reject IPv6
|
||||
if ip.To4() == nil {
|
||||
l.logger.Warn("msg", "IPv6 address rejected",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry,
|
||||
"reason", IPv4Only)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &net.IPNet{IP: ip.To4(), Mask: net.CIDRMask(32, 32)}
|
||||
}
|
||||
|
||||
// Parse CIDR
|
||||
ipAddr, ipNet, err := net.ParseCIDR(entry)
|
||||
if err != nil {
|
||||
l.logger.Warn("msg", "Invalid CIDR entry",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry,
|
||||
"error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reject IPv6 CIDR
|
||||
if ipAddr.To4() == nil {
|
||||
l.logger.Warn("msg", "IPv6 CIDR rejected",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry,
|
||||
"reason", IPv4Only)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure mask is IPv4
|
||||
_, bits := ipNet.Mask.Size()
|
||||
if bits != 32 {
|
||||
l.logger.Warn("msg", "Non-IPv4 CIDR mask rejected",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry,
|
||||
"mask_bits", bits,
|
||||
"reason", IPv4Only)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &net.IPNet{IP: ipAddr.To4(), Mask: ipNet.Mask}
|
||||
}
|
||||
|
||||
// checkIPAccess checks if an IP is allowed by ACLs
|
||||
func (l *NetLimiter) checkIPAccess(ip net.IP) DenialReason {
|
||||
// 1. Check blacklist first (deny takes precedence)
|
||||
for _, ipNet := range l.ipBlacklist {
|
||||
if ipNet.Contains(ip) {
|
||||
l.blockedByBlacklist.Add(1)
|
||||
l.logger.Debug("msg", "IP denied by blacklist",
|
||||
"component", "netlimit",
|
||||
"ip", ip.String(),
|
||||
"rule", ipNet.String())
|
||||
return ReasonBlacklisted
|
||||
}
|
||||
}
|
||||
|
||||
// 2. If whitelist is configured, IP must be in it
|
||||
if len(l.ipWhitelist) > 0 {
|
||||
for _, ipNet := range l.ipWhitelist {
|
||||
if ipNet.Contains(ip) {
|
||||
l.logger.Debug("msg", "IP allowed by whitelist",
|
||||
"component", "netlimit",
|
||||
"ip", ip.String(),
|
||||
"rule", ipNet.String())
|
||||
return ReasonAllowed
|
||||
}
|
||||
}
|
||||
l.blockedByWhitelist.Add(1)
|
||||
l.logger.Debug("msg", "IP not in whitelist",
|
||||
"component", "netlimit",
|
||||
"ip", ip.String())
|
||||
return ReasonNotWhitelisted
|
||||
}
|
||||
|
||||
return ReasonAllowed
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the net limiter's background cleanup processes.
|
||||
func (l *NetLimiter) Shutdown() {
|
||||
if l == nil {
|
||||
return
|
||||
@ -278,7 +166,7 @@ func (l *NetLimiter) Shutdown() {
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if an HTTP request should be allowed: IP access control + connection limits (IP only) + calls
|
||||
// CheckHTTP checks if an incoming HTTP request is allowed based on all configured limits.
|
||||
func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int64, message string) {
|
||||
if l == nil {
|
||||
return true, 0, ""
|
||||
@ -361,20 +249,7 @@ func (l *NetLimiter) CheckHTTP(remoteAddr string) (allowed bool, statusCode int6
|
||||
return true, 0, ""
|
||||
}
|
||||
|
||||
// Update connection activity
|
||||
func (l *NetLimiter) updateConnectionActivity(ip string) {
|
||||
l.connMu.RLock()
|
||||
tracker, exists := l.ipConnections[ip]
|
||||
l.connMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if a TCP connection should be allowed: IP access control + calls checkIPLimit()
|
||||
// CheckTCP checks if an incoming TCP connection is allowed based on ACLs and rate limits.
|
||||
func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
|
||||
if l == nil {
|
||||
return true
|
||||
@ -422,11 +297,7 @@ func (l *NetLimiter) CheckTCP(remoteAddr net.Addr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func isIPv4(ip net.IP) bool {
|
||||
return ip.To4() != nil
|
||||
}
|
||||
|
||||
// Tracks a new connection for an IP
|
||||
// AddConnection tracks a new connection from a specific remote address (for HTTP).
|
||||
func (l *NetLimiter) AddConnection(remoteAddr string) {
|
||||
if l == nil {
|
||||
return
|
||||
@ -477,7 +348,7 @@ func (l *NetLimiter) AddConnection(remoteAddr string) {
|
||||
"connections", newCount)
|
||||
}
|
||||
|
||||
// Removes a connection for an IP
|
||||
// RemoveConnection removes a tracked connection (for HTTP).
|
||||
func (l *NetLimiter) RemoveConnection(remoteAddr string) {
|
||||
if l == nil {
|
||||
return
|
||||
@ -527,7 +398,113 @@ func (l *NetLimiter) RemoveConnection(remoteAddr string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns net limiter statistics
|
||||
// TrackConnection checks connection limits and tracks a new connection (for TCP).
|
||||
func (l *NetLimiter) TrackConnection(ip string, user string, token string) bool {
|
||||
if l == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
l.connMu.Lock()
|
||||
defer l.connMu.Unlock()
|
||||
|
||||
// Check total connections limit (0 = disabled)
|
||||
if l.config.MaxConnectionsTotal > 0 {
|
||||
currentTotal := l.totalConnections.Load()
|
||||
if currentTotal >= l.config.MaxConnectionsTotal {
|
||||
l.blockedByConnLimit.Add(1)
|
||||
l.logger.Debug("msg", "TCP connection blocked by total limit",
|
||||
"component", "netlimit",
|
||||
"current_total", currentTotal,
|
||||
"max_connections_total", l.config.MaxConnectionsTotal)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check per-IP connection limit (0 = disabled)
|
||||
if l.config.MaxConnectionsPerIP > 0 && ip != "" {
|
||||
tracker, exists := l.ipConnections[ip]
|
||||
if !exists {
|
||||
tracker = &connTracker{lastSeen: time.Now()}
|
||||
l.ipConnections[ip] = tracker
|
||||
}
|
||||
if tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
|
||||
l.blockedByConnLimit.Add(1)
|
||||
l.logger.Debug("msg", "TCP connection blocked by IP limit",
|
||||
"component", "netlimit",
|
||||
"ip", ip,
|
||||
"current", tracker.connections.Load(),
|
||||
"max", l.config.MaxConnectionsPerIP)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// All checks passed, increment counters
|
||||
l.totalConnections.Add(1)
|
||||
|
||||
if ip != "" && l.config.MaxConnectionsPerIP > 0 {
|
||||
if tracker, exists := l.ipConnections[ip]; exists {
|
||||
tracker.connections.Add(1)
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ReleaseConnection decrements connection counters when a connection is closed (for TCP).
|
||||
func (l *NetLimiter) ReleaseConnection(ip string, user string, token string) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.connMu.Lock()
|
||||
defer l.connMu.Unlock()
|
||||
|
||||
// Decrement total
|
||||
if l.totalConnections.Load() > 0 {
|
||||
l.totalConnections.Add(-1)
|
||||
}
|
||||
|
||||
// Decrement IP counter
|
||||
if ip != "" {
|
||||
if tracker, exists := l.ipConnections[ip]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement user counter
|
||||
if user != "" {
|
||||
if tracker, exists := l.userConnections[user]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement token counter
|
||||
if token != "" {
|
||||
if tracker, exists := l.tokenConnections[token]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns a map of the net limiter's current statistics.
|
||||
func (l *NetLimiter) GetStats() map[string]any {
|
||||
if l == nil {
|
||||
return map[string]any{"enabled": false}
|
||||
@ -613,53 +590,26 @@ func (l *NetLimiter) GetStats() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// Performs IP net limit check (req/sec)
|
||||
func (l *NetLimiter) checkIPLimit(ip string) bool {
|
||||
// Validate IP format
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil || !isIPv4(parsedIP) {
|
||||
l.logger.Warn("msg", "Invalid or non-IPv4 address in rate limiter",
|
||||
"component", "netlimit",
|
||||
"ip", ip)
|
||||
return false
|
||||
}
|
||||
// cleanupLoop runs a periodic cleanup of stale limiter and tracker entries.
|
||||
func (l *NetLimiter) cleanupLoop() {
|
||||
defer close(l.cleanupDone)
|
||||
|
||||
// Maybe run cleanup
|
||||
l.maybeCleanup()
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
// IP limit
|
||||
l.ipMu.Lock()
|
||||
lim, exists := l.ipLimiters[ip]
|
||||
if !exists {
|
||||
// Create new limiter for this IP
|
||||
lim = &ipLimiter{
|
||||
bucket: NewTokenBucket(
|
||||
float64(l.config.BurstSize),
|
||||
l.config.RequestsPerSecond,
|
||||
),
|
||||
lastSeen: time.Now(),
|
||||
for {
|
||||
select {
|
||||
case <-l.ctx.Done():
|
||||
// Exit when context is cancelled
|
||||
l.logger.Debug("msg", "Cleanup loop stopping", "component", "netlimit")
|
||||
return
|
||||
case <-ticker.C:
|
||||
l.cleanup()
|
||||
}
|
||||
l.ipLimiters[ip] = lim
|
||||
l.uniqueIPs.Add(1)
|
||||
|
||||
l.logger.Debug("msg", "Created new IP limiter",
|
||||
"ip", ip,
|
||||
"total_ips", l.uniqueIPs.Load())
|
||||
} else {
|
||||
lim.lastSeen = time.Now()
|
||||
}
|
||||
l.ipMu.Unlock()
|
||||
|
||||
// Rate limit check
|
||||
allowed := lim.bucket.Allow()
|
||||
if !allowed {
|
||||
l.blockedByRateLimit.Add(1)
|
||||
}
|
||||
|
||||
return allowed
|
||||
}
|
||||
|
||||
// Runs cleanup if enough time has passed
|
||||
// maybeCleanup triggers an asynchronous cleanup if enough time has passed since the last one.
|
||||
func (l *NetLimiter) maybeCleanup() {
|
||||
l.cleanupMu.Lock()
|
||||
|
||||
@ -685,7 +635,7 @@ func (l *NetLimiter) maybeCleanup() {
|
||||
}()
|
||||
}
|
||||
|
||||
// Removes stale IP limiters
|
||||
// cleanup removes stale IP limiters and connection trackers from memory.
|
||||
func (l *NetLimiter) cleanup() {
|
||||
staleTimeout := 5 * time.Minute
|
||||
now := time.Now()
|
||||
@ -767,127 +717,180 @@ func (l *NetLimiter) cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// Runs periodic cleanup
|
||||
func (l *NetLimiter) cleanupLoop() {
|
||||
defer close(l.cleanupDone)
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.ctx.Done():
|
||||
// Exit when context is cancelled
|
||||
l.logger.Debug("msg", "Cleanup loop stopping", "component", "netlimit")
|
||||
return
|
||||
case <-ticker.C:
|
||||
l.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tracks a new connection with optional user/token info: Connection limits (IP/user/token/total) for TCP only
|
||||
func (l *NetLimiter) TrackConnection(ip string, user string, token string) bool {
|
||||
if l == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
l.connMu.Lock()
|
||||
defer l.connMu.Unlock()
|
||||
|
||||
// Check total connections limit (0 = disabled)
|
||||
if l.config.MaxConnectionsTotal > 0 {
|
||||
currentTotal := l.totalConnections.Load()
|
||||
if currentTotal >= l.config.MaxConnectionsTotal {
|
||||
l.blockedByConnLimit.Add(1)
|
||||
l.logger.Debug("msg", "TCP connection blocked by total limit",
|
||||
// checkIPAccess verifies if an IP address is permitted by the configured ACLs.
|
||||
func (l *NetLimiter) checkIPAccess(ip net.IP) DenialReason {
|
||||
// 1. Check blacklist first (deny takes precedence)
|
||||
for _, ipNet := range l.ipBlacklist {
|
||||
if ipNet.Contains(ip) {
|
||||
l.blockedByBlacklist.Add(1)
|
||||
l.logger.Debug("msg", "IP denied by blacklist",
|
||||
"component", "netlimit",
|
||||
"current_total", currentTotal,
|
||||
"max_connections_total", l.config.MaxConnectionsTotal)
|
||||
return false
|
||||
"ip", ip.String(),
|
||||
"rule", ipNet.String())
|
||||
return ReasonBlacklisted
|
||||
}
|
||||
}
|
||||
|
||||
// Check per-IP connection limit (0 = disabled)
|
||||
if l.config.MaxConnectionsPerIP > 0 && ip != "" {
|
||||
tracker, exists := l.ipConnections[ip]
|
||||
if !exists {
|
||||
tracker = &connTracker{lastSeen: time.Now()}
|
||||
l.ipConnections[ip] = tracker
|
||||
// 2. If whitelist is configured, IP must be in it
|
||||
if len(l.ipWhitelist) > 0 {
|
||||
for _, ipNet := range l.ipWhitelist {
|
||||
if ipNet.Contains(ip) {
|
||||
l.logger.Debug("msg", "IP allowed by whitelist",
|
||||
"component", "netlimit",
|
||||
"ip", ip.String(),
|
||||
"rule", ipNet.String())
|
||||
return ReasonAllowed
|
||||
}
|
||||
}
|
||||
if tracker.connections.Load() >= l.config.MaxConnectionsPerIP {
|
||||
l.blockedByConnLimit.Add(1)
|
||||
l.logger.Debug("msg", "TCP connection blocked by IP limit",
|
||||
l.blockedByWhitelist.Add(1)
|
||||
l.logger.Debug("msg", "IP not in whitelist",
|
||||
"component", "netlimit",
|
||||
"ip", ip.String())
|
||||
return ReasonNotWhitelisted
|
||||
}
|
||||
|
||||
return ReasonAllowed
|
||||
}
|
||||
|
||||
// checkIPLimit enforces the requests-per-second limit for a given IP address.
|
||||
func (l *NetLimiter) checkIPLimit(ip string) bool {
|
||||
// Validate IP format
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil || !isIPv4(parsedIP) {
|
||||
l.logger.Warn("msg", "Invalid or non-IPv4 address in rate limiter",
|
||||
"component", "netlimit",
|
||||
"ip", ip)
|
||||
return false
|
||||
}
|
||||
|
||||
// Maybe run cleanup
|
||||
l.maybeCleanup()
|
||||
|
||||
// IP limit
|
||||
l.ipMu.Lock()
|
||||
lim, exists := l.ipLimiters[ip]
|
||||
if !exists {
|
||||
// Create new limiter for this IP
|
||||
lim = &ipLimiter{
|
||||
bucket: NewTokenBucket(
|
||||
float64(l.config.BurstSize),
|
||||
l.config.RequestsPerSecond,
|
||||
),
|
||||
lastSeen: time.Now(),
|
||||
}
|
||||
l.ipLimiters[ip] = lim
|
||||
l.uniqueIPs.Add(1)
|
||||
|
||||
l.logger.Debug("msg", "Created new IP limiter",
|
||||
"ip", ip,
|
||||
"total_ips", l.uniqueIPs.Load())
|
||||
} else {
|
||||
lim.lastSeen = time.Now()
|
||||
}
|
||||
l.ipMu.Unlock()
|
||||
|
||||
// Rate limit check
|
||||
allowed := lim.bucket.Allow()
|
||||
if !allowed {
|
||||
l.blockedByRateLimit.Add(1)
|
||||
}
|
||||
|
||||
return allowed
|
||||
}
|
||||
|
||||
// parseIPLists converts the string-based IP rules from the config into parsed net.IPNet objects.
|
||||
func (l *NetLimiter) parseIPLists() {
|
||||
// Parse whitelist
|
||||
for _, entry := range l.config.IPWhitelist {
|
||||
if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil {
|
||||
l.ipWhitelist = append(l.ipWhitelist, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse blacklist
|
||||
for _, entry := range l.config.IPBlacklist {
|
||||
if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil {
|
||||
l.ipBlacklist = append(l.ipBlacklist, ipNet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseIPEntry parses a single IP address or CIDR notation string into a net.IPNet object.
|
||||
func (l *NetLimiter) parseIPEntry(entry, listType string) *net.IPNet {
|
||||
// Handle single IP
|
||||
if !strings.Contains(entry, "/") {
|
||||
ip := net.ParseIP(entry)
|
||||
if ip == nil {
|
||||
l.logger.Warn("msg", "Invalid IP entry",
|
||||
"component", "netlimit",
|
||||
"ip", ip,
|
||||
"current", tracker.connections.Load(),
|
||||
"max", l.config.MaxConnectionsPerIP)
|
||||
return false
|
||||
"list", listType,
|
||||
"entry", entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reject IPv6
|
||||
if ip.To4() == nil {
|
||||
l.logger.Warn("msg", "IPv6 address rejected",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry,
|
||||
"reason", IPv4Only)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &net.IPNet{IP: ip.To4(), Mask: net.CIDRMask(32, 32)}
|
||||
}
|
||||
|
||||
// All checks passed, increment counters
|
||||
l.totalConnections.Add(1)
|
||||
|
||||
if ip != "" && l.config.MaxConnectionsPerIP > 0 {
|
||||
if tracker, exists := l.ipConnections[ip]; exists {
|
||||
tracker.connections.Add(1)
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
// Parse CIDR
|
||||
ipAddr, ipNet, err := net.ParseCIDR(entry)
|
||||
if err != nil {
|
||||
l.logger.Warn("msg", "Invalid CIDR entry",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry,
|
||||
"error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return true
|
||||
// Reject IPv6 CIDR
|
||||
if ipAddr.To4() == nil {
|
||||
l.logger.Warn("msg", "IPv6 CIDR rejected",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry,
|
||||
"reason", IPv4Only)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure mask is IPv4
|
||||
_, bits := ipNet.Mask.Size()
|
||||
if bits != 32 {
|
||||
l.logger.Warn("msg", "Non-IPv4 CIDR mask rejected",
|
||||
"component", "netlimit",
|
||||
"list", listType,
|
||||
"entry", entry,
|
||||
"mask_bits", bits,
|
||||
"reason", IPv4Only)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &net.IPNet{IP: ipAddr.To4(), Mask: ipNet.Mask}
|
||||
}
|
||||
|
||||
// Releases a tracked connection
|
||||
func (l *NetLimiter) ReleaseConnection(ip string, user string, token string) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
// updateConnectionActivity updates the last seen timestamp for a connection tracker.
|
||||
func (l *NetLimiter) updateConnectionActivity(ip string) {
|
||||
l.connMu.RLock()
|
||||
tracker, exists := l.ipConnections[ip]
|
||||
l.connMu.RUnlock()
|
||||
|
||||
l.connMu.Lock()
|
||||
defer l.connMu.Unlock()
|
||||
|
||||
// Decrement total
|
||||
if l.totalConnections.Load() > 0 {
|
||||
l.totalConnections.Add(-1)
|
||||
}
|
||||
|
||||
// Decrement IP counter
|
||||
if ip != "" {
|
||||
if tracker, exists := l.ipConnections[ip]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement user counter
|
||||
if user != "" {
|
||||
if tracker, exists := l.userConnections[user]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement token counter
|
||||
if token != "" {
|
||||
if tracker, exists := l.tokenConnections[token]; exists {
|
||||
if tracker.connections.Load() > 0 {
|
||||
tracker.connections.Add(-1)
|
||||
}
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
if exists {
|
||||
tracker.mu.Lock()
|
||||
tracker.lastSeen = time.Now()
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// isIPv4 is a helper function to check if a net.IP is an IPv4 address.
|
||||
func isIPv4(ip net.IP) bool {
|
||||
return ip.To4() != nil
|
||||
}
|
||||
@ -11,7 +11,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Enforces rate limits on log entries flowing through a pipeline.
|
||||
// RateLimiter enforces rate limits on log entries flowing through a pipeline.
|
||||
type RateLimiter struct {
|
||||
bucket *TokenBucket
|
||||
policy config.RateLimitPolicy
|
||||
@ -23,7 +23,7 @@ type RateLimiter struct {
|
||||
droppedCount atomic.Uint64
|
||||
}
|
||||
|
||||
// Creates a new rate limiter. If cfg.Rate is 0, it returns nil.
|
||||
// NewRateLimiter creates a new pipeline-level rate limiter from configuration.
|
||||
func NewRateLimiter(cfg config.RateLimitConfig, logger *log.Logger) (*RateLimiter, error) {
|
||||
if cfg.Rate <= 0 {
|
||||
return nil, nil // No rate limit
|
||||
@ -56,8 +56,7 @@ func NewRateLimiter(cfg config.RateLimitConfig, logger *log.Logger) (*RateLimite
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// Checks if a log entry is allowed to pass based on the rate limit.
|
||||
// It returns true if the entry should pass, false if it should be dropped.
|
||||
// Allow checks if a log entry is permitted to pass based on the rate limit.
|
||||
func (l *RateLimiter) Allow(entry core.LogEntry) bool {
|
||||
if l == nil || l.policy == config.PolicyPass {
|
||||
return true
|
||||
@ -83,7 +82,7 @@ func (l *RateLimiter) Allow(entry core.LogEntry) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetStats returns the statistics for the limiter.
|
||||
// GetStats returns statistics for the rate limiter.
|
||||
func (l *RateLimiter) GetStats() map[string]any {
|
||||
if l == nil {
|
||||
return map[string]any{
|
||||
@ -106,7 +105,7 @@ func (l *RateLimiter) GetStats() map[string]any {
|
||||
return stats
|
||||
}
|
||||
|
||||
// policyString returns the string representation of the policy.
|
||||
// policyString returns the string representation of a rate limit policy.
|
||||
func policyString(p config.RateLimitPolicy) string {
|
||||
switch p {
|
||||
case config.PolicyDrop:
|
||||
|
||||
@ -6,8 +6,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenBucket implements a token bucket rate limiter
|
||||
// Safe for concurrent use.
|
||||
// TokenBucket implements a thread-safe token bucket rate limiter.
|
||||
type TokenBucket struct {
|
||||
capacity float64
|
||||
tokens float64
|
||||
@ -16,7 +15,7 @@ type TokenBucket struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Creates a new token bucket with given capacity and refill rate
|
||||
// NewTokenBucket creates a new token bucket with a given capacity and refill rate.
|
||||
func NewTokenBucket(capacity float64, refillRate float64) *TokenBucket {
|
||||
return &TokenBucket{
|
||||
capacity: capacity,
|
||||
@ -26,12 +25,12 @@ func NewTokenBucket(capacity float64, refillRate float64) *TokenBucket {
|
||||
}
|
||||
}
|
||||
|
||||
// Attempts to consume one token, returns true if allowed
|
||||
// Allow attempts to consume one token, returning true if successful.
|
||||
func (tb *TokenBucket) Allow() bool {
|
||||
return tb.AllowN(1)
|
||||
}
|
||||
|
||||
// Attempts to consume n tokens, returns true if allowed
|
||||
// AllowN attempts to consume n tokens, returning true if successful.
|
||||
func (tb *TokenBucket) AllowN(n float64) bool {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
@ -45,7 +44,7 @@ func (tb *TokenBucket) AllowN(n float64) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Returns the current number of available tokens
|
||||
// Tokens returns the current number of available tokens in the bucket.
|
||||
func (tb *TokenBucket) Tokens() float64 {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
@ -54,8 +53,7 @@ func (tb *TokenBucket) Tokens() float64 {
|
||||
return tb.tokens
|
||||
}
|
||||
|
||||
// Adds tokens based on time elapsed since last refill
|
||||
// MUST be called with mutex held
|
||||
// refill adds new tokens to the bucket based on the elapsed time.
|
||||
func (tb *TokenBucket) refill() {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(tb.lastRefill).Seconds()
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Manages the flow of data from sources through filters to sinks
|
||||
// Pipeline manages the flow of data from sources, through filters, to sinks.
|
||||
type Pipeline struct {
|
||||
Config *config.PipelineConfig
|
||||
Sources []source.Source
|
||||
@ -33,7 +33,7 @@ type Pipeline struct {
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Contains statistics for a pipeline
|
||||
// PipelineStats contains runtime statistics for a pipeline.
|
||||
type PipelineStats struct {
|
||||
StartTime time.Time
|
||||
TotalEntriesProcessed atomic.Uint64
|
||||
@ -44,7 +44,7 @@ type PipelineStats struct {
|
||||
FilterStats map[string]any
|
||||
}
|
||||
|
||||
// Creates and starts a new pipeline
|
||||
// NewPipeline creates, configures, and starts a new pipeline within the service.
|
||||
func (s *Service) NewPipeline(cfg *config.PipelineConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@ -149,7 +149,7 @@ func (s *Service) NewPipeline(cfg *config.PipelineConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Gracefully stops the pipeline
|
||||
// Shutdown gracefully stops the pipeline and all its components.
|
||||
func (p *Pipeline) Shutdown() {
|
||||
p.logger.Info("msg", "Shutting down pipeline",
|
||||
"component", "pipeline",
|
||||
@ -187,7 +187,7 @@ func (p *Pipeline) Shutdown() {
|
||||
"pipeline", p.Config.Name)
|
||||
}
|
||||
|
||||
// Returns pipeline statistics
|
||||
// GetStats returns a map of the pipeline's current statistics.
|
||||
func (p *Pipeline) GetStats() map[string]any {
|
||||
// Recovery to handle concurrent access during shutdown
|
||||
// When service is shutting down, sources/sinks might be nil or partially stopped
|
||||
@ -263,7 +263,8 @@ func (p *Pipeline) GetStats() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// Runs periodic stats updates
|
||||
// TODO: incomplete implementation
|
||||
// startStatsUpdater runs a periodic stats updater.
|
||||
func (p *Pipeline) startStatsUpdater(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
|
||||
@ -15,7 +15,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Service manages multiple pipelines
|
||||
// Service manages a collection of log processing pipelines.
|
||||
type Service struct {
|
||||
pipelines map[string]*Pipeline
|
||||
mu sync.RWMutex
|
||||
@ -25,7 +25,7 @@ type Service struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a new service
|
||||
// NewService creates a new, empty service.
|
||||
func NewService(ctx context.Context, logger *log.Logger) *Service {
|
||||
serviceCtx, cancel := context.WithCancel(ctx)
|
||||
return &Service{
|
||||
@ -36,7 +36,97 @@ func NewService(ctx context.Context, logger *log.Logger) *Service {
|
||||
}
|
||||
}
|
||||
|
||||
// Connects sources to sinks through filters
|
||||
// GetPipeline returns a pipeline by its name.
|
||||
func (s *Service) GetPipeline(name string) (*Pipeline, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
pipeline, exists := s.pipelines[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("pipeline '%s' not found", name)
|
||||
}
|
||||
return pipeline, nil
|
||||
}
|
||||
|
||||
// ListPipelines returns the names of all currently managed pipelines.
|
||||
func (s *Service) ListPipelines() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(s.pipelines))
|
||||
for name := range s.pipelines {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// RemovePipeline stops and removes a pipeline from the service.
|
||||
func (s *Service) RemovePipeline(name string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
pipeline, exists := s.pipelines[name]
|
||||
if !exists {
|
||||
err := fmt.Errorf("pipeline '%s' not found", name)
|
||||
s.logger.Warn("msg", "Cannot remove non-existent pipeline",
|
||||
"component", "service",
|
||||
"pipeline", name,
|
||||
"error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("msg", "Removing pipeline", "pipeline", name)
|
||||
pipeline.Shutdown()
|
||||
delete(s.pipelines, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops all pipelines managed by the service.
|
||||
func (s *Service) Shutdown() {
|
||||
s.logger.Info("msg", "Service shutdown initiated")
|
||||
|
||||
s.mu.Lock()
|
||||
pipelines := make([]*Pipeline, 0, len(s.pipelines))
|
||||
for _, pipeline := range s.pipelines {
|
||||
pipelines = append(pipelines, pipeline)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Stop all pipelines concurrently
|
||||
var wg sync.WaitGroup
|
||||
for _, pipeline := range pipelines {
|
||||
wg.Add(1)
|
||||
go func(p *Pipeline) {
|
||||
defer wg.Done()
|
||||
p.Shutdown()
|
||||
}(pipeline)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
|
||||
s.logger.Info("msg", "Service shutdown complete")
|
||||
}
|
||||
|
||||
// GetGlobalStats returns statistics for all pipelines.
|
||||
func (s *Service) GetGlobalStats() map[string]any {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
stats := map[string]any{
|
||||
"pipelines": make(map[string]any),
|
||||
"total_pipelines": len(s.pipelines),
|
||||
}
|
||||
|
||||
for name, pipeline := range s.pipelines {
|
||||
stats["pipelines"].(map[string]any)[name] = pipeline.GetStats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// wirePipeline connects a pipeline's sources to its sinks through its filter chain.
|
||||
func (s *Service) wirePipeline(p *Pipeline) {
|
||||
// For each source, subscribe and process entries
|
||||
for _, src := range p.Sources {
|
||||
@ -113,7 +203,7 @@ func (s *Service) wirePipeline(p *Pipeline) {
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a source instance based on configuration
|
||||
// createSource is a factory function for creating a source instance from configuration.
|
||||
func (s *Service) createSource(cfg *config.SourceConfig) (source.Source, error) {
|
||||
switch cfg.Type {
|
||||
case "directory":
|
||||
@ -129,7 +219,7 @@ func (s *Service) createSource(cfg *config.SourceConfig) (source.Source, error)
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a sink instance based on configuration
|
||||
// createSink is a factory function for creating a sink instance from configuration.
|
||||
func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) (sink.Sink, error) {
|
||||
|
||||
switch cfg.Type {
|
||||
@ -157,93 +247,3 @@ func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter)
|
||||
return nil, fmt.Errorf("unknown sink type: %s", cfg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a pipeline by name
|
||||
func (s *Service) GetPipeline(name string) (*Pipeline, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
pipeline, exists := s.pipelines[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("pipeline '%s' not found", name)
|
||||
}
|
||||
return pipeline, nil
|
||||
}
|
||||
|
||||
// Returns all pipeline names
|
||||
func (s *Service) ListPipelines() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(s.pipelines))
|
||||
for name := range s.pipelines {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// Stops and removes a pipeline
|
||||
func (s *Service) RemovePipeline(name string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
pipeline, exists := s.pipelines[name]
|
||||
if !exists {
|
||||
err := fmt.Errorf("pipeline '%s' not found", name)
|
||||
s.logger.Warn("msg", "Cannot remove non-existent pipeline",
|
||||
"component", "service",
|
||||
"pipeline", name,
|
||||
"error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("msg", "Removing pipeline", "pipeline", name)
|
||||
pipeline.Shutdown()
|
||||
delete(s.pipelines, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stops all pipelines
|
||||
func (s *Service) Shutdown() {
|
||||
s.logger.Info("msg", "Service shutdown initiated")
|
||||
|
||||
s.mu.Lock()
|
||||
pipelines := make([]*Pipeline, 0, len(s.pipelines))
|
||||
for _, pipeline := range s.pipelines {
|
||||
pipelines = append(pipelines, pipeline)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Stop all pipelines concurrently
|
||||
var wg sync.WaitGroup
|
||||
for _, pipeline := range pipelines {
|
||||
wg.Add(1)
|
||||
go func(p *Pipeline) {
|
||||
defer wg.Done()
|
||||
p.Shutdown()
|
||||
}(pipeline)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
|
||||
s.logger.Info("msg", "Service shutdown complete")
|
||||
}
|
||||
|
||||
// Returns statistics for all pipelines
|
||||
func (s *Service) GetGlobalStats() map[string]any {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
stats := map[string]any{
|
||||
"pipelines": make(map[string]any),
|
||||
"total_pipelines": len(s.pipelines),
|
||||
}
|
||||
|
||||
for name, pipeline := range s.pipelines {
|
||||
stats["pipelines"].(map[string]any)[name] = pipeline.GetStats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
290
src/internal/session/session.go
Normal file
290
src/internal/session/session.go
Normal file
@ -0,0 +1,290 @@
|
||||
// FILE: src/internal/session/session.go
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Session represents a connection session.
|
||||
type Session struct {
|
||||
ID string // Unique session identifier
|
||||
RemoteAddr string // Client address
|
||||
CreatedAt time.Time // Session creation time
|
||||
LastActivity time.Time // Last activity timestamp
|
||||
Metadata map[string]any // Optional metadata (e.g., TLS info)
|
||||
|
||||
// Connection context
|
||||
Source string // Source type: "tcp_source", "http_source", "tcp_sink", etc.
|
||||
}
|
||||
|
||||
// Manager handles the lifecycle of sessions.
|
||||
type Manager struct {
|
||||
sessions map[string]*Session
|
||||
mu sync.RWMutex
|
||||
|
||||
// Cleanup configuration
|
||||
maxIdleTime time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
done chan struct{}
|
||||
|
||||
// Expiry callbacks by source type
|
||||
expiryCallbacks map[string]func(sessionID, remoteAddr string)
|
||||
callbacksMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new session manager with a specified idle timeout.
|
||||
func NewManager(maxIdleTime time.Duration) *Manager {
|
||||
if maxIdleTime == 0 {
|
||||
maxIdleTime = 30 * time.Minute // Default idle timeout
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
sessions: make(map[string]*Session),
|
||||
maxIdleTime: maxIdleTime,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
m.startCleanup()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// CreateSession creates and stores a new session for a connection.
|
||||
func (m *Manager) CreateSession(remoteAddr string, source string, metadata map[string]any) *Session {
|
||||
session := &Session{
|
||||
ID: generateSessionID(),
|
||||
RemoteAddr: remoteAddr,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivity: time.Now(),
|
||||
Source: source,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
if metadata == nil {
|
||||
session.Metadata = make(map[string]any)
|
||||
}
|
||||
|
||||
m.StoreSession(session)
|
||||
return session
|
||||
}
|
||||
|
||||
// StoreSession adds a session to the manager.
|
||||
func (m *Manager) StoreSession(session *Session) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.sessions[session.ID] = session
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by its unique ID.
|
||||
func (m *Manager) GetSession(sessionID string) (*Session, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
session, exists := m.sessions[sessionID]
|
||||
return session, exists
|
||||
}
|
||||
|
||||
// RemoveSession removes a session from the manager.
|
||||
func (m *Manager) RemoveSession(sessionID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.sessions, sessionID)
|
||||
}
|
||||
|
||||
// UpdateActivity updates the last activity timestamp for a session.
|
||||
func (m *Manager) UpdateActivity(sessionID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if session, exists := m.sessions[sessionID]; exists {
|
||||
session.LastActivity = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// IsSessionActive checks if a session exists and has not been idle for too long.
|
||||
func (m *Manager) IsSessionActive(sessionID string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if session, exists := m.sessions[sessionID]; exists {
|
||||
// Session exists and hasn't exceeded idle timeout
|
||||
return time.Since(session.LastActivity) < m.maxIdleTime
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetActiveSessions returns a snapshot of all currently active sessions.
|
||||
func (m *Manager) GetActiveSessions() []*Session {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
sessions := make([]*Session, 0, len(m.sessions))
|
||||
for _, session := range m.sessions {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
// GetSessionCount returns the number of active sessions.
|
||||
func (m *Manager) GetSessionCount() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.sessions)
|
||||
}
|
||||
|
||||
// GetSessionsBySource returns all sessions matching a specific source type.
|
||||
func (m *Manager) GetSessionsBySource(source string) []*Session {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var sessions []*Session
|
||||
for _, session := range m.sessions {
|
||||
if session.Source == source {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
// GetActiveSessionsBySource returns all active sessions for a given source.
|
||||
func (m *Manager) GetActiveSessionsBySource(source string) []*Session {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var sessions []*Session
|
||||
now := time.Now()
|
||||
|
||||
for _, session := range m.sessions {
|
||||
if session.Source == source && now.Sub(session.LastActivity) < m.maxIdleTime {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the session manager.
|
||||
func (m *Manager) GetStats() map[string]any {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
sourceCounts := make(map[string]int)
|
||||
var totalSessions int
|
||||
var oldestSession time.Time
|
||||
var newestSession time.Time
|
||||
|
||||
for _, session := range m.sessions {
|
||||
totalSessions++
|
||||
sourceCounts[session.Source]++
|
||||
|
||||
if oldestSession.IsZero() || session.CreatedAt.Before(oldestSession) {
|
||||
oldestSession = session.CreatedAt
|
||||
}
|
||||
if newestSession.IsZero() || session.CreatedAt.After(newestSession) {
|
||||
newestSession = session.CreatedAt
|
||||
}
|
||||
}
|
||||
|
||||
stats := map[string]any{
|
||||
"total_sessions": totalSessions,
|
||||
"sessions_by_type": sourceCounts,
|
||||
"max_idle_time": m.maxIdleTime.String(),
|
||||
}
|
||||
|
||||
if !oldestSession.IsZero() {
|
||||
stats["oldest_session_age"] = time.Since(oldestSession).String()
|
||||
}
|
||||
if !newestSession.IsZero() {
|
||||
stats["newest_session_age"] = time.Since(newestSession).String()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Stop gracefully stops the session manager and its cleanup goroutine.
|
||||
func (m *Manager) Stop() {
|
||||
close(m.done)
|
||||
if m.cleanupTicker != nil {
|
||||
m.cleanupTicker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterExpiryCallback registers a callback function to be executed when a session expires.
|
||||
func (m *Manager) RegisterExpiryCallback(source string, callback func(sessionID, remoteAddr string)) {
|
||||
m.callbacksMu.Lock()
|
||||
defer m.callbacksMu.Unlock()
|
||||
|
||||
if m.expiryCallbacks == nil {
|
||||
m.expiryCallbacks = make(map[string]func(sessionID, remoteAddr string))
|
||||
}
|
||||
m.expiryCallbacks[source] = callback
|
||||
}
|
||||
|
||||
// UnregisterExpiryCallback removes an expiry callback for a given source type.
|
||||
func (m *Manager) UnregisterExpiryCallback(source string) {
|
||||
m.callbacksMu.Lock()
|
||||
defer m.callbacksMu.Unlock()
|
||||
|
||||
delete(m.expiryCallbacks, source)
|
||||
}
|
||||
|
||||
// startCleanup initializes the periodic cleanup of idle sessions.
|
||||
func (m *Manager) startCleanup() {
|
||||
m.cleanupTicker = time.NewTicker(5 * time.Minute)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-m.cleanupTicker.C:
|
||||
m.cleanupIdleSessions()
|
||||
case <-m.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupIdleSessions removes sessions that have exceeded the maximum idle time.
|
||||
func (m *Manager) cleanupIdleSessions() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
expiredSessions := make([]*Session, 0)
|
||||
|
||||
for id, session := range m.sessions {
|
||||
idleTime := now.Sub(session.LastActivity)
|
||||
|
||||
if idleTime > m.maxIdleTime {
|
||||
expiredSessions = append(expiredSessions, session)
|
||||
delete(m.sessions, id)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Call callbacks outside of lock
|
||||
if len(expiredSessions) > 0 {
|
||||
m.callbacksMu.RLock()
|
||||
defer m.callbacksMu.RUnlock()
|
||||
|
||||
for _, session := range expiredSessions {
|
||||
if callback, exists := m.expiryCallbacks[session.Source]; exists {
|
||||
// Call callback to notify owner
|
||||
go callback(session.ID, session.RemoteAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generateSessionID creates a unique, random session identifier.
|
||||
func generateSessionID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to timestamp-based ID
|
||||
return fmt.Sprintf("session_%d", time.Now().UnixNano())
|
||||
}
|
||||
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b)
|
||||
}
|
||||
@ -16,7 +16,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance
|
||||
// ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance.
|
||||
type ConsoleSink struct {
|
||||
config *config.ConsoleSinkOptions
|
||||
input chan core.LogEntry
|
||||
@ -31,7 +31,7 @@ type ConsoleSink struct {
|
||||
lastProcessed atomic.Value // time.Time
|
||||
}
|
||||
|
||||
// Creates a new console sink
|
||||
// NewConsoleSink creates a new console sink.
|
||||
func NewConsoleSink(opts *config.ConsoleSinkOptions, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("console sink options cannot be nil")
|
||||
@ -73,10 +73,12 @@ func NewConsoleSink(opts *config.ConsoleSinkOptions, appLogger *log.Logger, form
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Input returns the channel for sending log entries.
|
||||
func (s *ConsoleSink) Input() chan<- core.LogEntry {
|
||||
return s.input
|
||||
}
|
||||
|
||||
// Start begins the processing loop for the sink.
|
||||
func (s *ConsoleSink) Start(ctx context.Context) error {
|
||||
// Start the internal writer's processing goroutine.
|
||||
if err := s.writer.Start(); err != nil {
|
||||
@ -89,6 +91,7 @@ func (s *ConsoleSink) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the sink.
|
||||
func (s *ConsoleSink) Stop() {
|
||||
target := s.writer.GetConfig().ConsoleTarget
|
||||
s.logger.Info("msg", "Stopping console sink", "target", target)
|
||||
@ -103,6 +106,7 @@ func (s *ConsoleSink) Stop() {
|
||||
s.logger.Info("msg", "Console sink stopped", "target", target)
|
||||
}
|
||||
|
||||
// GetStats returns the sink's statistics.
|
||||
func (s *ConsoleSink) GetStats() SinkStats {
|
||||
lastProc, _ := s.lastProcessed.Load().(time.Time)
|
||||
|
||||
@ -117,7 +121,7 @@ func (s *ConsoleSink) GetStats() SinkStats {
|
||||
}
|
||||
}
|
||||
|
||||
// processLoop reads entries, formats them, and passes them to the internal writer.
|
||||
// processLoop reads entries, formats them, and writes to the console.
|
||||
func (s *ConsoleSink) processLoop(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
|
||||
@ -15,7 +15,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Writes log entries to files with rotation
|
||||
// FileSink writes log entries to files with rotation.
|
||||
type FileSink struct {
|
||||
config *config.FileSinkOptions
|
||||
input chan core.LogEntry
|
||||
@ -30,7 +30,7 @@ type FileSink struct {
|
||||
lastProcessed atomic.Value // time.Time
|
||||
}
|
||||
|
||||
// Creates a new file sink
|
||||
// NewFileSink creates a new file sink.
|
||||
func NewFileSink(opts *config.FileSinkOptions, logger *log.Logger, formatter format.Formatter) (*FileSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("file sink options cannot be nil")
|
||||
@ -63,10 +63,12 @@ func NewFileSink(opts *config.FileSinkOptions, logger *log.Logger, formatter for
|
||||
return fs, nil
|
||||
}
|
||||
|
||||
// Input returns the channel for sending log entries.
|
||||
func (fs *FileSink) Input() chan<- core.LogEntry {
|
||||
return fs.input
|
||||
}
|
||||
|
||||
// Start begins the processing loop for the sink.
|
||||
func (fs *FileSink) Start(ctx context.Context) error {
|
||||
// Start the internal file writer
|
||||
if err := fs.writer.Start(); err != nil {
|
||||
@ -78,6 +80,7 @@ func (fs *FileSink) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the sink.
|
||||
func (fs *FileSink) Stop() {
|
||||
fs.logger.Info("msg", "Stopping file sink")
|
||||
close(fs.done)
|
||||
@ -92,6 +95,7 @@ func (fs *FileSink) Stop() {
|
||||
fs.logger.Info("msg", "File sink stopped")
|
||||
}
|
||||
|
||||
// GetStats returns the sink's statistics.
|
||||
func (fs *FileSink) GetStats() SinkStats {
|
||||
lastProc, _ := fs.lastProcessed.Load().(time.Time)
|
||||
|
||||
@ -104,6 +108,7 @@ func (fs *FileSink) GetStats() SinkStats {
|
||||
}
|
||||
}
|
||||
|
||||
// processLoop reads entries, formats them, and writes to a file.
|
||||
func (fs *FileSink) processLoop(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
|
||||
@ -5,18 +5,19 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/tls"
|
||||
"logwisp/src/internal/session"
|
||||
ltls "logwisp/src/internal/tls"
|
||||
"logwisp/src/internal/version"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
@ -24,7 +25,7 @@ import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Streams log entries via Server-Sent Events
|
||||
// HTTPSink streams log entries via Server-Sent Events (SSE).
|
||||
type HTTPSink struct {
|
||||
// Configuration reference (NOT a copy)
|
||||
config *config.HTTPSinkOptions
|
||||
@ -46,10 +47,11 @@ type HTTPSink struct {
|
||||
unregister chan uint64
|
||||
nextClientID atomic.Uint64
|
||||
|
||||
// Security components
|
||||
authenticator *auth.Authenticator
|
||||
tlsManager *tls.Manager
|
||||
authConfig *config.ServerAuthConfig
|
||||
// Session and security
|
||||
sessionManager *session.Manager
|
||||
clientSessions map[uint64]string // clientID -> sessionID
|
||||
sessionsMu sync.RWMutex
|
||||
tlsManager *ltls.ServerManager
|
||||
|
||||
// Net limiting
|
||||
netLimiter *limit.NetLimiter
|
||||
@ -57,31 +59,32 @@ type HTTPSink struct {
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
lastProcessed atomic.Value // time.Time
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
}
|
||||
|
||||
// Creates a new HTTP streaming sink
|
||||
// NewHTTPSink creates a new HTTP streaming sink.
|
||||
func NewHTTPSink(opts *config.HTTPSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("HTTP sink options cannot be nil")
|
||||
}
|
||||
|
||||
h := &HTTPSink{
|
||||
config: opts, // Direct reference to config struct
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
startTime: time.Now(),
|
||||
done: make(chan struct{}),
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
clients: make(map[uint64]chan core.LogEntry),
|
||||
config: opts,
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
startTime: time.Now(),
|
||||
done: make(chan struct{}),
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
clients: make(map[uint64]chan core.LogEntry),
|
||||
unregister: make(chan uint64),
|
||||
sessionManager: session.NewManager(30 * time.Minute),
|
||||
clientSessions: make(map[uint64]string),
|
||||
}
|
||||
|
||||
h.lastProcessed.Store(time.Time{})
|
||||
|
||||
// Initialize TLS manager if configured
|
||||
if opts.TLS != nil && opts.TLS.Enabled {
|
||||
tlsManager, err := tls.NewManager(opts.TLS, logger)
|
||||
tlsManager, err := ltls.NewServerManager(opts.TLS, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
@ -97,32 +100,21 @@ func NewHTTPSink(opts *config.HTTPSinkOptions, logger *log.Logger, formatter for
|
||||
h.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
|
||||
}
|
||||
|
||||
// Initialize authenticator if auth is not "none"
|
||||
if opts.Auth != nil && opts.Auth.Type != "none" {
|
||||
// Only "basic" and "token" are valid for HTTP sink
|
||||
if opts.Auth.Type != "basic" && opts.Auth.Type != "token" {
|
||||
return nil, fmt.Errorf("invalid auth type '%s' for HTTP sink (valid: none, basic, token)", opts.Auth.Type)
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(opts.Auth, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create authenticator: %w", err)
|
||||
}
|
||||
h.authenticator = authenticator
|
||||
h.authConfig = opts.Auth
|
||||
logger.Info("msg", "Authentication enabled",
|
||||
"component", "http_sink",
|
||||
"type", opts.Auth.Type)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Input returns the channel for sending log entries.
|
||||
func (h *HTTPSink) Input() chan<- core.LogEntry {
|
||||
return h.input
|
||||
}
|
||||
|
||||
// Start initializes the HTTP server and begins the broker loop.
|
||||
func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
// Register expiry callback
|
||||
h.sessionManager.RegisterExpiryCallback("http_sink", func(sessionID, remoteAddr string) {
|
||||
h.handleSessionExpiry(sessionID, remoteAddr)
|
||||
})
|
||||
|
||||
// Start central broker goroutine
|
||||
h.wg.Add(1)
|
||||
go h.brokerLoop(ctx)
|
||||
@ -144,6 +136,16 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
// Configure TLS if enabled
|
||||
if h.tlsManager != nil {
|
||||
h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
|
||||
|
||||
// Enforce mTLS configuration
|
||||
if h.config.TLS.ClientAuth {
|
||||
if h.config.TLS.VerifyClientCert {
|
||||
h.server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
} else {
|
||||
h.server.TLSConfig.ClientAuth = tls.RequireAnyClientCert
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("msg", "TLS enabled for HTTP sink",
|
||||
"component", "http_sink",
|
||||
"port", h.config.Port)
|
||||
@ -183,7 +185,7 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
if h.server != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
h.server.ShutdownWithContext(shutdownCtx)
|
||||
_ = h.server.ShutdownWithContext(shutdownCtx)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -197,7 +199,105 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcasts only to active clients
|
||||
// Stop gracefully shuts down the HTTP server and all client connections.
|
||||
func (h *HTTPSink) Stop() {
|
||||
h.logger.Info("msg", "Stopping HTTP sink")
|
||||
|
||||
// Unregister callback
|
||||
h.sessionManager.UnregisterExpiryCallback("http_sink")
|
||||
|
||||
// Signal all client handlers to stop
|
||||
close(h.done)
|
||||
|
||||
// Shutdown HTTP server
|
||||
if h.server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = h.server.ShutdownWithContext(ctx)
|
||||
}
|
||||
|
||||
// Wait for all active client handlers to finish
|
||||
h.wg.Wait()
|
||||
|
||||
// Close unregister channel after all clients have finished
|
||||
close(h.unregister)
|
||||
|
||||
// Close all client channels
|
||||
h.clientsMu.Lock()
|
||||
for _, ch := range h.clients {
|
||||
close(ch)
|
||||
}
|
||||
h.clients = make(map[uint64]chan core.LogEntry)
|
||||
h.clientsMu.Unlock()
|
||||
|
||||
// Stop session manager
|
||||
if h.sessionManager != nil {
|
||||
h.sessionManager.Stop()
|
||||
}
|
||||
|
||||
h.logger.Info("msg", "HTTP sink stopped")
|
||||
}
|
||||
|
||||
// GetStats returns the sink's statistics.
|
||||
func (h *HTTPSink) GetStats() SinkStats {
|
||||
lastProc, _ := h.lastProcessed.Load().(time.Time)
|
||||
|
||||
var netLimitStats map[string]any
|
||||
if h.netLimiter != nil {
|
||||
netLimitStats = h.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var sessionStats map[string]any
|
||||
if h.sessionManager != nil {
|
||||
sessionStats = h.sessionManager.GetStats()
|
||||
}
|
||||
|
||||
var tlsStats map[string]any
|
||||
if h.tlsManager != nil {
|
||||
tlsStats = h.tlsManager.GetStats()
|
||||
}
|
||||
|
||||
return SinkStats{
|
||||
Type: "http",
|
||||
TotalProcessed: h.totalProcessed.Load(),
|
||||
ActiveConnections: h.activeClients.Load(),
|
||||
StartTime: h.startTime,
|
||||
LastProcessed: lastProc,
|
||||
Details: map[string]any{
|
||||
"port": h.config.Port,
|
||||
"buffer_size": h.config.BufferSize,
|
||||
"endpoints": map[string]string{
|
||||
"stream": h.config.StreamPath,
|
||||
"status": h.config.StatusPath,
|
||||
},
|
||||
"net_limit": netLimitStats,
|
||||
"sessions": sessionStats,
|
||||
"tls": tlsStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveConnections returns the current number of active clients.
|
||||
func (h *HTTPSink) GetActiveConnections() int64 {
|
||||
return h.activeClients.Load()
|
||||
}
|
||||
|
||||
// GetStreamPath returns the configured transport endpoint path.
|
||||
func (h *HTTPSink) GetStreamPath() string {
|
||||
return h.config.StreamPath
|
||||
}
|
||||
|
||||
// GetStatusPath returns the configured status endpoint path.
|
||||
func (h *HTTPSink) GetStatusPath() string {
|
||||
return h.config.StatusPath
|
||||
}
|
||||
|
||||
// GetHost returns the configured host.
|
||||
func (h *HTTPSink) GetHost() string {
|
||||
return h.config.Host
|
||||
}
|
||||
|
||||
// brokerLoop manages client connections and broadcasts log entries.
|
||||
func (h *HTTPSink) brokerLoop(ctx context.Context) {
|
||||
defer h.wg.Done()
|
||||
|
||||
@ -233,6 +333,11 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) {
|
||||
}
|
||||
h.clientsMu.Unlock()
|
||||
|
||||
// Clean up session tracking
|
||||
h.sessionsMu.Lock()
|
||||
delete(h.clientSessions, clientID)
|
||||
h.sessionsMu.Unlock()
|
||||
|
||||
case entry, ok := <-h.input:
|
||||
if !ok {
|
||||
h.logger.Debug("msg", "Input channel closed, broker stopping",
|
||||
@ -248,23 +353,50 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) {
|
||||
clientCount := len(h.clients)
|
||||
if clientCount > 0 {
|
||||
slowClients := 0
|
||||
var staleClients []uint64
|
||||
|
||||
for id, ch := range h.clients {
|
||||
select {
|
||||
case ch <- entry:
|
||||
// Successfully sent
|
||||
default:
|
||||
// Client buffer full
|
||||
slowClients++
|
||||
if slowClients == 1 { // Log only once per broadcast
|
||||
h.logger.Debug("msg", "Dropped entry for slow client(s)",
|
||||
"component", "http_sink",
|
||||
"client_id", id,
|
||||
"slow_clients", slowClients,
|
||||
"total_clients", clientCount)
|
||||
h.sessionsMu.RLock()
|
||||
sessionID, hasSession := h.clientSessions[id]
|
||||
h.sessionsMu.RUnlock()
|
||||
|
||||
if hasSession {
|
||||
if !h.sessionManager.IsSessionActive(sessionID) {
|
||||
staleClients = append(staleClients, id)
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case ch <- entry:
|
||||
h.sessionManager.UpdateActivity(sessionID)
|
||||
default:
|
||||
slowClients++
|
||||
if slowClients == 1 {
|
||||
h.logger.Debug("msg", "Dropped entry for slow client(s)",
|
||||
"component", "http_sink",
|
||||
"client_id", id,
|
||||
"slow_clients", slowClients,
|
||||
"total_clients", clientCount)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
delete(h.clients, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up stale clients after broadcast
|
||||
if len(staleClients) > 0 {
|
||||
go func() {
|
||||
for _, clientID := range staleClients {
|
||||
select {
|
||||
case h.unregister <- clientID:
|
||||
case <-h.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// If no clients connected, entry is discarded (no buffering)
|
||||
h.clientsMu.RUnlock()
|
||||
|
||||
@ -275,91 +407,29 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) {
|
||||
|
||||
h.clientsMu.RLock()
|
||||
for id, ch := range h.clients {
|
||||
select {
|
||||
case ch <- heartbeatEntry:
|
||||
default:
|
||||
// Client buffer full, skip heartbeat
|
||||
h.logger.Debug("msg", "Skipped heartbeat for slow client",
|
||||
"component", "http_sink",
|
||||
"client_id", id)
|
||||
h.sessionsMu.RLock()
|
||||
sessionID, hasSession := h.clientSessions[id]
|
||||
h.sessionsMu.RUnlock()
|
||||
|
||||
if hasSession {
|
||||
select {
|
||||
case ch <- heartbeatEntry:
|
||||
// Update session activity on heartbeat
|
||||
h.sessionManager.UpdateActivity(sessionID)
|
||||
default:
|
||||
// Client buffer full, skip heartbeat
|
||||
h.logger.Debug("msg", "Skipped heartbeat for slow client",
|
||||
"component", "http_sink",
|
||||
"client_id", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
h.clientsMu.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HTTPSink) Stop() {
|
||||
h.logger.Info("msg", "Stopping HTTP sink")
|
||||
|
||||
// Signal all client handlers to stop
|
||||
close(h.done)
|
||||
|
||||
// Shutdown HTTP server
|
||||
if h.server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
h.server.ShutdownWithContext(ctx)
|
||||
}
|
||||
|
||||
// Wait for all active client handlers to finish
|
||||
h.wg.Wait()
|
||||
|
||||
// Close unregister channel after all clients have finished
|
||||
close(h.unregister)
|
||||
|
||||
// Close all client channels
|
||||
h.clientsMu.Lock()
|
||||
for _, ch := range h.clients {
|
||||
close(ch)
|
||||
}
|
||||
h.clients = make(map[uint64]chan core.LogEntry)
|
||||
h.clientsMu.Unlock()
|
||||
|
||||
h.logger.Info("msg", "HTTP sink stopped")
|
||||
}
|
||||
|
||||
func (h *HTTPSink) GetStats() SinkStats {
|
||||
lastProc, _ := h.lastProcessed.Load().(time.Time)
|
||||
|
||||
var netLimitStats map[string]any
|
||||
if h.netLimiter != nil {
|
||||
netLimitStats = h.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var authStats map[string]any
|
||||
if h.authenticator != nil {
|
||||
authStats = h.authenticator.GetStats()
|
||||
authStats["failures"] = h.authFailures.Load()
|
||||
authStats["successes"] = h.authSuccesses.Load()
|
||||
}
|
||||
|
||||
var tlsStats map[string]any
|
||||
if h.tlsManager != nil {
|
||||
tlsStats = h.tlsManager.GetStats()
|
||||
}
|
||||
|
||||
return SinkStats{
|
||||
Type: "http",
|
||||
TotalProcessed: h.totalProcessed.Load(),
|
||||
ActiveConnections: h.activeClients.Load(),
|
||||
StartTime: h.startTime,
|
||||
LastProcessed: lastProc,
|
||||
Details: map[string]any{
|
||||
"port": h.config.Port,
|
||||
"buffer_size": h.config.BufferSize,
|
||||
"endpoints": map[string]string{
|
||||
"stream": h.config.StreamPath,
|
||||
"status": h.config.StatusPath,
|
||||
},
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
"tls": tlsStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// requestHandler is the main entry point for all incoming HTTP requests.
|
||||
func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
|
||||
@ -380,21 +450,6 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce TLS for authentication
|
||||
if h.authenticator != nil && h.authConfig.Type != "none" {
|
||||
isTLS := ctx.IsTLS() || h.tlsManager != nil
|
||||
|
||||
if !isTLS {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "TLS required for authentication",
|
||||
"hint": "Use HTTPS for authenticated connections",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Status endpoint doesn't require auth
|
||||
@ -403,52 +458,14 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate request
|
||||
var session *auth.Session
|
||||
if h.authenticator != nil {
|
||||
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||
var err error
|
||||
session, err = h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
|
||||
if err != nil {
|
||||
h.authFailures.Add(1)
|
||||
h.logger.Warn("msg", "Authentication failed",
|
||||
"component", "http_sink",
|
||||
"remote_addr", remoteAddr,
|
||||
"error", err)
|
||||
|
||||
// Return 401 with WWW-Authenticate header
|
||||
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
||||
if h.authConfig.Type == "basic" && h.authConfig.Basic != nil {
|
||||
realm := h.authConfig.Basic.Realm
|
||||
if realm == "" {
|
||||
realm = "Restricted"
|
||||
}
|
||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
|
||||
} else if h.authConfig.Type == "token" {
|
||||
ctx.Response.Header.Set("WWW-Authenticate", "Token")
|
||||
}
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
h.authSuccesses.Add(1)
|
||||
} else {
|
||||
// Create anonymous session for unauthenticated connections
|
||||
session = &auth.Session{
|
||||
ID: fmt.Sprintf("anon-%d", time.Now().UnixNano()),
|
||||
Username: "anonymous",
|
||||
Method: "none",
|
||||
RemoteAddr: remoteAddr,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
// Create anonymous session for all connections
|
||||
sess := h.sessionManager.CreateSession(remoteAddr, "http_sink", map[string]any{
|
||||
"tls": ctx.IsTLS() || h.tlsManager != nil,
|
||||
})
|
||||
|
||||
switch path {
|
||||
case h.config.StreamPath:
|
||||
h.handleStream(ctx, session)
|
||||
h.handleStream(ctx, sess)
|
||||
default:
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetContentType("application/json")
|
||||
@ -456,18 +473,11 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
"error": "Not Found",
|
||||
})
|
||||
}
|
||||
// Handle stream endpoint
|
||||
// if path == h.config.StreamPath {
|
||||
// h.handleStream(ctx, session)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// // Unknown path
|
||||
// ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
// ctx.SetBody([]byte("Not Found"))
|
||||
|
||||
}
|
||||
|
||||
func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) {
|
||||
// handleStream manages a client's Server-Sent Events (SSE) stream.
|
||||
func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, sess *session.Session) {
|
||||
// Track connection for net limiting
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
if h.netLimiter != nil {
|
||||
@ -490,14 +500,18 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
|
||||
h.clients[clientID] = clientChan
|
||||
h.clientsMu.Unlock()
|
||||
|
||||
// Register session mapping
|
||||
h.sessionsMu.Lock()
|
||||
h.clientSessions[clientID] = sess.ID
|
||||
h.sessionsMu.Unlock()
|
||||
|
||||
// Define the stream writer function
|
||||
streamFunc := func(w *bufio.Writer) {
|
||||
connectCount := h.activeClients.Add(1)
|
||||
h.logger.Debug("msg", "HTTP client connected",
|
||||
"component", "http_sink",
|
||||
"remote_addr", remoteAddr,
|
||||
"username", session.Username,
|
||||
"auth_method", session.Method,
|
||||
"session_id", sess.ID,
|
||||
"client_id", clientID,
|
||||
"active_clients", connectCount)
|
||||
|
||||
@ -510,7 +524,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
|
||||
h.logger.Debug("msg", "HTTP client disconnected",
|
||||
"component", "http_sink",
|
||||
"remote_addr", remoteAddr,
|
||||
"username", session.Username,
|
||||
"session_id", sess.ID,
|
||||
"client_id", clientID,
|
||||
"active_clients", disconnectCount)
|
||||
|
||||
@ -521,14 +535,16 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
|
||||
// Shutting down, don't block
|
||||
}
|
||||
|
||||
// Remove session
|
||||
h.sessionManager.RemoveSession(sess.ID)
|
||||
|
||||
h.wg.Done()
|
||||
}()
|
||||
|
||||
// Send initial connected event with metadata
|
||||
connectionInfo := map[string]any{
|
||||
"client_id": fmt.Sprintf("%d", clientID),
|
||||
"username": session.Username,
|
||||
"auth_method": session.Method,
|
||||
"session_id": sess.ID,
|
||||
"stream_path": h.config.StreamPath,
|
||||
"status_path": h.config.StatusPath,
|
||||
"buffer_size": h.config.BufferSize,
|
||||
@ -573,20 +589,15 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
|
||||
return
|
||||
}
|
||||
|
||||
case <-tickerChan:
|
||||
// Validate session is still active
|
||||
if h.authenticator != nil && session != nil && !h.authenticator.ValidateSession(session.ID) {
|
||||
fmt.Fprintf(w, "event: disconnect\ndata: {\"reason\":\"session_expired\"}\n\n")
|
||||
w.Flush()
|
||||
return
|
||||
}
|
||||
// Update session activity
|
||||
h.sessionManager.UpdateActivity(sess.ID)
|
||||
|
||||
// Heartbeat is sent from broker, additional client-specific heartbeat is sent here
|
||||
// This provides per-client heartbeat validation with session check
|
||||
case <-tickerChan:
|
||||
// Client-specific heartbeat
|
||||
sessionHB := map[string]any{
|
||||
"type": "session_heartbeat",
|
||||
"client_id": fmt.Sprintf("%d", clientID),
|
||||
"session_valid": true,
|
||||
"type": "heartbeat",
|
||||
"client_id": fmt.Sprintf("%d", clientID),
|
||||
"session_id": sess.ID,
|
||||
}
|
||||
hbData, _ := json.Marshal(sessionHB)
|
||||
fmt.Fprintf(w, "event: heartbeat\ndata: %s\n\n", hbData)
|
||||
@ -607,49 +618,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
|
||||
ctx.SetBodyStreamWriter(streamFunc)
|
||||
}
|
||||
|
||||
func (h *HTTPSink) formatEntryForSSE(w *bufio.Writer, entry core.LogEntry) error {
|
||||
formatted, err := h.formatter.Format(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove trailing newline if present (SSE adds its own)
|
||||
formatted = bytes.TrimSuffix(formatted, []byte{'\n'})
|
||||
|
||||
// Multi-line content handler
|
||||
lines := bytes.Split(formatted, []byte{'\n'})
|
||||
for _, line := range lines {
|
||||
// SSE needs "data: " prefix for each line based on W3C spec
|
||||
fmt.Fprintf(w, "data: %s\n", line)
|
||||
}
|
||||
fmt.Fprintf(w, "\n") // Empty line to terminate event
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HTTPSink) createHeartbeatEntry() core.LogEntry {
|
||||
message := "heartbeat"
|
||||
|
||||
// Build fields for heartbeat metadata
|
||||
fields := make(map[string]any)
|
||||
fields["type"] = "heartbeat"
|
||||
|
||||
if h.config.Heartbeat.Enabled {
|
||||
fields["active_clients"] = h.activeClients.Load()
|
||||
fields["uptime_seconds"] = int(time.Since(h.startTime).Seconds())
|
||||
}
|
||||
|
||||
fieldsJSON, _ := json.Marshal(fields)
|
||||
|
||||
return core.LogEntry{
|
||||
Time: time.Now(),
|
||||
Source: "logwisp-http",
|
||||
Level: "INFO",
|
||||
Message: message,
|
||||
Fields: fieldsJSON,
|
||||
}
|
||||
}
|
||||
|
||||
// handleStatus provides a JSON status report of the sink.
|
||||
func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
@ -662,17 +631,6 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
var authStats any
|
||||
if h.authenticator != nil {
|
||||
authStats = h.authenticator.GetStats()
|
||||
authStats.(map[string]any)["failures"] = h.authFailures.Load()
|
||||
authStats.(map[string]any)["successes"] = h.authSuccesses.Load()
|
||||
} else {
|
||||
authStats = map[string]any{
|
||||
"enabled": false,
|
||||
}
|
||||
}
|
||||
|
||||
var tlsStats any
|
||||
if h.tlsManager != nil {
|
||||
tlsStats = h.tlsManager.GetStats()
|
||||
@ -682,6 +640,11 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
var sessionStats any
|
||||
if h.sessionManager != nil {
|
||||
sessionStats = h.sessionManager.GetStats()
|
||||
}
|
||||
|
||||
status := map[string]any{
|
||||
"service": "LogWisp",
|
||||
"version": version.Short(),
|
||||
@ -703,13 +666,11 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
"format": h.config.Heartbeat.Format,
|
||||
},
|
||||
"tls": tlsStats,
|
||||
"auth": authStats,
|
||||
"sessions": sessionStats,
|
||||
"net_limit": netLimitStats,
|
||||
},
|
||||
"statistics": map[string]any{
|
||||
"total_processed": h.totalProcessed.Load(),
|
||||
"auth_failures": h.authFailures.Load(),
|
||||
"auth_successes": h.authSuccesses.Load(),
|
||||
},
|
||||
}
|
||||
|
||||
@ -717,22 +678,71 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetBody(data)
|
||||
}
|
||||
|
||||
// Returns the current number of active clients
|
||||
func (h *HTTPSink) GetActiveConnections() int64 {
|
||||
return h.activeClients.Load()
|
||||
// handleSessionExpiry is the callback for cleaning up expired sessions.
|
||||
func (h *HTTPSink) handleSessionExpiry(sessionID, remoteAddr string) {
|
||||
h.sessionsMu.RLock()
|
||||
defer h.sessionsMu.RUnlock()
|
||||
|
||||
// Find client by session ID
|
||||
for clientID, sessID := range h.clientSessions {
|
||||
if sessID == sessionID {
|
||||
h.logger.Info("msg", "Closing expired session client",
|
||||
"component", "http_sink",
|
||||
"session_id", sessionID,
|
||||
"client_id", clientID,
|
||||
"remote_addr", remoteAddr)
|
||||
|
||||
// Signal broker to unregister
|
||||
select {
|
||||
case h.unregister <- clientID:
|
||||
case <-h.done:
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the configured transport endpoint path
|
||||
func (h *HTTPSink) GetStreamPath() string {
|
||||
return h.config.StreamPath
|
||||
// createHeartbeatEntry generates a new heartbeat log entry.
|
||||
func (h *HTTPSink) createHeartbeatEntry() core.LogEntry {
|
||||
message := "heartbeat"
|
||||
|
||||
// Build fields for heartbeat metadata
|
||||
fields := make(map[string]any)
|
||||
fields["type"] = "heartbeat"
|
||||
|
||||
if h.config.Heartbeat.Enabled {
|
||||
fields["active_clients"] = h.activeClients.Load()
|
||||
fields["uptime_seconds"] = int(time.Since(h.startTime).Seconds())
|
||||
}
|
||||
|
||||
fieldsJSON, _ := json.Marshal(fields)
|
||||
|
||||
return core.LogEntry{
|
||||
Time: time.Now(),
|
||||
Source: "logwisp-http",
|
||||
Level: "INFO",
|
||||
Message: message,
|
||||
Fields: fieldsJSON,
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the configured status endpoint path
|
||||
func (h *HTTPSink) GetStatusPath() string {
|
||||
return h.config.StatusPath
|
||||
}
|
||||
// formatEntryForSSE formats a log entry into the SSE 'data:' format.
|
||||
func (h *HTTPSink) formatEntryForSSE(w *bufio.Writer, entry core.LogEntry) error {
|
||||
formatted, err := h.formatter.Format(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Returns the configured host
|
||||
func (h *HTTPSink) GetHost() string {
|
||||
return h.config.Host
|
||||
// Remove trailing newline if present (SSE adds its own)
|
||||
formatted = bytes.TrimSuffix(formatted, []byte{'\n'})
|
||||
|
||||
// Multi-line content handler
|
||||
lines := bytes.Split(formatted, []byte{'\n'})
|
||||
for _, line := range lines {
|
||||
// SSE needs "data: " prefix for each line based on W3C spec
|
||||
fmt.Fprintf(w, "data: %s\n", line)
|
||||
}
|
||||
fmt.Fprintf(w, "\n") // Empty line to terminate event
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -5,39 +5,39 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/session"
|
||||
ltls "logwisp/src/internal/tls"
|
||||
"logwisp/src/internal/version"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TODO: implement heartbeat for HTTP Client Sink, similar to HTTP Sink
|
||||
// Forwards log entries to a remote HTTP endpoint
|
||||
// TODO: add heartbeat
|
||||
// HTTPClientSink forwards log entries to a remote HTTP endpoint.
|
||||
type HTTPClientSink struct {
|
||||
input chan core.LogEntry
|
||||
config *config.HTTPClientSinkOptions
|
||||
client *fasthttp.Client
|
||||
batch []core.LogEntry
|
||||
batchMu sync.Mutex
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
startTime time.Time
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
authenticator *auth.Authenticator
|
||||
input chan core.LogEntry
|
||||
config *config.HTTPClientSinkOptions
|
||||
client *fasthttp.Client
|
||||
batch []core.LogEntry
|
||||
batchMu sync.Mutex
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
startTime time.Time
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
sessionID string
|
||||
sessionManager *session.Manager
|
||||
tlsManager *ltls.ClientManager
|
||||
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
@ -48,20 +48,21 @@ type HTTPClientSink struct {
|
||||
activeConnections atomic.Int64
|
||||
}
|
||||
|
||||
// Creates a new HTTP client sink
|
||||
// NewHTTPClientSink creates a new HTTP client sink.
|
||||
func NewHTTPClientSink(opts *config.HTTPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPClientSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("HTTP client sink options cannot be nil")
|
||||
}
|
||||
|
||||
h := &HTTPClientSink{
|
||||
config: opts,
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
batch: make([]core.LogEntry, 0, opts.BatchSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
config: opts,
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
batch: make([]core.LogEntry, 0, opts.BatchSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
sessionManager: session.NewManager(30 * time.Minute),
|
||||
}
|
||||
h.lastProcessed.Store(time.Time{})
|
||||
h.lastBatchSent.Store(time.Time{})
|
||||
@ -75,54 +76,48 @@ func NewHTTPClientSink(opts *config.HTTPClientSinkOptions, logger *log.Logger, f
|
||||
DisableHeaderNamesNormalizing: true,
|
||||
}
|
||||
|
||||
// Configure TLS if using HTTPS
|
||||
// Configure TLS for HTTPS
|
||||
if strings.HasPrefix(opts.URL, "https://") {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: opts.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
// Use TLS config if provided
|
||||
if opts.TLS != nil {
|
||||
// Load custom CA for server verification
|
||||
if opts.TLS.CAFile != "" {
|
||||
caCert, err := os.ReadFile(opts.TLS.CAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA file '%s': %w", opts.TLS.CAFile, err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", opts.TLS.CAFile)
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
logger.Debug("msg", "Custom CA loaded for server verification",
|
||||
"component", "http_client_sink",
|
||||
"ca_file", opts.TLS.CAFile)
|
||||
if opts.TLS != nil && opts.TLS.Enabled {
|
||||
// Use the new ClientManager with the clear client-specific config
|
||||
tlsManager, err := ltls.NewClientManager(opts.TLS, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLS client manager: %w", err)
|
||||
}
|
||||
h.tlsManager = tlsManager
|
||||
// Get the generated config
|
||||
h.client.TLSConfig = tlsManager.GetConfig()
|
||||
|
||||
// Load client certificate for mTLS if provided
|
||||
if opts.TLS.CertFile != "" && opts.TLS.KeyFile != "" {
|
||||
cert, err := tls.LoadX509KeyPair(opts.TLS.CertFile, opts.TLS.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client certificate: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
logger.Info("msg", "Client certificate loaded for mTLS",
|
||||
"component", "http_client_sink",
|
||||
"cert_file", opts.TLS.CertFile)
|
||||
logger.Info("msg", "Client TLS configured",
|
||||
"component", "http_client_sink",
|
||||
"has_client_cert", opts.TLS.ClientCertFile != "", // Clearer check
|
||||
"has_server_ca", opts.TLS.ServerCAFile != "", // Clearer check
|
||||
"min_version", opts.TLS.MinVersion)
|
||||
} else if opts.InsecureSkipVerify { // Use the new clear field
|
||||
// TODO: document this behavior
|
||||
h.client.TLSConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
|
||||
h.client.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Input returns the channel for sending log entries.
|
||||
func (h *HTTPClientSink) Input() chan<- core.LogEntry {
|
||||
return h.input
|
||||
}
|
||||
|
||||
// Start begins the processing and batching loops.
|
||||
func (h *HTTPClientSink) Start(ctx context.Context) error {
|
||||
// Create session for HTTP client sink lifetime
|
||||
sess := h.sessionManager.CreateSession(h.config.URL, "http_client_sink", map[string]any{
|
||||
"batch_size": h.config.BatchSize,
|
||||
"timeout": h.config.Timeout,
|
||||
})
|
||||
h.sessionID = sess.ID
|
||||
|
||||
h.wg.Add(2)
|
||||
go h.processLoop(ctx)
|
||||
go h.batchTimer(ctx)
|
||||
@ -131,10 +126,12 @@ func (h *HTTPClientSink) Start(ctx context.Context) error {
|
||||
"component", "http_client_sink",
|
||||
"url", h.config.URL,
|
||||
"batch_size", h.config.BatchSize,
|
||||
"batch_delay_ms", h.config.BatchDelayMS)
|
||||
"batch_delay_ms", h.config.BatchDelayMS,
|
||||
"session_id", h.sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the sink, sending any remaining batched entries.
|
||||
func (h *HTTPClientSink) Stop() {
|
||||
h.logger.Info("msg", "Stopping HTTP client sink")
|
||||
close(h.done)
|
||||
@ -151,12 +148,21 @@ func (h *HTTPClientSink) Stop() {
|
||||
h.batchMu.Unlock()
|
||||
}
|
||||
|
||||
// Remove session and stop manager
|
||||
if h.sessionID != "" {
|
||||
h.sessionManager.RemoveSession(h.sessionID)
|
||||
}
|
||||
if h.sessionManager != nil {
|
||||
h.sessionManager.Stop()
|
||||
}
|
||||
|
||||
h.logger.Info("msg", "HTTP client sink stopped",
|
||||
"total_processed", h.totalProcessed.Load(),
|
||||
"total_batches", h.totalBatches.Load(),
|
||||
"failed_batches", h.failedBatches.Load())
|
||||
}
|
||||
|
||||
// GetStats returns the sink's statistics.
|
||||
func (h *HTTPClientSink) GetStats() SinkStats {
|
||||
lastProc, _ := h.lastProcessed.Load().(time.Time)
|
||||
lastBatch, _ := h.lastBatchSent.Load().(time.Time)
|
||||
@ -165,6 +171,23 @@ func (h *HTTPClientSink) GetStats() SinkStats {
|
||||
pendingEntries := len(h.batch)
|
||||
h.batchMu.Unlock()
|
||||
|
||||
// Get session information
|
||||
var sessionInfo map[string]any
|
||||
if h.sessionID != "" {
|
||||
if sess, exists := h.sessionManager.GetSession(h.sessionID); exists {
|
||||
sessionInfo = map[string]any{
|
||||
"session_id": sess.ID,
|
||||
"created_at": sess.CreatedAt,
|
||||
"last_activity": sess.LastActivity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var tlsStats map[string]any
|
||||
if h.tlsManager != nil {
|
||||
tlsStats = h.tlsManager.GetStats()
|
||||
}
|
||||
|
||||
return SinkStats{
|
||||
Type: "http_client",
|
||||
TotalProcessed: h.totalProcessed.Load(),
|
||||
@ -178,10 +201,13 @@ func (h *HTTPClientSink) GetStats() SinkStats {
|
||||
"total_batches": h.totalBatches.Load(),
|
||||
"failed_batches": h.failedBatches.Load(),
|
||||
"last_batch_sent": lastBatch,
|
||||
"session": sessionInfo,
|
||||
"tls": tlsStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// processLoop collects incoming log entries into a batch.
|
||||
func (h *HTTPClientSink) processLoop(ctx context.Context) {
|
||||
defer h.wg.Done()
|
||||
|
||||
@ -219,6 +245,7 @@ func (h *HTTPClientSink) processLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// batchTimer periodically triggers sending of the current batch.
|
||||
func (h *HTTPClientSink) batchTimer(ctx context.Context) {
|
||||
defer h.wg.Done()
|
||||
|
||||
@ -248,6 +275,7 @@ func (h *HTTPClientSink) batchTimer(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// sendBatch sends a batch of log entries to the remote endpoint with retry logic.
|
||||
func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
h.activeConnections.Add(1)
|
||||
defer h.activeConnections.Add(-1)
|
||||
@ -293,7 +321,6 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
var lastErr error
|
||||
retryDelay := time.Duration(h.config.RetryDelayMS) * time.Millisecond
|
||||
|
||||
// TODO: verify retry loop placement is correct or should it be after acquiring resources (req :=....)
|
||||
for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Wait before retry
|
||||
@ -323,24 +350,6 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short()))
|
||||
|
||||
// Add authentication based on auth type
|
||||
switch h.config.Auth.Type {
|
||||
case "basic":
|
||||
creds := h.config.Auth.Username + ":" + h.config.Auth.Password
|
||||
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
|
||||
req.Header.Set("Authorization", "Basic "+encodedCreds)
|
||||
|
||||
case "token":
|
||||
req.Header.Set("Authorization", "Token "+h.config.Auth.Token)
|
||||
|
||||
case "mtls":
|
||||
// mTLS auth is handled at TLS layer via client certificates
|
||||
// No Authorization header needed
|
||||
|
||||
case "none":
|
||||
// No authentication
|
||||
}
|
||||
|
||||
// Send request
|
||||
err := h.client.DoTimeout(req, resp, time.Duration(h.config.Timeout)*time.Second)
|
||||
|
||||
@ -370,6 +379,12 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
// Check response status
|
||||
if statusCode >= 200 && statusCode < 300 {
|
||||
// Success
|
||||
|
||||
// Update session activity on successful batch send
|
||||
if h.sessionID != "" {
|
||||
h.sessionManager.UpdateActivity(h.sessionID)
|
||||
}
|
||||
|
||||
h.logger.Debug("msg", "Batch sent successfully",
|
||||
"component", "http_client_sink",
|
||||
"batch_size", len(batch),
|
||||
|
||||
@ -8,22 +8,22 @@ import (
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
// Represents an output data stream
|
||||
// Sink represents an output data stream.
|
||||
type Sink interface {
|
||||
// Returns the channel for sending log entries to this sink
|
||||
// Input returns the channel for sending log entries to this sink.
|
||||
Input() chan<- core.LogEntry
|
||||
|
||||
// Begins processing log entries
|
||||
// Start begins processing log entries.
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Gracefully shuts down the sink
|
||||
// Stop gracefully shuts down the sink.
|
||||
Stop()
|
||||
|
||||
// Returns sink statistics
|
||||
// GetStats returns sink statistics.
|
||||
GetStats() SinkStats
|
||||
}
|
||||
|
||||
// Contains statistics about a sink
|
||||
// SinkStats contains statistics about a sink.
|
||||
type SinkStats struct {
|
||||
Type string
|
||||
TotalProcessed uint64
|
||||
|
||||
@ -11,31 +11,32 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/session"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/lixenwraith/log/compat"
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
)
|
||||
|
||||
// Streams log entries via TCP
|
||||
// TCPSink streams log entries to connected TCP clients.
|
||||
type TCPSink struct {
|
||||
input chan core.LogEntry
|
||||
config *config.TCPSinkOptions
|
||||
server *tcpServer
|
||||
done chan struct{}
|
||||
activeConns atomic.Int64
|
||||
startTime time.Time
|
||||
engine *gnet.Engine
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
netLimiter *limit.NetLimiter
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
input chan core.LogEntry
|
||||
config *config.TCPSinkOptions
|
||||
server *tcpServer
|
||||
done chan struct{}
|
||||
activeConns atomic.Int64
|
||||
startTime time.Time
|
||||
engine *gnet.Engine
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
netLimiter *limit.NetLimiter
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
sessionManager *session.Manager
|
||||
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
@ -47,7 +48,7 @@ type TCPSink struct {
|
||||
errorMu sync.Mutex
|
||||
}
|
||||
|
||||
// Holds TCP sink configuration
|
||||
// TCPConfig holds configuration for the TCPSink.
|
||||
type TCPConfig struct {
|
||||
Host string
|
||||
Port int64
|
||||
@ -56,19 +57,21 @@ type TCPConfig struct {
|
||||
NetLimit *config.NetLimitConfig
|
||||
}
|
||||
|
||||
// Creates a new TCP streaming sink
|
||||
// NewTCPSink creates a new TCP streaming sink.
|
||||
func NewTCPSink(opts *config.TCPSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("TCP sink options cannot be nil")
|
||||
}
|
||||
|
||||
t := &TCPSink{
|
||||
config: opts, // Direct reference to config
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
config: opts,
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
consecutiveWriteErrors: make(map[gnet.Conn]int),
|
||||
sessionManager: session.NewManager(30 * time.Minute),
|
||||
}
|
||||
t.lastProcessed.Store(time.Time{})
|
||||
|
||||
@ -82,16 +85,23 @@ func NewTCPSink(opts *config.TCPSinkOptions, logger *log.Logger, formatter forma
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Input returns the channel for sending log entries.
|
||||
func (t *TCPSink) Input() chan<- core.LogEntry {
|
||||
return t.input
|
||||
}
|
||||
|
||||
// Start initializes the TCP server and begins the broadcast loop.
|
||||
func (t *TCPSink) Start(ctx context.Context) error {
|
||||
t.server = &tcpServer{
|
||||
sink: t,
|
||||
clients: make(map[gnet.Conn]*tcpClient),
|
||||
}
|
||||
|
||||
// Register expiry callback
|
||||
t.sessionManager.RegisterExpiryCallback("tcp_sink", func(sessionID, remoteAddr string) {
|
||||
t.handleSessionExpiry(sessionID, remoteAddr)
|
||||
})
|
||||
|
||||
// Start log broadcast loop
|
||||
t.wg.Add(1)
|
||||
go func() {
|
||||
@ -155,8 +165,13 @@ func (t *TCPSink) Start(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the TCP server.
|
||||
func (t *TCPSink) Stop() {
|
||||
t.logger.Info("msg", "Stopping TCP sink")
|
||||
|
||||
// Unregister callback
|
||||
t.sessionManager.UnregisterExpiryCallback("tcp_sink")
|
||||
|
||||
// Signal broadcast loop to stop
|
||||
close(t.done)
|
||||
|
||||
@ -174,9 +189,15 @@ func (t *TCPSink) Stop() {
|
||||
// Wait for broadcast loop to finish
|
||||
t.wg.Wait()
|
||||
|
||||
// Stop session manager
|
||||
if t.sessionManager != nil {
|
||||
t.sessionManager.Stop()
|
||||
}
|
||||
|
||||
t.logger.Info("msg", "TCP sink stopped")
|
||||
}
|
||||
|
||||
// GetStats returns the sink's statistics.
|
||||
func (t *TCPSink) GetStats() SinkStats {
|
||||
lastProc, _ := t.lastProcessed.Load().(time.Time)
|
||||
|
||||
@ -185,6 +206,11 @@ func (t *TCPSink) GetStats() SinkStats {
|
||||
netLimitStats = t.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var sessionStats map[string]any
|
||||
if t.sessionManager != nil {
|
||||
sessionStats = t.sessionManager.GetStats()
|
||||
}
|
||||
|
||||
return SinkStats{
|
||||
Type: "tcp",
|
||||
TotalProcessed: t.totalProcessed.Load(),
|
||||
@ -195,11 +221,32 @@ func (t *TCPSink) GetStats() SinkStats {
|
||||
"port": t.config.Port,
|
||||
"buffer_size": t.config.BufferSize,
|
||||
"net_limit": netLimitStats,
|
||||
"auth": map[string]any{"enabled": false},
|
||||
"sessions": sessionStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveConnections returns the current number of active connections.
|
||||
func (t *TCPSink) GetActiveConnections() int64 {
|
||||
return t.activeConns.Load()
|
||||
}
|
||||
|
||||
// tcpServer implements the gnet.EventHandler interface for the TCP sink.
|
||||
type tcpServer struct {
|
||||
gnet.BuiltinEventEngine
|
||||
sink *TCPSink
|
||||
clients map[gnet.Conn]*tcpClient
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// tcpClient represents a connected TCP client.
|
||||
type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
sessionID string
|
||||
}
|
||||
|
||||
// broadcastLoop manages the central broadcasting of log entries to all clients.
|
||||
func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
var ticker *time.Ticker
|
||||
var tickerChan <-chan time.Time
|
||||
@ -248,101 +295,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TCPSink) broadcastData(data []byte) {
|
||||
t.server.mu.RLock()
|
||||
defer t.server.mu.RUnlock()
|
||||
|
||||
for conn, _ := range t.server.clients {
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
t.handleWriteError(c, err)
|
||||
} else {
|
||||
// Reset consecutive error count on success
|
||||
t.errorMu.Lock()
|
||||
delete(t.consecutiveWriteErrors, c)
|
||||
t.errorMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle write errors with threshold-based connection termination
|
||||
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
|
||||
t.errorMu.Lock()
|
||||
defer t.errorMu.Unlock()
|
||||
|
||||
// Track consecutive errors per connection
|
||||
if t.consecutiveWriteErrors == nil {
|
||||
t.consecutiveWriteErrors = make(map[gnet.Conn]int)
|
||||
}
|
||||
|
||||
t.consecutiveWriteErrors[c]++
|
||||
errorCount := t.consecutiveWriteErrors[c]
|
||||
|
||||
t.logger.Debug("msg", "AsyncWrite error",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"error", err,
|
||||
"consecutive_errors", errorCount)
|
||||
|
||||
// Close connection after 3 consecutive write errors
|
||||
if errorCount >= 3 {
|
||||
t.logger.Warn("msg", "Closing connection due to repeated write errors",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"error_count", errorCount)
|
||||
delete(t.consecutiveWriteErrors, c)
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Create heartbeat as a proper LogEntry
|
||||
func (t *TCPSink) createHeartbeatEntry() core.LogEntry {
|
||||
message := "heartbeat"
|
||||
|
||||
// Build fields for heartbeat metadata
|
||||
fields := make(map[string]any)
|
||||
fields["type"] = "heartbeat"
|
||||
|
||||
if t.config.Heartbeat.IncludeStats {
|
||||
fields["active_connections"] = t.activeConns.Load()
|
||||
fields["uptime_seconds"] = int64(time.Since(t.startTime).Seconds())
|
||||
}
|
||||
|
||||
fieldsJSON, _ := json.Marshal(fields)
|
||||
|
||||
return core.LogEntry{
|
||||
Time: time.Now(),
|
||||
Source: "logwisp-tcp",
|
||||
Level: "INFO",
|
||||
Message: message,
|
||||
Fields: fieldsJSON,
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the current number of connections
|
||||
func (t *TCPSink) GetActiveConnections() int64 {
|
||||
return t.activeConns.Load()
|
||||
}
|
||||
|
||||
// Represents a connected TCP client with auth state
|
||||
type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authTimeout time.Time
|
||||
session *auth.Session
|
||||
}
|
||||
|
||||
// Handles gnet events with authentication
|
||||
type tcpServer struct {
|
||||
gnet.BuiltinEventEngine
|
||||
sink *TCPSink
|
||||
clients map[gnet.Conn]*tcpClient
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// OnBoot is called when the server starts.
|
||||
func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
// Store engine reference for shutdown
|
||||
s.sink.engineMu.Lock()
|
||||
@ -355,6 +308,7 @@ func (s *tcpServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// OnOpen is called when a new connection is established.
|
||||
func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
remoteAddr := c.RemoteAddr()
|
||||
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
|
||||
@ -387,10 +341,14 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
s.sink.netLimiter.AddConnection(remoteStr)
|
||||
}
|
||||
|
||||
// Create session for tracking
|
||||
sess := s.sink.sessionManager.CreateSession(c.RemoteAddr().String(), "tcp_sink", nil)
|
||||
|
||||
// TCP Sink accepts all connections without authentication
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
buffer: bytes.Buffer{},
|
||||
conn: c,
|
||||
buffer: bytes.Buffer{},
|
||||
sessionID: sess.ID,
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
@ -400,14 +358,30 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
newCount := s.sink.activeConns.Add(1)
|
||||
s.sink.logger.Debug("msg", "TCP connection opened",
|
||||
"remote_addr", remoteAddr,
|
||||
"session_id", sess.ID,
|
||||
"active_connections", newCount)
|
||||
|
||||
return nil, gnet.None
|
||||
}
|
||||
|
||||
// OnClose is called when a connection is closed.
|
||||
func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
remoteAddr := c.RemoteAddr().String()
|
||||
|
||||
// Get client to retrieve session ID
|
||||
s.mu.RLock()
|
||||
client, exists := s.clients[c]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists && client.sessionID != "" {
|
||||
// Remove session
|
||||
s.sink.sessionManager.RemoveSession(client.sessionID)
|
||||
s.sink.logger.Debug("msg", "Session removed",
|
||||
"component", "tcp_sink",
|
||||
"session_id", client.sessionID,
|
||||
"remote_addr", remoteAddr)
|
||||
}
|
||||
|
||||
// Remove client state
|
||||
s.mu.Lock()
|
||||
delete(s.clients, c)
|
||||
@ -431,8 +405,141 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// OnTraffic is called when data is received from a connection.
|
||||
func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
s.mu.RLock()
|
||||
client, exists := s.clients[c]
|
||||
s.mu.RUnlock()
|
||||
|
||||
// Update session activity when client sends data
|
||||
if exists && client.sessionID != "" {
|
||||
s.sink.sessionManager.UpdateActivity(client.sessionID)
|
||||
}
|
||||
|
||||
// TCP Sink doesn't expect any data from clients, discard all
|
||||
c.Discard(-1)
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// handleSessionExpiry is the callback for cleaning up expired sessions.
|
||||
func (t *TCPSink) handleSessionExpiry(sessionID, remoteAddr string) {
|
||||
t.server.mu.RLock()
|
||||
defer t.server.mu.RUnlock()
|
||||
|
||||
// Find connection by session ID
|
||||
for conn, client := range t.server.clients {
|
||||
if client.sessionID == sessionID {
|
||||
t.logger.Info("msg", "Closing expired session connection",
|
||||
"component", "tcp_sink",
|
||||
"session_id", sessionID,
|
||||
"remote_addr", remoteAddr)
|
||||
|
||||
// Close connection
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// broadcastData sends a formatted byte slice to all connected clients.
|
||||
func (t *TCPSink) broadcastData(data []byte) {
|
||||
t.server.mu.RLock()
|
||||
defer t.server.mu.RUnlock()
|
||||
|
||||
// Track clients to remove after iteration
|
||||
var staleClients []gnet.Conn
|
||||
|
||||
for conn, client := range t.server.clients {
|
||||
// Update session activity before sending data
|
||||
if client.sessionID != "" {
|
||||
if !t.sessionManager.IsSessionActive(client.sessionID) {
|
||||
// Session expired, mark for cleanup
|
||||
staleClients = append(staleClients, conn)
|
||||
continue
|
||||
}
|
||||
t.sessionManager.UpdateActivity(client.sessionID)
|
||||
}
|
||||
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
t.handleWriteError(c, err)
|
||||
} else {
|
||||
// Reset consecutive error count on success
|
||||
t.errorMu.Lock()
|
||||
delete(t.consecutiveWriteErrors, c)
|
||||
t.errorMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Clean up stale connections outside the read lock
|
||||
if len(staleClients) > 0 {
|
||||
go t.cleanupStaleConnections(staleClients)
|
||||
}
|
||||
}
|
||||
|
||||
// handleWriteError manages errors during async writes, closing faulty connections.
|
||||
func (t *TCPSink) handleWriteError(c gnet.Conn, err error) {
|
||||
t.errorMu.Lock()
|
||||
defer t.errorMu.Unlock()
|
||||
|
||||
// Track consecutive errors per connection
|
||||
if t.consecutiveWriteErrors == nil {
|
||||
t.consecutiveWriteErrors = make(map[gnet.Conn]int)
|
||||
}
|
||||
|
||||
t.consecutiveWriteErrors[c]++
|
||||
errorCount := t.consecutiveWriteErrors[c]
|
||||
|
||||
t.logger.Debug("msg", "AsyncWrite error",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"error", err,
|
||||
"consecutive_errors", errorCount)
|
||||
|
||||
// Close connection after 3 consecutive write errors
|
||||
if errorCount >= 3 {
|
||||
t.logger.Warn("msg", "Closing connection due to repeated write errors",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"error_count", errorCount)
|
||||
delete(t.consecutiveWriteErrors, c)
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// createHeartbeatEntry generates a new heartbeat log entry.
|
||||
func (t *TCPSink) createHeartbeatEntry() core.LogEntry {
|
||||
message := "heartbeat"
|
||||
|
||||
// Build fields for heartbeat metadata
|
||||
fields := make(map[string]any)
|
||||
fields["type"] = "heartbeat"
|
||||
|
||||
if t.config.Heartbeat.IncludeStats {
|
||||
fields["active_connections"] = t.activeConns.Load()
|
||||
fields["uptime_seconds"] = int64(time.Since(t.startTime).Seconds())
|
||||
}
|
||||
|
||||
fieldsJSON, _ := json.Marshal(fields)
|
||||
|
||||
return core.LogEntry{
|
||||
Time: time.Now(),
|
||||
Source: "logwisp-tcp",
|
||||
Level: "INFO",
|
||||
Message: message,
|
||||
Fields: fieldsJSON,
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleConnections closes connections associated with expired sessions.
|
||||
func (t *TCPSink) cleanupStaleConnections(staleConns []gnet.Conn) {
|
||||
for _, conn := range staleConns {
|
||||
t.logger.Info("msg", "Closing stale connection",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
@ -2,28 +2,25 @@
|
||||
package sink
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/session"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// TODO: implement heartbeat for TCP Client Sink, similar to TCP Sink
|
||||
// Forwards log entries to a remote TCP endpoint
|
||||
// TODO: add heartbeat
|
||||
// TCPClientSink forwards log entries to a remote TCP endpoint.
|
||||
type TCPClientSink struct {
|
||||
input chan core.LogEntry
|
||||
config *config.TCPClientSinkOptions
|
||||
@ -36,7 +33,9 @@ type TCPClientSink struct {
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
|
||||
// Reconnection state
|
||||
// Connection
|
||||
sessionID string
|
||||
sessionManager *session.Manager
|
||||
reconnecting atomic.Bool
|
||||
lastConnectErr error
|
||||
connectTime time.Time
|
||||
@ -49,7 +48,7 @@ type TCPClientSink struct {
|
||||
connectionUptime atomic.Value // time.Duration
|
||||
}
|
||||
|
||||
// Creates a new TCP client sink
|
||||
// NewTCPClientSink creates a new TCP client sink.
|
||||
func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) {
|
||||
// Validation and defaults are handled in config package
|
||||
if opts == nil {
|
||||
@ -57,13 +56,14 @@ func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, for
|
||||
}
|
||||
|
||||
t := &TCPClientSink{
|
||||
config: opts,
|
||||
address: opts.Host + ":" + strconv.Itoa(int(opts.Port)),
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
config: opts,
|
||||
address: opts.Host + ":" + strconv.Itoa(int(opts.Port)),
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
sessionManager: session.NewManager(30 * time.Minute),
|
||||
}
|
||||
t.lastProcessed.Store(time.Time{})
|
||||
t.connectionUptime.Store(time.Duration(0))
|
||||
@ -71,10 +71,12 @@ func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, for
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Input returns the channel for sending log entries.
|
||||
func (t *TCPClientSink) Input() chan<- core.LogEntry {
|
||||
return t.input
|
||||
}
|
||||
|
||||
// Start begins the connection and processing loops.
|
||||
func (t *TCPClientSink) Start(ctx context.Context) error {
|
||||
// Start connection manager
|
||||
t.wg.Add(1)
|
||||
@ -91,6 +93,7 @@ func (t *TCPClientSink) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the sink and its connection.
|
||||
func (t *TCPClientSink) Stop() {
|
||||
t.logger.Info("msg", "Stopping TCP client sink")
|
||||
close(t.done)
|
||||
@ -103,12 +106,21 @@ func (t *TCPClientSink) Stop() {
|
||||
}
|
||||
t.connMu.Unlock()
|
||||
|
||||
// Remove session and stop manager
|
||||
if t.sessionID != "" {
|
||||
t.sessionManager.RemoveSession(t.sessionID)
|
||||
}
|
||||
if t.sessionManager != nil {
|
||||
t.sessionManager.Stop()
|
||||
}
|
||||
|
||||
t.logger.Info("msg", "TCP client sink stopped",
|
||||
"total_processed", t.totalProcessed.Load(),
|
||||
"total_failed", t.totalFailed.Load(),
|
||||
"total_reconnects", t.totalReconnects.Load())
|
||||
}
|
||||
|
||||
// GetStats returns the sink's statistics.
|
||||
func (t *TCPClientSink) GetStats() SinkStats {
|
||||
lastProc, _ := t.lastProcessed.Load().(time.Time)
|
||||
uptime, _ := t.connectionUptime.Load().(time.Duration)
|
||||
@ -122,6 +134,19 @@ func (t *TCPClientSink) GetStats() SinkStats {
|
||||
activeConns = 1
|
||||
}
|
||||
|
||||
// Get session stats
|
||||
var sessionInfo map[string]any
|
||||
if t.sessionID != "" {
|
||||
if sess, exists := t.sessionManager.GetSession(t.sessionID); exists {
|
||||
sessionInfo = map[string]any{
|
||||
"session_id": sess.ID,
|
||||
"created_at": sess.CreatedAt,
|
||||
"last_activity": sess.LastActivity,
|
||||
"remote_addr": sess.RemoteAddr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SinkStats{
|
||||
Type: "tcp_client",
|
||||
TotalProcessed: t.totalProcessed.Load(),
|
||||
@ -136,10 +161,12 @@ func (t *TCPClientSink) GetStats() SinkStats {
|
||||
"total_reconnects": t.totalReconnects.Load(),
|
||||
"connection_uptime": uptime.Seconds(),
|
||||
"last_error": fmt.Sprintf("%v", t.lastConnectErr),
|
||||
"session": sessionInfo,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// connectionManager handles the lifecycle of the TCP connection, including reconnections.
|
||||
func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
defer t.wg.Done()
|
||||
|
||||
@ -154,6 +181,11 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
default:
|
||||
}
|
||||
|
||||
if t.sessionID != "" {
|
||||
t.sessionManager.RemoveSession(t.sessionID)
|
||||
t.sessionID = ""
|
||||
}
|
||||
|
||||
// Attempt to connect
|
||||
t.reconnecting.Store(true)
|
||||
conn, err := t.connect()
|
||||
@ -190,6 +222,13 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
t.connectTime = time.Now()
|
||||
t.totalReconnects.Add(1)
|
||||
|
||||
// Create session for the connection
|
||||
sess := t.sessionManager.CreateSession(t.address, "tcp_client_sink", map[string]any{
|
||||
"local_addr": conn.LocalAddr().String(),
|
||||
"sink_type": "tcp_client",
|
||||
})
|
||||
t.sessionID = sess.ID
|
||||
|
||||
t.connMu.Lock()
|
||||
t.conn = conn
|
||||
t.connMu.Unlock()
|
||||
@ -197,7 +236,8 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
t.logger.Info("msg", "Connected to TCP server",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.address,
|
||||
"local_addr", conn.LocalAddr())
|
||||
"local_addr", conn.LocalAddr(),
|
||||
"session_id", t.sessionID)
|
||||
|
||||
// Monitor connection
|
||||
t.monitorConnection(conn)
|
||||
@ -214,10 +254,57 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
t.logger.Warn("msg", "Lost connection to TCP server",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.address,
|
||||
"uptime", uptime)
|
||||
"uptime", uptime,
|
||||
"session_id", t.sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// processLoop reads entries from the input channel and sends them.
|
||||
func (t *TCPClientSink) processLoop(ctx context.Context) {
|
||||
defer t.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case entry, ok := <-t.input:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
t.totalProcessed.Add(1)
|
||||
t.lastProcessed.Store(time.Now())
|
||||
|
||||
// Send entry
|
||||
if err := t.sendEntry(entry); err != nil {
|
||||
t.totalFailed.Add(1)
|
||||
t.logger.Debug("msg", "Failed to send log entry",
|
||||
"component", "tcp_client_sink",
|
||||
"error", err)
|
||||
} else {
|
||||
// Update session activity on successful send
|
||||
if t.sessionID != "" {
|
||||
t.sessionManager.UpdateActivity(t.sessionID)
|
||||
} else {
|
||||
// Close invalid connection without session
|
||||
t.logger.Warn("msg", "Connection without session detected, forcing reconnection",
|
||||
"component", "tcp_client_sink")
|
||||
t.connMu.Lock()
|
||||
if t.conn != nil {
|
||||
_ = t.conn.Close()
|
||||
t.conn = nil
|
||||
}
|
||||
t.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connect attempts to establish a connection to the remote server.
|
||||
func (t *TCPClientSink) connect() (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Duration(t.config.DialTimeout) * time.Second,
|
||||
@ -235,129 +322,10 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
|
||||
tcpConn.SetKeepAlivePeriod(time.Duration(t.config.KeepAlive) * time.Second)
|
||||
}
|
||||
|
||||
// SCRAM authentication if credentials configured
|
||||
if t.config.Auth != nil && t.config.Auth.Type == "scram" {
|
||||
if err := t.performSCRAMAuth(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("SCRAM authentication failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("msg", "SCRAM authentication completed",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.address)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
|
||||
reader := bufio.NewReader(conn)
|
||||
|
||||
// Create SCRAM client
|
||||
scramClient := auth.NewScramClient(t.config.Auth.Username, t.config.Auth.Password)
|
||||
|
||||
// Wait for AUTH_REQUIRED from server
|
||||
authPrompt, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read auth prompt: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(authPrompt) != "AUTH_REQUIRED" {
|
||||
return fmt.Errorf("unexpected server greeting: %s", authPrompt)
|
||||
}
|
||||
|
||||
// Step 1: Send ClientFirst
|
||||
clientFirst, err := scramClient.StartAuthentication()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start SCRAM: %w", err)
|
||||
}
|
||||
|
||||
msg, err := auth.FormatSCRAMRequest("SCRAM-FIRST", clientFirst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte(msg)); err != nil {
|
||||
return fmt.Errorf("failed to send SCRAM-FIRST: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Receive ServerFirst challenge
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read SCRAM challenge: %w", err)
|
||||
}
|
||||
|
||||
command, data, err := auth.ParseSCRAMResponse(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if command != "SCRAM-CHALLENGE" {
|
||||
return fmt.Errorf("unexpected server response: %s", command)
|
||||
}
|
||||
|
||||
var serverFirst auth.ServerFirst
|
||||
if err := json.Unmarshal([]byte(data), &serverFirst); err != nil {
|
||||
return fmt.Errorf("failed to parse server challenge: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Process challenge and send proof
|
||||
clientFinal, err := scramClient.ProcessServerFirst(&serverFirst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process challenge: %w", err)
|
||||
}
|
||||
|
||||
msg, err = auth.FormatSCRAMRequest("SCRAM-PROOF", clientFinal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte(msg)); err != nil {
|
||||
return fmt.Errorf("failed to send SCRAM-PROOF: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Receive ServerFinal
|
||||
response, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read SCRAM result: %w", err)
|
||||
}
|
||||
|
||||
command, data, err = auth.ParseSCRAMResponse(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch command {
|
||||
case "SCRAM-OK":
|
||||
var serverFinal auth.ServerFinal
|
||||
if err := json.Unmarshal([]byte(data), &serverFinal); err != nil {
|
||||
return fmt.Errorf("failed to parse server signature: %w", err)
|
||||
}
|
||||
|
||||
// Verify server signature
|
||||
if err := scramClient.VerifyServerFinal(&serverFinal); err != nil {
|
||||
return fmt.Errorf("server signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
t.logger.Info("msg", "SCRAM authentication successful",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.address,
|
||||
"username", t.config.Auth.Username,
|
||||
"session_id", serverFinal.SessionID)
|
||||
|
||||
return nil
|
||||
|
||||
case "SCRAM-FAIL":
|
||||
reason := data
|
||||
if reason == "" {
|
||||
reason = "unknown"
|
||||
}
|
||||
return fmt.Errorf("authentication failed: %s", reason)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected response: %s", command)
|
||||
}
|
||||
}
|
||||
|
||||
// monitorConnection checks the health of the connection.
|
||||
func (t *TCPClientSink) monitorConnection(conn net.Conn) {
|
||||
// Simple connection monitoring by periodic zero-byte reads
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
@ -390,35 +358,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TCPClientSink) processLoop(ctx context.Context) {
|
||||
defer t.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case entry, ok := <-t.input:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
t.totalProcessed.Add(1)
|
||||
t.lastProcessed.Store(time.Now())
|
||||
|
||||
// Send entry
|
||||
if err := t.sendEntry(entry); err != nil {
|
||||
t.totalFailed.Add(1)
|
||||
t.logger.Debug("msg", "Failed to send log entry",
|
||||
"component", "tcp_client_sink",
|
||||
"error", err)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendEntry formats and sends a single log entry over the connection.
|
||||
func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
|
||||
// Get current connection
|
||||
t.connMu.RLock()
|
||||
|
||||
@ -19,7 +19,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Monitors a directory for log files
|
||||
// DirectorySource monitors a directory for log files and tails them.
|
||||
type DirectorySource struct {
|
||||
config *config.DirectorySourceOptions
|
||||
subscribers []chan core.LogEntry
|
||||
@ -35,7 +35,7 @@ type DirectorySource struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a new directory monitoring source
|
||||
// NewDirectorySource creates a new directory monitoring source.
|
||||
func NewDirectorySource(opts *config.DirectorySourceOptions, logger *log.Logger) (*DirectorySource, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("directory source options cannot be nil")
|
||||
@ -52,6 +52,7 @@ func NewDirectorySource(opts *config.DirectorySourceOptions, logger *log.Logger)
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
// Subscribe returns a channel for receiving log entries.
|
||||
func (ds *DirectorySource) Subscribe() <-chan core.LogEntry {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
@ -61,6 +62,7 @@ func (ds *DirectorySource) Subscribe() <-chan core.LogEntry {
|
||||
return ch
|
||||
}
|
||||
|
||||
// Start begins the directory monitoring loop.
|
||||
func (ds *DirectorySource) Start() error {
|
||||
ds.ctx, ds.cancel = context.WithCancel(context.Background())
|
||||
ds.wg.Add(1)
|
||||
@ -74,6 +76,7 @@ func (ds *DirectorySource) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the directory source and all file watchers.
|
||||
func (ds *DirectorySource) Stop() {
|
||||
if ds.cancel != nil {
|
||||
ds.cancel()
|
||||
@ -82,7 +85,7 @@ func (ds *DirectorySource) Stop() {
|
||||
|
||||
ds.mu.Lock()
|
||||
for _, w := range ds.watchers {
|
||||
w.close()
|
||||
w.stop()
|
||||
}
|
||||
for _, ch := range ds.subscribers {
|
||||
close(ch)
|
||||
@ -94,6 +97,7 @@ func (ds *DirectorySource) Stop() {
|
||||
"path", ds.config.Path)
|
||||
}
|
||||
|
||||
// GetStats returns the source's statistics, including active watchers.
|
||||
func (ds *DirectorySource) GetStats() SourceStats {
|
||||
lastEntry, _ := ds.lastEntryTime.Load().(time.Time)
|
||||
|
||||
@ -128,24 +132,7 @@ func (ds *DirectorySource) GetStats() SourceStats {
|
||||
}
|
||||
}
|
||||
|
||||
func (ds *DirectorySource) publish(entry core.LogEntry) {
|
||||
ds.mu.RLock()
|
||||
defer ds.mu.RUnlock()
|
||||
|
||||
ds.totalEntries.Add(1)
|
||||
ds.lastEntryTime.Store(entry.Time)
|
||||
|
||||
for _, ch := range ds.subscribers {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
ds.droppedEntries.Add(1)
|
||||
ds.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
|
||||
"component", "directory_source")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// monitorLoop periodically scans the directory for new or changed files.
|
||||
func (ds *DirectorySource) monitorLoop() {
|
||||
defer ds.wg.Done()
|
||||
|
||||
@ -164,6 +151,7 @@ func (ds *DirectorySource) monitorLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// checkTargets finds matching files and ensures watchers are running for them.
|
||||
func (ds *DirectorySource) checkTargets() {
|
||||
files, err := ds.scanDirectory()
|
||||
if err != nil {
|
||||
@ -182,34 +170,7 @@ func (ds *DirectorySource) checkTargets() {
|
||||
ds.cleanupWatchers()
|
||||
}
|
||||
|
||||
func (ds *DirectorySource) scanDirectory() ([]string, error) {
|
||||
entries, err := os.ReadDir(ds.config.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert glob pattern to regex
|
||||
regexPattern := globToRegex(ds.config.Pattern)
|
||||
re, err := regexp.Compile(regexPattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pattern regex: %w", err)
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
if re.MatchString(name) {
|
||||
files = append(files, filepath.Join(ds.config.Path, name))
|
||||
}
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// ensureWatcher creates and starts a new file watcher if one doesn't exist for the given path.
|
||||
func (ds *DirectorySource) ensureWatcher(path string) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
@ -247,6 +208,7 @@ func (ds *DirectorySource) ensureWatcher(path string) {
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupWatchers stops and removes watchers for files that no longer exist.
|
||||
func (ds *DirectorySource) cleanupWatchers() {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
@ -262,6 +224,55 @@ func (ds *DirectorySource) cleanupWatchers() {
|
||||
}
|
||||
}
|
||||
|
||||
// publish sends a log entry to all subscribers.
|
||||
func (ds *DirectorySource) publish(entry core.LogEntry) {
|
||||
ds.mu.RLock()
|
||||
defer ds.mu.RUnlock()
|
||||
|
||||
ds.totalEntries.Add(1)
|
||||
ds.lastEntryTime.Store(entry.Time)
|
||||
|
||||
for _, ch := range ds.subscribers {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
ds.droppedEntries.Add(1)
|
||||
ds.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
|
||||
"component", "directory_source")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scanDirectory finds all files in the configured path that match the pattern.
|
||||
func (ds *DirectorySource) scanDirectory() ([]string, error) {
|
||||
entries, err := os.ReadDir(ds.config.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert glob pattern to regex
|
||||
regexPattern := globToRegex(ds.config.Pattern)
|
||||
re, err := regexp.Compile(regexPattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pattern regex: %w", err)
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
if re.MatchString(name) {
|
||||
files = append(files, filepath.Join(ds.config.Path, name))
|
||||
}
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// globToRegex converts a simple glob pattern to a regular expression.
|
||||
func globToRegex(glob string) string {
|
||||
regex := regexp.QuoteMeta(glob)
|
||||
regex = strings.ReplaceAll(regex, `\*`, `.*`)
|
||||
|
||||
@ -20,7 +20,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Contains information about a file watcher
|
||||
// WatcherInfo contains snapshot information about a file watcher's state.
|
||||
type WatcherInfo struct {
|
||||
Path string
|
||||
Size int64
|
||||
@ -31,6 +31,7 @@ type WatcherInfo struct {
|
||||
Rotations int64
|
||||
}
|
||||
|
||||
// fileWatcher tails a single file, handles rotations, and sends new lines to a callback.
|
||||
type fileWatcher struct {
|
||||
path string
|
||||
callback func(core.LogEntry)
|
||||
@ -46,6 +47,7 @@ type fileWatcher struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// newFileWatcher creates a new watcher for a specific file path.
|
||||
func newFileWatcher(path string, callback func(core.LogEntry), logger *log.Logger) *fileWatcher {
|
||||
w := &fileWatcher{
|
||||
path: path,
|
||||
@ -57,6 +59,7 @@ func newFileWatcher(path string, callback func(core.LogEntry), logger *log.Logge
|
||||
return w
|
||||
}
|
||||
|
||||
// watch starts the main monitoring loop for the file.
|
||||
func (w *fileWatcher) watch(ctx context.Context) error {
|
||||
if err := w.seekToEnd(); err != nil {
|
||||
return fmt.Errorf("seekToEnd failed: %w", err)
|
||||
@ -81,49 +84,34 @@ func (w *fileWatcher) watch(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *fileWatcher) seekToEnd() error {
|
||||
file, err := os.Open(w.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
w.mu.Lock()
|
||||
w.position = 0
|
||||
w.size = 0
|
||||
w.modTime = time.Now()
|
||||
w.inode = 0
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// stop signals the watcher to terminate its loop.
|
||||
func (w *fileWatcher) stop() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
// Keep existing position (including 0)
|
||||
// First time initialization seeks to the end of the file
|
||||
if w.position == -1 {
|
||||
pos, err := file.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.position = pos
|
||||
}
|
||||
|
||||
w.size = info.Size()
|
||||
w.modTime = info.ModTime()
|
||||
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
|
||||
w.inode = stat.Ino
|
||||
}
|
||||
|
||||
return nil
|
||||
w.stopped = true
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// getInfo returns a snapshot of the watcher's current statistics.
|
||||
func (w *fileWatcher) getInfo() WatcherInfo {
|
||||
w.mu.Lock()
|
||||
info := WatcherInfo{
|
||||
Path: w.path,
|
||||
Size: w.size,
|
||||
Position: w.position,
|
||||
ModTime: w.modTime,
|
||||
EntriesRead: w.entriesRead.Load(),
|
||||
Rotations: w.rotationSeq,
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
if lastRead, ok := w.lastReadTime.Load().(time.Time); ok {
|
||||
info.LastReadTime = lastRead
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// checkFile examines the file for changes, rotations, or new content.
|
||||
func (w *fileWatcher) checkFile() error {
|
||||
file, err := os.Open(w.path)
|
||||
if err != nil {
|
||||
@ -310,6 +298,58 @@ func (w *fileWatcher) checkFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// seekToEnd sets the initial read position to the end of the file.
|
||||
func (w *fileWatcher) seekToEnd() error {
|
||||
file, err := os.Open(w.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
w.mu.Lock()
|
||||
w.position = 0
|
||||
w.size = 0
|
||||
w.modTime = time.Now()
|
||||
w.inode = 0
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
// Keep existing position (including 0)
|
||||
// First time initialization seeks to the end of the file
|
||||
if w.position == -1 {
|
||||
pos, err := file.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.position = pos
|
||||
}
|
||||
|
||||
w.size = info.Size()
|
||||
w.modTime = info.ModTime()
|
||||
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
|
||||
w.inode = stat.Ino
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStopped checks if the watcher has been instructed to stop.
|
||||
func (w *fileWatcher) isStopped() bool {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.stopped
|
||||
}
|
||||
|
||||
// parseLine attempts to parse a line as JSON, falling back to plain text.
|
||||
func (w *fileWatcher) parseLine(line string) core.LogEntry {
|
||||
var jsonLog struct {
|
||||
Time string `json:"time"`
|
||||
@ -343,6 +383,7 @@ func (w *fileWatcher) parseLine(line string) core.LogEntry {
|
||||
}
|
||||
}
|
||||
|
||||
// extractLogLevel heuristically determines the log level from a line of text.
|
||||
func extractLogLevel(line string) string {
|
||||
patterns := []struct {
|
||||
patterns []string
|
||||
@ -366,38 +407,3 @@ func extractLogLevel(line string) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (w *fileWatcher) getInfo() WatcherInfo {
|
||||
w.mu.Lock()
|
||||
info := WatcherInfo{
|
||||
Path: w.path,
|
||||
Size: w.size,
|
||||
Position: w.position,
|
||||
ModTime: w.modTime,
|
||||
EntriesRead: w.entriesRead.Load(),
|
||||
Rotations: w.rotationSeq,
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
if lastRead, ok := w.lastReadTime.Load().(time.Time); ok {
|
||||
info.LastReadTime = lastRead
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func (w *fileWatcher) close() {
|
||||
w.stop()
|
||||
}
|
||||
|
||||
func (w *fileWatcher) stop() {
|
||||
w.mu.Lock()
|
||||
w.stopped = true
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
func (w *fileWatcher) isStopped() bool {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.stopped
|
||||
}
|
||||
@ -2,6 +2,7 @@
|
||||
package source
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -9,17 +10,17 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/tls"
|
||||
"logwisp/src/internal/session"
|
||||
ltls "logwisp/src/internal/tls"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Receives log entries via HTTP POST requests
|
||||
// HTTPSource receives log entries via HTTP POST requests.
|
||||
type HTTPSource struct {
|
||||
config *config.HTTPSourceOptions
|
||||
|
||||
@ -35,10 +36,10 @@ type HTTPSource struct {
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Security
|
||||
authenticator *auth.Authenticator
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
tlsManager *tls.Manager
|
||||
httpSessions sync.Map
|
||||
sessionManager *session.Manager
|
||||
tlsManager *ltls.ServerManager
|
||||
tlsStates sync.Map // remoteAddr -> *tls.ConnectionState
|
||||
|
||||
// Statistics
|
||||
totalEntries atomic.Uint64
|
||||
@ -48,7 +49,7 @@ type HTTPSource struct {
|
||||
lastEntryTime atomic.Value // time.Time
|
||||
}
|
||||
|
||||
// Creates a new HTTP server source
|
||||
// NewHTTPSource creates a new HTTP server source.
|
||||
func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSource, error) {
|
||||
// Validation done in config package
|
||||
if opts == nil {
|
||||
@ -56,10 +57,11 @@ func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSou
|
||||
}
|
||||
|
||||
h := &HTTPSource{
|
||||
config: opts,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
config: opts,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
sessionManager: session.NewManager(core.MaxSessionTime),
|
||||
}
|
||||
h.lastEntryTime.Store(time.Time{})
|
||||
|
||||
@ -72,34 +74,17 @@ func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSou
|
||||
|
||||
// Initialize TLS manager if configured
|
||||
if opts.TLS != nil && opts.TLS.Enabled {
|
||||
tlsManager, err := tls.NewManager(opts.TLS, logger)
|
||||
tlsManager, err := ltls.NewServerManager(opts.TLS, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
h.tlsManager = tlsManager
|
||||
}
|
||||
|
||||
// Initialize authenticator if configured
|
||||
if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" {
|
||||
// Verify TLS is enabled for auth (validation should have caught this)
|
||||
if h.tlsManager == nil {
|
||||
return nil, fmt.Errorf("authentication requires TLS to be enabled")
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(opts.Auth, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create authenticator: %w", err)
|
||||
}
|
||||
h.authenticator = authenticator
|
||||
|
||||
logger.Info("msg", "Authentication configured for HTTP source",
|
||||
"component", "http_source",
|
||||
"auth_type", opts.Auth.Type)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Subscribe returns a channel for receiving log entries.
|
||||
func (h *HTTPSource) Subscribe() <-chan core.LogEntry {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
@ -109,7 +94,13 @@ func (h *HTTPSource) Subscribe() <-chan core.LogEntry {
|
||||
return ch
|
||||
}
|
||||
|
||||
// Start initializes and starts the HTTP server.
|
||||
func (h *HTTPSource) Start() error {
|
||||
// Register expiry callback
|
||||
h.sessionManager.RegisterExpiryCallback("http_source", func(sessionID, remoteAddr string) {
|
||||
h.handleSessionExpiry(sessionID, remoteAddr)
|
||||
})
|
||||
|
||||
h.server = &fasthttp.Server{
|
||||
Handler: h.requestHandler,
|
||||
DisableKeepalive: false,
|
||||
@ -120,6 +111,20 @@ func (h *HTTPSource) Start() error {
|
||||
MaxRequestBodySize: int(h.config.MaxRequestBodySize),
|
||||
}
|
||||
|
||||
// TLS and mTLS configuration
|
||||
if h.tlsManager != nil {
|
||||
h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
|
||||
|
||||
// Enforce mTLS configuration from the TLSServerConfig struct.
|
||||
if h.config.TLS.ClientAuth {
|
||||
if h.config.TLS.VerifyClientCert {
|
||||
h.server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
} else {
|
||||
h.server.TLSConfig.ClientAuth = tls.RequireAnyClientCert
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use configured host and port
|
||||
addr := fmt.Sprintf("%s:%d", h.config.Host, h.config.Port)
|
||||
|
||||
@ -133,12 +138,22 @@ func (h *HTTPSource) Start() error {
|
||||
"port", h.config.Port,
|
||||
"ingest_path", h.config.IngestPath,
|
||||
"tls_enabled", h.tlsManager != nil,
|
||||
"auth_enabled", h.authenticator != nil)
|
||||
"mtls_enabled", h.config.TLS != nil && h.config.TLS.ClientAuth,
|
||||
)
|
||||
|
||||
var err error
|
||||
if h.tlsManager != nil {
|
||||
// HTTPS server
|
||||
h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
|
||||
|
||||
// Add certificate verification callback
|
||||
if h.config.TLS.ClientAuth {
|
||||
h.server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
if h.config.TLS.ClientCAFile != "" {
|
||||
// ClientCAs already set by tls.Manager
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPS server
|
||||
err = h.server.ListenAndServeTLS(addr, h.config.TLS.CertFile, h.config.TLS.KeyFile)
|
||||
} else {
|
||||
// HTTP server
|
||||
@ -163,8 +178,13 @@ func (h *HTTPSource) Start() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the HTTP server.
|
||||
func (h *HTTPSource) Stop() {
|
||||
h.logger.Info("msg", "Stopping HTTP source")
|
||||
|
||||
// Unregister callback
|
||||
h.sessionManager.UnregisterExpiryCallback("http_source")
|
||||
|
||||
close(h.done)
|
||||
|
||||
if h.server != nil {
|
||||
@ -189,9 +209,15 @@ func (h *HTTPSource) Stop() {
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
// Stop session manager
|
||||
if h.sessionManager != nil {
|
||||
h.sessionManager.Stop()
|
||||
}
|
||||
|
||||
h.logger.Info("msg", "HTTP source stopped")
|
||||
}
|
||||
|
||||
// GetStats returns the source's statistics.
|
||||
func (h *HTTPSource) GetStats() SourceStats {
|
||||
lastEntry, _ := h.lastEntryTime.Load().(time.Time)
|
||||
|
||||
@ -200,14 +226,9 @@ func (h *HTTPSource) GetStats() SourceStats {
|
||||
netLimitStats = h.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var authStats map[string]any
|
||||
if h.authenticator != nil {
|
||||
authStats = map[string]any{
|
||||
"enabled": true,
|
||||
"type": h.config.Auth.Type,
|
||||
"failures": h.authFailures.Load(),
|
||||
"successes": h.authSuccesses.Load(),
|
||||
}
|
||||
var sessionStats map[string]any
|
||||
if h.sessionManager != nil {
|
||||
sessionStats = h.sessionManager.GetStats()
|
||||
}
|
||||
|
||||
var tlsStats map[string]any
|
||||
@ -227,12 +248,13 @@ func (h *HTTPSource) GetStats() SourceStats {
|
||||
"path": h.config.IngestPath,
|
||||
"invalid_entries": h.invalidEntries.Load(),
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
"sessions": sessionStats,
|
||||
"tls": tlsStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// requestHandler is the main entry point for all incoming HTTP requests.
|
||||
func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
remoteAddr := ctx.RemoteAddr().String()
|
||||
|
||||
@ -262,42 +284,26 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check TLS requirement for auth
|
||||
if h.authenticator != nil {
|
||||
isTLS := ctx.IsTLS() || h.tlsManager != nil
|
||||
if !isTLS {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "TLS required for authentication",
|
||||
"hint": "Use HTTPS to submit authenticated requests",
|
||||
})
|
||||
return
|
||||
// 3. Create session for connections
|
||||
var sess *session.Session
|
||||
if savedID, exists := h.httpSessions.Load(remoteAddr); exists {
|
||||
if s, found := h.sessionManager.GetSession(savedID.(string)); found {
|
||||
sess = s
|
||||
h.sessionManager.UpdateActivity(savedID.(string))
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate request
|
||||
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||
session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
|
||||
if err != nil {
|
||||
h.authFailures.Add(1)
|
||||
h.logger.Warn("msg", "Authentication failed",
|
||||
"component", "http_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"error", err)
|
||||
if sess == nil {
|
||||
// New connection
|
||||
sess = h.sessionManager.CreateSession(remoteAddr, "http_source", map[string]any{
|
||||
"tls": ctx.IsTLS() || h.tlsManager != nil,
|
||||
"mtls_enabled": h.config.TLS != nil && h.config.TLS.ClientAuth,
|
||||
})
|
||||
h.httpSessions.Store(remoteAddr, sess.ID)
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
||||
if h.config.Auth.Type == "basic" && h.config.Auth.Basic != nil && h.config.Auth.Basic.Realm != "" {
|
||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, h.config.Auth.Basic.Realm))
|
||||
}
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Authentication failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.authSuccesses.Add(1)
|
||||
_ = session // Session can be used for audit logging
|
||||
// Setup connection close handler
|
||||
ctx.SetConnectionClose()
|
||||
go h.cleanupHTTPSession(remoteAddr, sess.ID)
|
||||
}
|
||||
|
||||
// 4. Path check
|
||||
@ -359,14 +365,58 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
// Publish to subscribers
|
||||
h.publish(entry)
|
||||
|
||||
// Update session activity after successful processing
|
||||
h.sessionManager.UpdateActivity(sess.ID)
|
||||
|
||||
// Success response
|
||||
ctx.SetStatusCode(fasthttp.StatusAccepted)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"status": "accepted",
|
||||
"status": "accepted",
|
||||
"session_id": sess.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// publish sends a log entry to all subscribers.
|
||||
func (h *HTTPSource) publish(entry core.LogEntry) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
h.totalEntries.Add(1)
|
||||
h.lastEntryTime.Store(entry.Time)
|
||||
|
||||
for _, ch := range h.subscribers {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
h.droppedEntries.Add(1)
|
||||
h.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
|
||||
"component", "http_source")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionExpiry is the callback for cleaning up expired sessions.
|
||||
func (h *HTTPSource) handleSessionExpiry(sessionID, remoteAddr string) {
|
||||
h.logger.Info("msg", "Removing expired HTTP session",
|
||||
"component", "http_source",
|
||||
"session_id", sessionID,
|
||||
"remote_addr", remoteAddr)
|
||||
|
||||
// Remove from mapping
|
||||
h.httpSessions.Delete(remoteAddr)
|
||||
}
|
||||
|
||||
// cleanupHTTPSession removes a session when a client connection is closed.
|
||||
func (h *HTTPSource) cleanupHTTPSession(addr, sessionID string) {
|
||||
// Wait for connection to actually close
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
h.httpSessions.CompareAndDelete(addr, sessionID)
|
||||
h.sessionManager.RemoveSession(sessionID)
|
||||
}
|
||||
|
||||
// parseEntries attempts to parse a request body as a single JSON object, a JSON array, or newline-delimited JSON.
|
||||
func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) {
|
||||
var entries []core.LogEntry
|
||||
|
||||
@ -442,25 +492,7 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) {
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (h *HTTPSource) publish(entry core.LogEntry) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
h.totalEntries.Add(1)
|
||||
h.lastEntryTime.Store(entry.Time)
|
||||
|
||||
for _, ch := range h.subscribers {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
h.droppedEntries.Add(1)
|
||||
h.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
|
||||
"component", "http_source")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Splits bytes into lines, handling both \n and \r\n
|
||||
// splitLines splits a byte slice into lines, handling both \n and \r\n.
|
||||
func splitLines(data []byte) [][]byte {
|
||||
var lines [][]byte
|
||||
start := 0
|
||||
|
||||
@ -7,22 +7,22 @@ import (
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
// Represents an input data stream
|
||||
// Source represents an input data stream for log entries.
|
||||
type Source interface {
|
||||
// Returns a channel that receives log entries
|
||||
// Subscribe returns a channel that receives log entries from the source.
|
||||
Subscribe() <-chan core.LogEntry
|
||||
|
||||
// Begins reading from the source
|
||||
// Start begins reading from the source.
|
||||
Start() error
|
||||
|
||||
// Gracefully shuts down the source
|
||||
// Stop gracefully shuts down the source.
|
||||
Stop()
|
||||
|
||||
// Returns source statistics
|
||||
// SourceStats contains statistics about a source.
|
||||
GetStats() SourceStats
|
||||
}
|
||||
|
||||
// Contains statistics about a source
|
||||
// SourceStats contains statistics about a source.
|
||||
type SourceStats struct {
|
||||
Type string
|
||||
TotalEntries uint64
|
||||
|
||||
@ -13,7 +13,7 @@ import (
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Reads log entries from standard input
|
||||
// StdinSource reads log entries from the standard input stream.
|
||||
type StdinSource struct {
|
||||
config *config.StdinSourceOptions
|
||||
subscribers []chan core.LogEntry
|
||||
@ -25,6 +25,7 @@ type StdinSource struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewStdinSource creates a new stdin source.
|
||||
func NewStdinSource(opts *config.StdinSourceOptions, logger *log.Logger) (*StdinSource, error) {
|
||||
if opts == nil {
|
||||
opts = &config.StdinSourceOptions{
|
||||
@ -43,18 +44,21 @@ func NewStdinSource(opts *config.StdinSourceOptions, logger *log.Logger) (*Stdin
|
||||
return source, nil
|
||||
}
|
||||
|
||||
// Subscribe returns a channel for receiving log entries.
|
||||
func (s *StdinSource) Subscribe() <-chan core.LogEntry {
|
||||
ch := make(chan core.LogEntry, s.config.BufferSize)
|
||||
s.subscribers = append(s.subscribers, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Start begins reading from the standard input.
|
||||
func (s *StdinSource) Start() error {
|
||||
go s.readLoop()
|
||||
s.logger.Info("msg", "Stdin source started", "component", "stdin_source")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop signals the source to stop reading.
|
||||
func (s *StdinSource) Stop() {
|
||||
close(s.done)
|
||||
for _, ch := range s.subscribers {
|
||||
@ -63,6 +67,7 @@ func (s *StdinSource) Stop() {
|
||||
s.logger.Info("msg", "Stdin source stopped", "component", "stdin_source")
|
||||
}
|
||||
|
||||
// GetStats returns the source's statistics.
|
||||
func (s *StdinSource) GetStats() SourceStats {
|
||||
lastEntry, _ := s.lastEntryTime.Load().(time.Time)
|
||||
|
||||
@ -76,6 +81,7 @@ func (s *StdinSource) GetStats() SourceStats {
|
||||
}
|
||||
}
|
||||
|
||||
// readLoop continuously reads lines from stdin and publishes them.
|
||||
func (s *StdinSource) readLoop() {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
for scanner.Scan() {
|
||||
@ -107,6 +113,7 @@ func (s *StdinSource) readLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// publish sends a log entry to all subscribers.
|
||||
func (s *StdinSource) publish(entry core.LogEntry) {
|
||||
s.totalEntries.Add(1)
|
||||
s.lastEntryTime.Store(entry.Time)
|
||||
|
||||
@ -7,15 +7,14 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/session"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/lixenwraith/log/compat"
|
||||
@ -27,21 +26,19 @@ const (
|
||||
maxLineLength = 1 * 1024 * 1024 // 1MB max per log line
|
||||
)
|
||||
|
||||
// Receives log entries via TCP connections
|
||||
// TCPSource receives log entries via TCP connections.
|
||||
type TCPSource struct {
|
||||
config *config.TCPSourceOptions
|
||||
server *tcpSourceServer
|
||||
subscribers []chan core.LogEntry
|
||||
mu sync.RWMutex
|
||||
done chan struct{}
|
||||
engine *gnet.Engine
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
authenticator *auth.Authenticator
|
||||
netLimiter *limit.NetLimiter
|
||||
logger *log.Logger
|
||||
scramManager *auth.ScramManager
|
||||
scramProtocolHandler *auth.ScramProtocolHandler
|
||||
config *config.TCPSourceOptions
|
||||
server *tcpSourceServer
|
||||
subscribers []chan core.LogEntry
|
||||
mu sync.RWMutex
|
||||
done chan struct{}
|
||||
engine *gnet.Engine
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
sessionManager *session.Manager
|
||||
netLimiter *limit.NetLimiter
|
||||
logger *log.Logger
|
||||
|
||||
// Statistics
|
||||
totalEntries atomic.Uint64
|
||||
@ -50,11 +47,9 @@ type TCPSource struct {
|
||||
activeConns atomic.Int64
|
||||
startTime time.Time
|
||||
lastEntryTime atomic.Value // time.Time
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
}
|
||||
|
||||
// Creates a new TCP server source
|
||||
// NewTCPSource creates a new TCP server source.
|
||||
func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource, error) {
|
||||
// Accept typed config - validation done in config package
|
||||
if opts == nil {
|
||||
@ -62,10 +57,11 @@ func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource
|
||||
}
|
||||
|
||||
t := &TCPSource{
|
||||
config: opts,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
config: opts,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
sessionManager: session.NewManager(core.MaxSessionTime),
|
||||
}
|
||||
t.lastEntryTime.Store(time.Time{})
|
||||
|
||||
@ -76,20 +72,10 @@ func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource
|
||||
t.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
|
||||
}
|
||||
|
||||
// Initialize SCRAM
|
||||
if opts.Auth != nil && opts.Auth.Type == "scram" && opts.Auth.Scram != nil {
|
||||
t.scramManager = auth.NewScramManager(opts.Auth.Scram)
|
||||
t.scramProtocolHandler = auth.NewScramProtocolHandler(t.scramManager, logger)
|
||||
logger.Info("msg", "SCRAM authentication configured for TCP source",
|
||||
"component", "tcp_source",
|
||||
"users", len(opts.Auth.Scram.Users))
|
||||
} else if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" {
|
||||
return nil, fmt.Errorf("TCP source only supports 'none' or 'scram' auth")
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Subscribe returns a channel for receiving log entries.
|
||||
func (t *TCPSource) Subscribe() <-chan core.LogEntry {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
@ -99,12 +85,18 @@ func (t *TCPSource) Subscribe() <-chan core.LogEntry {
|
||||
return ch
|
||||
}
|
||||
|
||||
// Start initializes and starts the TCP server.
|
||||
func (t *TCPSource) Start() error {
|
||||
t.server = &tcpSourceServer{
|
||||
source: t,
|
||||
clients: make(map[gnet.Conn]*tcpClient),
|
||||
}
|
||||
|
||||
// Register expiry callback
|
||||
t.sessionManager.RegisterExpiryCallback("tcp_source", func(sessionID, remoteAddr string) {
|
||||
t.handleSessionExpiry(sessionID, remoteAddr)
|
||||
})
|
||||
|
||||
// Use configured host and port
|
||||
addr := fmt.Sprintf("tcp://%s:%d", t.config.Host, t.config.Port)
|
||||
|
||||
@ -119,7 +111,7 @@ func (t *TCPSource) Start() error {
|
||||
t.logger.Info("msg", "TCP source server starting",
|
||||
"component", "tcp_source",
|
||||
"port", t.config.Port,
|
||||
"auth_enabled", t.authenticator != nil)
|
||||
)
|
||||
|
||||
err := gnet.Run(t.server, addr,
|
||||
gnet.WithLogger(gnetLogger),
|
||||
@ -150,8 +142,13 @@ func (t *TCPSource) Start() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the TCP server.
|
||||
func (t *TCPSource) Stop() {
|
||||
t.logger.Info("msg", "Stopping TCP source")
|
||||
|
||||
// Unregister callback
|
||||
t.sessionManager.UnregisterExpiryCallback("tcp_source")
|
||||
|
||||
close(t.done)
|
||||
|
||||
// Stop gnet engine if running
|
||||
@ -182,6 +179,7 @@ func (t *TCPSource) Stop() {
|
||||
t.logger.Info("msg", "TCP source stopped")
|
||||
}
|
||||
|
||||
// GetStats returns the source's statistics.
|
||||
func (t *TCPSource) GetStats() SourceStats {
|
||||
lastEntry, _ := t.lastEntryTime.Load().(time.Time)
|
||||
|
||||
@ -190,14 +188,9 @@ func (t *TCPSource) GetStats() SourceStats {
|
||||
netLimitStats = t.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var authStats map[string]any
|
||||
if t.authenticator != nil {
|
||||
authStats = map[string]any{
|
||||
"enabled": true,
|
||||
"type": t.config.Auth.Type,
|
||||
"failures": t.authFailures.Load(),
|
||||
"successes": t.authSuccesses.Load(),
|
||||
}
|
||||
var sessionStats map[string]any
|
||||
if t.sessionManager != nil {
|
||||
sessionStats = t.sessionManager.GetStats()
|
||||
}
|
||||
|
||||
return SourceStats{
|
||||
@ -211,40 +204,12 @@ func (t *TCPSource) GetStats() SourceStats {
|
||||
"active_connections": t.activeConns.Load(),
|
||||
"invalid_entries": t.invalidEntries.Load(),
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
"sessions": sessionStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TCPSource) publish(entry core.LogEntry) {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
t.totalEntries.Add(1)
|
||||
t.lastEntryTime.Store(entry.Time)
|
||||
|
||||
for _, ch := range t.subscribers {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
t.droppedEntries.Add(1)
|
||||
t.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
|
||||
"component", "tcp_source")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Represents a connected TCP client
|
||||
type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer *bytes.Buffer
|
||||
authenticated bool
|
||||
authTimeout time.Time
|
||||
session *auth.Session
|
||||
maxBufferSeen int
|
||||
}
|
||||
|
||||
// Handles gnet events
|
||||
// tcpSourceServer implements the gnet.EventHandler interface for the source.
|
||||
type tcpSourceServer struct {
|
||||
gnet.BuiltinEventEngine
|
||||
source *TCPSource
|
||||
@ -252,6 +217,15 @@ type tcpSourceServer struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// tcpClient represents a connected TCP client and its state.
|
||||
type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer *bytes.Buffer
|
||||
sessionID string
|
||||
maxBufferSeen int
|
||||
}
|
||||
|
||||
// OnBoot is called when the server starts.
|
||||
func (s *tcpSourceServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
// Store engine reference for shutdown
|
||||
s.source.engineMu.Lock()
|
||||
@ -264,6 +238,7 @@ func (s *tcpSourceServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// OnOpen is called when a new connection is established.
|
||||
func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
remoteAddr := c.RemoteAddr().String()
|
||||
s.source.logger.Debug("msg", "TCP connection attempt",
|
||||
@ -299,7 +274,6 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
}
|
||||
|
||||
// Track connection
|
||||
// s.source.netLimiter.AddConnection(remoteAddr)
|
||||
if !s.source.netLimiter.TrackConnection(ip.String(), "", "") {
|
||||
s.source.logger.Warn("msg", "TCP connection limit exceeded",
|
||||
"component", "tcp_source",
|
||||
@ -308,21 +282,14 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
sess := s.source.sessionManager.CreateSession(remoteAddr, "tcp_source", nil)
|
||||
|
||||
// Create client state
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
buffer: bytes.NewBuffer(nil),
|
||||
authenticated: s.source.authenticator == nil, // No auth = auto authenticated
|
||||
}
|
||||
|
||||
if s.source.authenticator != nil {
|
||||
// Set auth timeout
|
||||
client.authTimeout = time.Now().Add(10 * time.Second)
|
||||
|
||||
// Send auth challenge for SCRAM
|
||||
if s.source.config.Auth.Type == "scram" {
|
||||
out = []byte("AUTH_REQUIRED\n")
|
||||
}
|
||||
conn: c,
|
||||
buffer: bytes.NewBuffer(nil),
|
||||
sessionID: sess.ID,
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
@ -333,19 +300,29 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
s.source.logger.Debug("msg", "TCP connection opened",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"auth_enabled", s.source.authenticator != nil)
|
||||
"session_id", sess.ID)
|
||||
|
||||
return out, gnet.None
|
||||
}
|
||||
|
||||
// OnClose is called when a connection is closed.
|
||||
func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
remoteAddr := c.RemoteAddr().String()
|
||||
|
||||
// Get client to retrieve session ID
|
||||
s.mu.RLock()
|
||||
client, exists := s.clients[c]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists && client.sessionID != "" {
|
||||
// Remove session
|
||||
s.source.sessionManager.RemoveSession(client.sessionID)
|
||||
}
|
||||
|
||||
// Untrack connection
|
||||
if s.source.netLimiter != nil {
|
||||
if tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr); err == nil {
|
||||
s.source.netLimiter.ReleaseConnection(tcpAddr.IP.String(), "", "")
|
||||
// s.source.netLimiter.RemoveConnection(remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
@ -363,6 +340,7 @@ func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// OnTraffic is called when data is received from a connection.
|
||||
func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
s.mu.RLock()
|
||||
client, exists := s.clients[c]
|
||||
@ -372,6 +350,11 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Update session activity when client sends data
|
||||
if client.sessionID != "" {
|
||||
s.source.sessionManager.UpdateActivity(client.sessionID)
|
||||
}
|
||||
|
||||
// Read all available data
|
||||
data, err := c.Next(-1)
|
||||
if err != nil {
|
||||
@ -381,76 +364,10 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// SCRAM Authentication phase
|
||||
if !client.authenticated && s.source.scramManager != nil {
|
||||
// Check auth timeout
|
||||
if !client.authTimeout.IsZero() && time.Now().After(client.authTimeout) {
|
||||
s.source.logger.Warn("msg", "Authentication timeout",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
s.source.authFailures.Add(1)
|
||||
c.AsyncWrite([]byte("AUTH_TIMEOUT\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
client.buffer.Write(data)
|
||||
|
||||
// Use centralized SCRAM protocol handler
|
||||
if s.source.scramProtocolHandler == nil {
|
||||
s.source.scramProtocolHandler = auth.NewScramProtocolHandler(s.source.scramManager, s.source.logger)
|
||||
}
|
||||
|
||||
// Look for complete auth line
|
||||
for {
|
||||
idx := bytes.IndexByte(client.buffer.Bytes(), '\n')
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
|
||||
line := client.buffer.Bytes()[:idx]
|
||||
client.buffer.Next(idx + 1)
|
||||
|
||||
// Process auth message through handler
|
||||
authenticated, session, err := s.source.scramProtocolHandler.HandleAuthMessage(line, c)
|
||||
if err != nil {
|
||||
s.source.logger.Warn("msg", "SCRAM authentication failed",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", err)
|
||||
|
||||
if strings.Contains(err.Error(), "unknown command") {
|
||||
return gnet.Close
|
||||
}
|
||||
// Continue for other errors (might be multi-step auth)
|
||||
}
|
||||
|
||||
if authenticated && session != nil {
|
||||
// Authentication successful
|
||||
s.mu.Lock()
|
||||
client.authenticated = true
|
||||
client.session = session
|
||||
s.mu.Unlock()
|
||||
|
||||
s.source.logger.Info("msg", "Client authenticated via SCRAM",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"session_id", session.ID)
|
||||
|
||||
// Clear auth buffer
|
||||
client.buffer.Reset()
|
||||
break
|
||||
}
|
||||
}
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
return s.processLogData(c, client, data)
|
||||
}
|
||||
|
||||
// processLogData processes raw data from a client, parsing and publishing log entries.
|
||||
func (s *tcpSourceServer) processLogData(c gnet.Conn, client *tcpClient, data []byte) gnet.Action {
|
||||
// Check if appending the new data would exceed the client buffer limit.
|
||||
if client.buffer.Len()+len(data) > maxClientBufferSize {
|
||||
@ -543,3 +460,42 @@ func (s *tcpSourceServer) processLogData(c gnet.Conn, client *tcpClient, data []
|
||||
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// publish sends a log entry to all subscribers.
|
||||
func (t *TCPSource) publish(entry core.LogEntry) {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
t.totalEntries.Add(1)
|
||||
t.lastEntryTime.Store(entry.Time)
|
||||
|
||||
for _, ch := range t.subscribers {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
t.droppedEntries.Add(1)
|
||||
t.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
|
||||
"component", "tcp_source")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionExpiry is the callback for cleaning up expired sessions.
|
||||
func (t *TCPSource) handleSessionExpiry(sessionID, remoteAddr string) {
|
||||
t.server.mu.RLock()
|
||||
defer t.server.mu.RUnlock()
|
||||
|
||||
// Find connection by session ID
|
||||
for conn, client := range t.server.clients {
|
||||
if client.sessionID == sessionID {
|
||||
t.logger.Info("msg", "Closing expired session connection",
|
||||
"component", "tcp_source",
|
||||
"session_id", sessionID,
|
||||
"remote_addr", remoteAddr)
|
||||
|
||||
// Close connection
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
94
src/internal/tls/client.go
Normal file
94
src/internal/tls/client.go
Normal file
@ -0,0 +1,94 @@
|
||||
// FILE: src/internal/tls/client.go
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// ClientManager handles TLS configuration for client components.
|
||||
type ClientManager struct {
|
||||
config *config.TLSClientConfig
|
||||
tlsConfig *tls.Config
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewClientManager creates a TLS manager for clients (HTTP Client Sink).
|
||||
func NewClientManager(cfg *config.TLSClientConfig, logger *log.Logger) (*ClientManager, error) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
m := &ClientManager{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
tlsConfig: &tls.Config{
|
||||
MinVersion: parseTLSVersion(cfg.MinVersion, tls.VersionTLS12),
|
||||
MaxVersion: parseTLSVersion(cfg.MaxVersion, tls.VersionTLS13),
|
||||
},
|
||||
}
|
||||
|
||||
// Cipher suite configuration
|
||||
if cfg.CipherSuites != "" {
|
||||
m.tlsConfig.CipherSuites = parseCipherSuites(cfg.CipherSuites)
|
||||
}
|
||||
|
||||
// Load client certificate for mTLS, if provided.
|
||||
if cfg.ClientCertFile != "" && cfg.ClientKeyFile != "" {
|
||||
clientCert, err := tls.LoadX509KeyPair(cfg.ClientCertFile, cfg.ClientKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
|
||||
}
|
||||
m.tlsConfig.Certificates = []tls.Certificate{clientCert}
|
||||
} else if cfg.ClientCertFile != "" || cfg.ClientKeyFile != "" {
|
||||
return nil, fmt.Errorf("both client_cert_file and client_key_file must be provided for mTLS")
|
||||
}
|
||||
|
||||
// Load server CA for verification.
|
||||
if cfg.ServerCAFile != "" {
|
||||
caCert, err := os.ReadFile(cfg.ServerCAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read server CA file: %w", err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse server CA certificate")
|
||||
}
|
||||
m.tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
m.tlsConfig.InsecureSkipVerify = cfg.InsecureSkipVerify
|
||||
m.tlsConfig.ServerName = cfg.ServerName
|
||||
|
||||
logger.Info("msg", "TLS Client Manager initialized", "component", "tls")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetConfig returns the client's TLS configuration.
|
||||
func (m *ClientManager) GetConfig() *tls.Config {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return m.tlsConfig.Clone()
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the current client TLS configuration.
|
||||
func (m *ClientManager) GetStats() map[string]any {
|
||||
if m == nil {
|
||||
return map[string]any{"enabled": false}
|
||||
}
|
||||
return map[string]any{
|
||||
"enabled": true,
|
||||
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
|
||||
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
|
||||
"has_client_cert": m.config.ClientCertFile != "",
|
||||
"has_server_ca": m.config.ServerCAFile != "",
|
||||
"insecure_skip_verify": m.config.InsecureSkipVerify,
|
||||
}
|
||||
}
|
||||
@ -1,236 +0,0 @@
|
||||
// FILE: logwisp/src/internal/tls/manager.go
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// Handles TLS configuration for servers
|
||||
type Manager struct {
|
||||
config *config.TLSConfig
|
||||
tlsConfig *tls.Config
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a TLS configuration from TLS config
|
||||
func NewManager(cfg *config.TLSConfig, logger *log.Logger) (*Manager, error) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Load certificate and key
|
||||
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load cert/key: %w", err)
|
||||
}
|
||||
|
||||
// Create base TLS config
|
||||
m.tlsConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: parseTLSVersion(cfg.MinVersion, tls.VersionTLS12),
|
||||
MaxVersion: parseTLSVersion(cfg.MaxVersion, tls.VersionTLS13),
|
||||
}
|
||||
|
||||
// Configure cipher suites if specified
|
||||
if cfg.CipherSuites != "" {
|
||||
m.tlsConfig.CipherSuites = parseCipherSuites(cfg.CipherSuites)
|
||||
} else {
|
||||
// Use secure defaults
|
||||
m.tlsConfig.CipherSuites = []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
}
|
||||
}
|
||||
|
||||
// Configure client authentication (mTLS)
|
||||
if cfg.ClientAuth {
|
||||
if cfg.VerifyClientCert {
|
||||
m.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
} else {
|
||||
m.tlsConfig.ClientAuth = tls.RequireAnyClientCert
|
||||
}
|
||||
|
||||
// Load client CA if specified
|
||||
if cfg.ClientCAFile != "" {
|
||||
caCert, err := os.ReadFile(cfg.ClientCAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read client CA: %w", err)
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse client CA certificate")
|
||||
}
|
||||
m.tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
}
|
||||
|
||||
// Set secure defaults
|
||||
m.tlsConfig.SessionTicketsDisabled = false
|
||||
m.tlsConfig.Renegotiation = tls.RenegotiateNever
|
||||
|
||||
logger.Info("msg", "TLS manager initialized",
|
||||
"component", "tls",
|
||||
"min_version", cfg.MinVersion,
|
||||
"max_version", cfg.MaxVersion,
|
||||
"client_auth", cfg.ClientAuth,
|
||||
"cipher_count", len(m.tlsConfig.CipherSuites))
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Returns the TLS configuration
|
||||
func (m *Manager) GetConfig() *tls.Config {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
// Return a clone to prevent modification
|
||||
return m.tlsConfig.Clone()
|
||||
}
|
||||
|
||||
// Returns TLS config suitable for HTTP servers
|
||||
func (m *Manager) GetHTTPConfig() *tls.Config {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg := m.tlsConfig.Clone()
|
||||
// Enable HTTP/2
|
||||
cfg.NextProtos = []string{"h2", "http/1.1"}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Validates a client certificate for mTLS
|
||||
func (m *Manager) ValidateClientCert(rawCerts [][]byte) error {
|
||||
if m == nil || !m.config.ClientAuth {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(rawCerts) == 0 {
|
||||
return fmt.Errorf("no client certificate provided")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse client certificate: %w", err)
|
||||
}
|
||||
|
||||
// Verify against CA if configured
|
||||
if m.tlsConfig.ClientCAs != nil {
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: m.tlsConfig.ClientCAs,
|
||||
Intermediates: x509.NewCertPool(),
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
// Add any intermediate certs
|
||||
for i := 1; i < len(rawCerts); i++ {
|
||||
intermediate, err := x509.ParseCertificate(rawCerts[i])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
opts.Intermediates.AddCert(intermediate)
|
||||
}
|
||||
|
||||
if _, err := cert.Verify(opts); err != nil {
|
||||
return fmt.Errorf("client certificate verification failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Debug("msg", "Client certificate validated",
|
||||
"component", "tls",
|
||||
"subject", cert.Subject.String(),
|
||||
"serial", cert.SerialNumber.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns TLS statistics
|
||||
func (m *Manager) GetStats() map[string]any {
|
||||
if m == nil {
|
||||
return map[string]any{"enabled": false}
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"enabled": true,
|
||||
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
|
||||
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
|
||||
"client_auth": m.config.ClientAuth,
|
||||
"cipher_suites": len(m.tlsConfig.CipherSuites),
|
||||
}
|
||||
}
|
||||
|
||||
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
|
||||
switch strings.ToUpper(version) {
|
||||
case "TLS1.0", "TLS10":
|
||||
return tls.VersionTLS10
|
||||
case "TLS1.1", "TLS11":
|
||||
return tls.VersionTLS11
|
||||
case "TLS1.2", "TLS12":
|
||||
return tls.VersionTLS12
|
||||
case "TLS1.3", "TLS13":
|
||||
return tls.VersionTLS13
|
||||
default:
|
||||
return defaultVersion
|
||||
}
|
||||
}
|
||||
|
||||
func parseCipherSuites(suites string) []uint16 {
|
||||
var result []uint16
|
||||
|
||||
// Map of cipher suite names to IDs
|
||||
suiteMap := map[string]uint16{
|
||||
// TLS 1.2 ECDHE suites (preferred)
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
|
||||
// RSA suites (less preferred)
|
||||
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
}
|
||||
|
||||
for _, suite := range strings.Split(suites, ",") {
|
||||
suite = strings.TrimSpace(suite)
|
||||
if id, ok := suiteMap[suite]; ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func tlsVersionString(version uint16) string {
|
||||
switch version {
|
||||
case tls.VersionTLS10:
|
||||
return "TLS1.0"
|
||||
case tls.VersionTLS11:
|
||||
return "TLS1.1"
|
||||
case tls.VersionTLS12:
|
||||
return "TLS1.2"
|
||||
case tls.VersionTLS13:
|
||||
return "TLS1.3"
|
||||
default:
|
||||
return fmt.Sprintf("0x%04x", version)
|
||||
}
|
||||
}
|
||||
69
src/internal/tls/parse.go
Normal file
69
src/internal/tls/parse.go
Normal file
@ -0,0 +1,69 @@
|
||||
// FILE: logwisp/src/internal/tls/parse.go
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// parseTLSVersion converts a string representation (e.g., "TLS1.2") into a Go crypto/tls constant.
|
||||
func parseTLSVersion(version string, defaultVersion uint16) uint16 {
|
||||
switch strings.ToUpper(version) {
|
||||
case "TLS1.0", "TLS10":
|
||||
return tls.VersionTLS10
|
||||
case "TLS1.1", "TLS11":
|
||||
return tls.VersionTLS11
|
||||
case "TLS1.2", "TLS12":
|
||||
return tls.VersionTLS12
|
||||
case "TLS1.3", "TLS13":
|
||||
return tls.VersionTLS13
|
||||
default:
|
||||
return defaultVersion
|
||||
}
|
||||
}
|
||||
|
||||
// parseCipherSuites converts a comma-separated string of cipher suite names into a slice of Go constants.
|
||||
func parseCipherSuites(suites string) []uint16 {
|
||||
var result []uint16
|
||||
|
||||
// Map of cipher suite names to IDs
|
||||
suiteMap := map[string]uint16{
|
||||
// TLS 1.2 ECDHE suites (preferred)
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
|
||||
// RSA suites
|
||||
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
}
|
||||
|
||||
for _, suite := range strings.Split(suites, ",") {
|
||||
suite = strings.TrimSpace(suite)
|
||||
if id, ok := suiteMap[suite]; ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// tlsVersionString converts a Go crypto/tls version constant back into a string representation.
|
||||
func tlsVersionString(version uint16) string {
|
||||
switch version {
|
||||
case tls.VersionTLS10:
|
||||
return "TLS1.0"
|
||||
case tls.VersionTLS11:
|
||||
return "TLS1.1"
|
||||
case tls.VersionTLS12:
|
||||
return "TLS1.2"
|
||||
case tls.VersionTLS13:
|
||||
return "TLS1.3"
|
||||
default:
|
||||
return fmt.Sprintf("0x%04x", version)
|
||||
}
|
||||
}
|
||||
99
src/internal/tls/server.go
Normal file
99
src/internal/tls/server.go
Normal file
@ -0,0 +1,99 @@
|
||||
// FILE: src/internal/tls/server.go
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
|
||||
// ServerManager handles TLS configuration for server components.
|
||||
type ServerManager struct {
|
||||
config *config.TLSServerConfig
|
||||
tlsConfig *tls.Config
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewServerManager creates a TLS manager for servers (HTTP Source/Sink).
|
||||
func NewServerManager(cfg *config.TLSServerConfig, logger *log.Logger) (*ServerManager, error) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
m := &ServerManager{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load server cert/key: %w", err)
|
||||
}
|
||||
|
||||
// Enforce TLS 1.2 / TLS 1.3
|
||||
m.tlsConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: parseTLSVersion(cfg.MinVersion, tls.VersionTLS12),
|
||||
MaxVersion: parseTLSVersion(cfg.MaxVersion, tls.VersionTLS13),
|
||||
}
|
||||
|
||||
if cfg.CipherSuites != "" {
|
||||
m.tlsConfig.CipherSuites = parseCipherSuites(cfg.CipherSuites)
|
||||
} else {
|
||||
// Use secure defaults
|
||||
m.tlsConfig.CipherSuites = []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
}
|
||||
}
|
||||
|
||||
// Configure client authentication (mTLS)
|
||||
if cfg.ClientAuth {
|
||||
if cfg.ClientCAFile == "" {
|
||||
return nil, fmt.Errorf("client_auth is enabled but client_ca_file is not specified")
|
||||
}
|
||||
caCert, err := os.ReadFile(cfg.ClientCAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read client CA file: %w", err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse client CA certificate")
|
||||
}
|
||||
m.tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
logger.Info("msg", "TLS Server Manager initialized", "component", "tls")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetHTTPConfig returns a TLS configuration suitable for HTTP servers.
|
||||
func (m *ServerManager) GetHTTPConfig() *tls.Config {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
cfg := m.tlsConfig.Clone()
|
||||
cfg.NextProtos = []string{"h2", "http/1.1"}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the current server TLS configuration.
|
||||
func (m *ServerManager) GetStats() map[string]any {
|
||||
if m == nil {
|
||||
return map[string]any{"enabled": false}
|
||||
}
|
||||
return map[string]any{
|
||||
"enabled": true,
|
||||
"min_version": tlsVersionString(m.tlsConfig.MinVersion),
|
||||
"max_version": tlsVersionString(m.tlsConfig.MaxVersion),
|
||||
"client_auth": m.config.ClientAuth,
|
||||
"cipher_suites": len(m.tlsConfig.CipherSuites),
|
||||
}
|
||||
}
|
||||
@ -4,13 +4,15 @@ package version
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
// Version is set at compile time via -ldflags
|
||||
Version = "dev"
|
||||
// Version is the application version, set at compile time via -ldflags.
|
||||
Version = "dev"
|
||||
// GitCommit is the git commit hash, set at compile time.
|
||||
GitCommit = "unknown"
|
||||
// BuildTime is the application build time, set at compile time.
|
||||
BuildTime = "unknown"
|
||||
)
|
||||
|
||||
// Returns a formatted version string
|
||||
// String returns a detailed, formatted version string including commit and build time.
|
||||
func String() string {
|
||||
if Version == "dev" {
|
||||
return fmt.Sprintf("dev (commit: %s, built: %s)", GitCommit, BuildTime)
|
||||
@ -18,7 +20,7 @@ func String() string {
|
||||
return fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime)
|
||||
}
|
||||
|
||||
// Returns just the version tag
|
||||
// Short returns just the version tag.
|
||||
func Short() string {
|
||||
return Version
|
||||
}
|
||||
@ -1,12 +0,0 @@
|
||||
### Usage:
|
||||
|
||||
- Copy logwisp executable to the test folder (to compile, in logwisp top directory: `make build`).
|
||||
|
||||
- Run the test script for each scenario.
|
||||
|
||||
### Notes:
|
||||
|
||||
- The tests create configuration files and log files. Most tests set logging at debug level and don't clean up their temp files that are created in the current execution directory.
|
||||
|
||||
- Some tests may need to be run on different hosts (containers can be used).
|
||||
|
||||
Reference in New Issue
Block a user