v0.7.0 major configuration and sub-command restructuring, not tested, docs and default config outdated

This commit is contained in:
2025-10-09 09:35:21 -04:00
parent 490fb777ab
commit 89e6a4ea05
61 changed files with 3248 additions and 4571 deletions

13
go.mod
View File

@ -3,14 +3,12 @@ module logwisp
go 1.25.1 go 1.25.1
require ( require (
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2
github.com/panjf2000/gnet/v2 v2.9.4 github.com/panjf2000/gnet/v2 v2.9.4
github.com/stretchr/testify v1.10.0 github.com/valyala/fasthttp v1.67.0
github.com/valyala/fasthttp v1.66.0 golang.org/x/crypto v0.43.0
golang.org/x/crypto v0.42.0 golang.org/x/term v0.36.0
golang.org/x/term v0.35.0
golang.org/x/time v0.13.0
) )
require ( require (
@ -20,12 +18,11 @@ require (
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/panjf2000/ants/v2 v2.11.3 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect go.uber.org/zap v1.27.0 // indirect
golang.org/x/sync v0.17.0 // indirect golang.org/x/sync v0.17.0 // indirect
golang.org/x/sys v0.36.0 // indirect golang.org/x/sys v0.37.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

22
go.sum
View File

@ -8,8 +8,8 @@ github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGL
github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw= github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0= github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6 h1:G9qP8biXBT6bwBOjEe1tZwjA0gPuB5DC+fLBRXDNXqo=
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0= github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 h1:9Qf+BR83sKjok2E1Nct+3Sfzoj2dLGwC/zyQDVNmmqs= github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 h1:9Qf+BR83sKjok2E1Nct+3Sfzoj2dLGwC/zyQDVNmmqs=
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0= github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
@ -22,8 +22,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.66.0 h1:M87A0Z7EayeyNaV6pfO3tUTUiYO0dZfEJnRGXTVNuyU= github.com/valyala/fasthttp v1.67.0 h1:tqKlJMUP6iuNG8hGjK/s9J4kadH7HLV4ijEcPGsezac=
github.com/valyala/fasthttp v1.66.0/go.mod h1:Y4eC+zwoocmXSVCB1JmhNbYtS7tZPRI2ztPB72EVObs= github.com/valyala/fasthttp v1.67.0/go.mod h1:qYSIpqt/0XNmShgo/8Aq8E3UYWVVwNS2QYmzd8WIEPM=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
@ -32,16 +32,14 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=

View File

@ -24,7 +24,7 @@ func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service
logger.Info("msg", "Initializing pipeline", "pipeline", pipelineCfg.Name) logger.Info("msg", "Initializing pipeline", "pipeline", pipelineCfg.Name)
// Create the pipeline // Create the pipeline
if err := svc.NewPipeline(pipelineCfg); err != nil { if err := svc.NewPipeline(&pipelineCfg); err != nil {
logger.Error("msg", "Failed to create pipeline", logger.Error("msg", "Failed to create pipeline",
"pipeline", pipelineCfg.Name, "pipeline", pipelineCfg.Name,
"error", err) "error", err)

View File

@ -1,138 +0,0 @@
// FILE: src/cmd/logwisp/commands.go
package main
import (
"fmt"
"os"
"logwisp/src/internal/auth"
"logwisp/src/internal/tls"
"logwisp/src/internal/version"
)
// Handles subcommand routing before main app initialization
type CommandRouter struct {
commands map[string]CommandHandler
}
// Defines the interface for subcommands
type CommandHandler interface {
Execute(args []string) error
Description() string
}
// Creates and initializes the command router
func NewCommandRouter() *CommandRouter {
router := &CommandRouter{
commands: make(map[string]CommandHandler),
}
// Register available commands
router.commands["auth"] = &authCommand{}
router.commands["version"] = &versionCommand{}
router.commands["help"] = &helpCommand{}
router.commands["tls"] = &tlsCommand{}
return router
}
// Checks for and executes subcommands
func (r *CommandRouter) Route(args []string) error {
if len(args) < 1 {
return nil
}
// Check for help flags anywhere in args
for _, arg := range args[1:] { // Skip program name
if arg == "-h" || arg == "--help" || arg == "help" {
// Show main help and exit regardless of other flags
r.commands["help"].Execute(nil)
os.Exit(0)
}
}
// Check for commands
if len(args) > 1 {
cmdName := args[1]
if handler, exists := r.commands[cmdName]; exists {
if err := handler.Execute(args[2:]); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
os.Exit(0)
}
// Check if it looks like a mistyped command (not a flag)
if cmdName[0] != '-' {
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", cmdName)
fmt.Fprintln(os.Stderr, "\nAvailable commands:")
r.ShowCommands()
os.Exit(1)
}
}
return nil
}
// Displays available subcommands
func (r *CommandRouter) ShowCommands() {
fmt.Fprintln(os.Stderr, " auth Generate authentication credentials")
fmt.Fprintln(os.Stderr, " tls Generate TLS certificates")
fmt.Fprintln(os.Stderr, " version Show version information")
fmt.Fprintln(os.Stderr, " help Display help information")
fmt.Fprintln(os.Stderr, "\nUse 'logwisp <command> --help' for command-specific help")
}
// TODO: Future: refactor with a new command interface
type helpCommand struct{}
func (c *helpCommand) Execute(args []string) error {
// Check if help is requested for a specific command
if len(args) > 0 {
// TODO: Future: show command-specific help
// For now, just show general help
}
fmt.Print(helpText)
return nil
}
func (c *helpCommand) Description() string {
return "Display help information"
}
// authCommand wrapper
type authCommand struct{}
func (c *authCommand) Execute(args []string) error {
gen := auth.NewAuthGeneratorCommand()
return gen.Execute(args)
}
func (c *authCommand) Description() string {
return "Generate authentication credentials (passwords, tokens)"
}
// versionCommand wrapper
type versionCommand struct{}
func (c *versionCommand) Execute(args []string) error {
fmt.Println(version.String())
return nil
}
func (c *versionCommand) Description() string {
return "Show version information"
}
// tlsCommand wrapper
type tlsCommand struct{}
func (c *tlsCommand) Execute(args []string) error {
gen := tls.NewCertGeneratorCommand()
return gen.Execute(args)
}
func (c *tlsCommand) Description() string {
return "Generate TLS certificates (CA, server, client)"
}

View File

@ -0,0 +1,355 @@
// FILE: src/cmd/logwisp/commands/auth.go
package commands
import (
"crypto/rand"
"encoding/base64"
"flag"
"fmt"
"io"
"os"
"strings"
"syscall"
"logwisp/src/internal/auth"
"logwisp/src/internal/core"
"golang.org/x/term"
)
type AuthCommand struct {
output io.Writer
errOut io.Writer
}
func NewAuthCommand() *AuthCommand {
return &AuthCommand{
output: os.Stdout,
errOut: os.Stderr,
}
}
func (ac *AuthCommand) Execute(args []string) error {
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
cmd.SetOutput(ac.errOut)
var (
// User credentials
username = cmd.String("u", "", "Username")
usernameLong = cmd.String("user", "", "Username")
password = cmd.String("p", "", "Password (will prompt if not provided)")
passwordLong = cmd.String("password", "", "Password (will prompt if not provided)")
// Auth type selection (multiple ways to specify)
authType = cmd.String("t", "", "Auth type: basic, scram, or token")
authTypeLong = cmd.String("type", "", "Auth type: basic, scram, or token")
useScram = cmd.Bool("s", false, "Generate SCRAM credentials (TCP)")
useScramLong = cmd.Bool("scram", false, "Generate SCRAM credentials (TCP)")
useBasic = cmd.Bool("b", false, "Generate basic auth credentials (HTTP)")
useBasicLong = cmd.Bool("basic", false, "Generate basic auth credentials (HTTP)")
// Token generation
genToken = cmd.Bool("k", false, "Generate random bearer token")
genTokenLong = cmd.Bool("token", false, "Generate random bearer token")
tokenLen = cmd.Int("l", 32, "Token length in bytes")
tokenLenLong = cmd.Int("length", 32, "Token length in bytes")
// Migration option
migrate = cmd.Bool("m", false, "Convert basic auth PHC to SCRAM")
migrateLong = cmd.Bool("migrate", false, "Convert basic auth PHC to SCRAM")
phcHash = cmd.String("phc", "", "PHC hash to migrate (required with --migrate)")
)
cmd.Usage = func() {
fmt.Fprintln(ac.errOut, "Generate authentication credentials for LogWisp")
fmt.Fprintln(ac.errOut, "\nUsage: logwisp auth [options]")
fmt.Fprintln(ac.errOut, "\nExamples:")
fmt.Fprintln(ac.errOut, " # Generate basic auth hash for HTTP sources/sinks")
fmt.Fprintln(ac.errOut, " logwisp auth -u admin -b")
fmt.Fprintln(ac.errOut, " logwisp auth --user=admin --basic")
fmt.Fprintln(ac.errOut, " ")
fmt.Fprintln(ac.errOut, " # Generate SCRAM credentials for TCP")
fmt.Fprintln(ac.errOut, " logwisp auth -u tcpuser -s")
fmt.Fprintln(ac.errOut, " logwisp auth --user=tcpuser --scram")
fmt.Fprintln(ac.errOut, " ")
fmt.Fprintln(ac.errOut, " # Generate bearer token")
fmt.Fprintln(ac.errOut, " logwisp auth -k -l 64")
fmt.Fprintln(ac.errOut, " logwisp auth --token --length=64")
fmt.Fprintln(ac.errOut, "\nOptions:")
cmd.PrintDefaults()
}
if err := cmd.Parse(args); err != nil {
return err
}
// Check for unparsed arguments
if cmd.NArg() > 0 {
return fmt.Errorf("unexpected argument(s): %s", strings.Join(cmd.Args(), " "))
}
// Merge short and long form values
finalUsername := coalesceString(*username, *usernameLong)
finalPassword := coalesceString(*password, *passwordLong)
finalAuthType := coalesceString(*authType, *authTypeLong)
finalGenToken := coalesceBool(*genToken, *genTokenLong)
finalTokenLen := coalesceInt(*tokenLen, *tokenLenLong, core.DefaultTokenLength)
finalUseScram := coalesceBool(*useScram, *useScramLong)
finalUseBasic := coalesceBool(*useBasic, *useBasicLong)
finalMigrate := coalesceBool(*migrate, *migrateLong)
// Handle migration mode
if finalMigrate {
if *phcHash == "" || finalUsername == "" || finalPassword == "" {
return fmt.Errorf("--migrate requires --user, --password, and --phc flags")
}
return ac.migrateToScram(finalUsername, finalPassword, *phcHash)
}
// Determine auth type from flags
if finalGenToken || finalAuthType == "token" {
return ac.generateToken(finalTokenLen)
}
// Determine credential type
credType := "basic" // default
// Check explicit type flags
if finalUseScram || finalAuthType == "scram" {
credType = "scram"
} else if finalUseBasic || finalAuthType == "basic" {
credType = "basic"
} else if finalAuthType != "" {
return fmt.Errorf("invalid auth type: %s (valid: basic, scram, token)", finalAuthType)
}
// Username required for password-based auth
if finalUsername == "" {
cmd.Usage()
return fmt.Errorf("username required for %s auth generation", credType)
}
return ac.generatePasswordHash(finalUsername, finalPassword, credType)
}
func (ac *AuthCommand) Description() string {
return "Generate authentication credentials (passwords, tokens, SCRAM)"
}
func (ac *AuthCommand) Help() string {
return `Auth Command - Generate authentication credentials for LogWisp
Usage:
logwisp auth [options]
Authentication Types:
HTTP/HTTPS Sources & Sinks (TLS required):
- Basic Auth: Username/password with Argon2id hashing
- Bearer Token: Random cryptographic tokens
TCP Sources & Sinks (No TLS):
- SCRAM: Argon2-SCRAM-SHA256 for plaintext connections
Options:
-u, --user <name> Username for credential generation
-p, --password <pass> Password (will prompt if not provided)
-t, --type <type> Auth type: "basic", "scram", or "token"
-b, --basic Generate basic auth credentials (HTTP/HTTPS)
-s, --scram Generate SCRAM credentials (TCP)
-k, --token Generate random bearer token
-l, --length <bytes> Token length in bytes (default: 32)
Examples:
Examples:
# Generate basic auth hash for HTTP/HTTPS (with TLS)
logwisp auth -u admin -b
logwisp auth --user=admin --basic
# Generate SCRAM credentials for TCP (without TLS)
logwisp auth -u tcpuser -s
logwisp auth --user=tcpuser --type=scram
# Generate 64-byte bearer token
logwisp auth -k -l 64
logwisp auth --token --length=64
# Convert existing basic auth to SCRAM (HTTPS to TCP conversion)
logwisp auth -u admin -m --phc='$argon2id$v=19$m=65536...' --password='secret'
Output:
The command outputs configuration snippets ready to paste into logwisp.toml
and the raw credential values for external auth files.
Security Notes:
- Basic auth and tokens require TLS encryption for HTTP connections
- SCRAM provides authentication but NOT encryption for TCP connections
- Use strong passwords (12+ characters with mixed case, numbers, symbols)
- Store credentials securely and never commit them to version control
`
}
func (ac *AuthCommand) generatePasswordHash(username, password, credType string) error {
// Get password if not provided
if password == "" {
var err error
password, err = ac.promptForPassword()
if err != nil {
return err
}
}
switch credType {
case "basic":
return ac.generateBasicAuth(username, password)
case "scram":
return ac.generateScramAuth(username, password)
default:
return fmt.Errorf("invalid credential type: %s", credType)
}
}
// promptForPassword handles password prompting with confirmation
func (ac *AuthCommand) promptForPassword() (string, error) {
pass1 := ac.promptPassword("Enter password: ")
pass2 := ac.promptPassword("Confirm password: ")
if pass1 != pass2 {
return "", fmt.Errorf("passwords don't match")
}
return pass1, nil
}
func (ac *AuthCommand) promptPassword(prompt string) string {
fmt.Fprint(ac.errOut, prompt)
password, err := term.ReadPassword(syscall.Stdin)
fmt.Fprintln(ac.errOut)
if err != nil {
fmt.Fprintf(ac.errOut, "Failed to read password: %v\n", err)
os.Exit(1)
}
return string(password)
}
// generateBasicAuth creates Argon2id hash for HTTP basic auth
func (ac *AuthCommand) generateBasicAuth(username, password string) error {
// Generate salt
salt := make([]byte, core.Argon2SaltLen)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("failed to generate salt: %w", err)
}
// Generate Argon2id hash
cred, err := auth.DeriveCredential(username, password, salt,
core.Argon2Time, core.Argon2Memory, core.Argon2Threads)
if err != nil {
return fmt.Errorf("failed to derive credential: %w", err)
}
// Output configuration snippets
fmt.Fprintln(ac.output, "\n# Basic Auth Configuration (HTTP sources/sinks)")
fmt.Fprintln(ac.output, "# REQUIRES HTTPS/TLS for security")
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
fmt.Fprintln(ac.output, "")
fmt.Fprintln(ac.output, "[pipelines.auth]")
fmt.Fprintln(ac.output, `type = "basic"`)
fmt.Fprintln(ac.output, "")
fmt.Fprintln(ac.output, "[[pipelines.auth.basic_auth.users]]")
fmt.Fprintf(ac.output, "username = %q\n", username)
fmt.Fprintf(ac.output, "password_hash = %q\n\n", cred.PHCHash)
fmt.Fprintln(ac.output, "# For external users file:")
fmt.Fprintf(ac.output, "%s:%s\n", username, cred.PHCHash)
return nil
}
// generateScramAuth creates Argon2id-SCRAM-SHA256 credentials for TCP
func (ac *AuthCommand) generateScramAuth(username, password string) error {
// Generate salt
salt := make([]byte, core.Argon2SaltLen)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("failed to generate salt: %w", err)
}
// Use internal auth package to derive SCRAM credentials
cred, err := auth.DeriveCredential(username, password, salt,
core.Argon2Time, core.Argon2Memory, core.Argon2Threads)
if err != nil {
return fmt.Errorf("failed to derive SCRAM credential: %w", err)
}
// Output SCRAM configuration
fmt.Fprintln(ac.output, "\n# SCRAM Auth Configuration (TCP sources/sinks)")
fmt.Fprintln(ac.output, "# Provides authentication but NOT encryption")
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
fmt.Fprintln(ac.output, "")
fmt.Fprintln(ac.output, "[pipelines.auth]")
fmt.Fprintln(ac.output, `type = "scram"`)
fmt.Fprintln(ac.output, "")
fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]")
fmt.Fprintf(ac.output, "username = %q\n", username)
fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime)
fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory)
fmt.Fprintf(ac.output, "argon_threads = %d\n\n", cred.ArgonThreads)
fmt.Fprintln(ac.output, "# Note: SCRAM provides authentication only.")
fmt.Fprintln(ac.output, "# Use TLS/mTLS for encryption if needed.")
return nil
}
func (ac *AuthCommand) generateToken(length int) error {
if length < 16 {
fmt.Fprintln(ac.errOut, "Warning: tokens < 16 bytes are cryptographically weak")
}
if length > 512 {
return fmt.Errorf("token length exceeds maximum (512 bytes)")
}
token := make([]byte, length)
if _, err := rand.Read(token); err != nil {
return fmt.Errorf("failed to generate random bytes: %w", err)
}
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
hex := fmt.Sprintf("%x", token)
fmt.Fprintln(ac.output, "\n# Token Configuration")
fmt.Fprintln(ac.output, "# Add to logwisp.toml:")
fmt.Fprintf(ac.output, "tokens = [%q]\n\n", b64)
fmt.Fprintln(ac.output, "# Generated Token:")
fmt.Fprintf(ac.output, "Base64: %s\n", b64)
fmt.Fprintf(ac.output, "Hex: %s\n", hex)
return nil
}
// migrateToScram converts basic auth PHC hash to SCRAM credentials
func (ac *AuthCommand) migrateToScram(username, password, phcHash string) error {
// CHANGED: Moved from internal/auth to CLI command layer
cred, err := auth.MigrateFromPHC(username, password, phcHash)
if err != nil {
return fmt.Errorf("migration failed: %w", err)
}
// Output SCRAM configuration (reuse format from generateScramAuth)
fmt.Fprintln(ac.output, "\n# Migrated SCRAM Credentials")
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
fmt.Fprintln(ac.output, "")
fmt.Fprintln(ac.output, "[pipelines.auth]")
fmt.Fprintln(ac.output, `type = "scram"`)
fmt.Fprintln(ac.output, "")
fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]")
fmt.Fprintf(ac.output, "username = %q\n", username)
fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime)
fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory)
fmt.Fprintf(ac.output, "argon_threads = %d\n", cred.ArgonThreads)
return nil
}

View File

@ -0,0 +1,123 @@
// FILE: src/cmd/logwisp/commands/help.go
package commands
import (
"fmt"
"sort"
"strings"
)
const generalHelpTemplate = `LogWisp: A flexible log transport and processing tool.
Usage:
logwisp [command] [options]
logwisp [options]
Commands:
%s
Application Options:
-c, --config <path> Path to configuration file (default: logwisp.toml)
-h, --help Display this help message and exit
-v, --version Display version information and exit
-b, --background Run LogWisp in the background as a daemon
-q, --quiet Suppress all console output, including errors
Runtime Options:
--disable-status-reporter Disable the periodic status reporter
--config-auto-reload Enable config reload on file change
For command-specific help:
logwisp help <command>
logwisp <command> --help
Configuration Sources (Precedence: CLI > Env > File > Defaults):
- CLI flags override all other settings
- Environment variables override file settings
- TOML configuration file is the primary method
Examples:
# Generate password for admin user
logwisp auth -u admin
# Start service with custom config
logwisp -c /etc/logwisp/prod.toml
# Run in background with config reload
logwisp -b --config-auto-reload
For detailed configuration options, please refer to the documentation.
`
// HelpCommand handles help display
type HelpCommand struct {
router *CommandRouter
}
// NewHelpCommand creates a new help command
func NewHelpCommand(router *CommandRouter) *HelpCommand {
return &HelpCommand{router: router}
}
// Execute displays help information
func (c *HelpCommand) Execute(args []string) error {
// Check if help is requested for a specific command
if len(args) > 0 && args[0] != "" {
cmdName := args[0]
if handler, exists := c.router.GetCommand(cmdName); exists {
fmt.Print(handler.Help())
return nil
}
return fmt.Errorf("unknown command: %s", cmdName)
}
// Display general help with command list
fmt.Printf(generalHelpTemplate, c.formatCommandList())
return nil
}
// formatCommandList creates a formatted list of available commands
func (c *HelpCommand) formatCommandList() string {
commands := c.router.GetCommands()
// Sort command names for consistent output
names := make([]string, 0, len(commands))
maxLen := 0
for name := range commands {
names = append(names, name)
if len(name) > maxLen {
maxLen = len(name)
}
}
sort.Strings(names)
// Format each command with aligned descriptions
var lines []string
for _, name := range names {
handler := commands[name]
padding := strings.Repeat(" ", maxLen-len(name)+2)
lines = append(lines, fmt.Sprintf(" %s%s%s", name, padding, handler.Description()))
}
return strings.Join(lines, "\n")
}
func (c *HelpCommand) Description() string {
return "Display help information"
}
func (c *HelpCommand) Help() string {
return `Help Command - Display help information
Usage:
logwisp help Show general help
logwisp help <command> Show help for a specific command
Examples:
logwisp help # Show general help
logwisp help auth # Show auth command help
logwisp auth --help # Alternative way to get command help
`
}

View File

@ -0,0 +1,118 @@
// FILE: src/cmd/logwisp/commands/router.go
package commands
import (
"fmt"
"os"
)
// Handler defines the interface for subcommands
type Handler interface {
Execute(args []string) error
Description() string
Help() string
}
// CommandRouter handles subcommand routing before main app initialization
type CommandRouter struct {
commands map[string]Handler
}
// NewCommandRouter creates and initializes the command router
func NewCommandRouter() *CommandRouter {
router := &CommandRouter{
commands: make(map[string]Handler),
}
// Register available commands
router.commands["auth"] = NewAuthCommand()
router.commands["tls"] = NewTLSCommand()
router.commands["version"] = NewVersionCommand()
router.commands["help"] = NewHelpCommand(router)
return router
}
// Route checks for and executes subcommands
func (r *CommandRouter) Route(args []string) (bool, error) {
if len(args) < 2 {
return false, nil // No command specified, let main app continue
}
cmdName := args[1]
// Special case: help flag at any position shows general help
for _, arg := range args[1:] {
if arg == "-h" || arg == "--help" {
// If it's after a valid command, show command-specific help
if handler, exists := r.commands[cmdName]; exists && cmdName != "help" {
fmt.Print(handler.Help())
return true, nil
}
// Otherwise show general help
return true, r.commands["help"].Execute(nil)
}
}
// Check if this is a known command
handler, exists := r.commands[cmdName]
if !exists {
// Check if it looks like a mistyped command (not a flag)
if cmdName[0] != '-' {
return false, fmt.Errorf("unknown command: %s\n\nRun 'logwisp help' for usage", cmdName)
}
// It's a flag, let main app handle it
return false, nil
}
// Execute the command
return true, handler.Execute(args[2:])
}
// GetCommand returns a command handler by name
func (r *CommandRouter) GetCommand(name string) (Handler, bool) {
cmd, exists := r.commands[name]
return cmd, exists
}
// GetCommands returns all registered commands
func (r *CommandRouter) GetCommands() map[string]Handler {
return r.commands
}
// ShowCommands displays available subcommands
func (r *CommandRouter) ShowCommands() {
for name, handler := range r.commands {
fmt.Fprintf(os.Stderr, " %-10s %s\n", name, handler.Description())
}
fmt.Fprintln(os.Stderr, "\nUse 'logwisp <command> --help' for command-specific help")
}
// Helper functions to merge short and long options
func coalesceString(values ...string) string {
for _, v := range values {
if v != "" {
return v
}
}
return ""
}
func coalesceInt(primary, secondary, defaultVal int) int {
if primary != defaultVal {
return primary
}
if secondary != defaultVal {
return secondary
}
return defaultVal
}
func coalesceBool(values ...bool) bool {
for _, v := range values {
if v {
return true
}
}
return false
}

View File

@ -1,5 +1,5 @@
// FILE: src/internal/tls/generator.go // FILE: src/cmd/logwisp/commands/tls.go
package tls package commands
import ( import (
"crypto/rand" "crypto/rand"
@ -17,40 +17,50 @@ import (
"time" "time"
) )
type CertGeneratorCommand struct { type TLSCommand struct {
output io.Writer output io.Writer
errOut io.Writer errOut io.Writer
} }
func NewCertGeneratorCommand() *CertGeneratorCommand { func NewTLSCommand() *TLSCommand {
return &CertGeneratorCommand{ return &TLSCommand{
output: os.Stdout, output: os.Stdout,
errOut: os.Stderr, errOut: os.Stderr,
} }
} }
func (cg *CertGeneratorCommand) Execute(args []string) error { func (tc *TLSCommand) Execute(args []string) error {
cmd := flag.NewFlagSet("tls", flag.ContinueOnError) cmd := flag.NewFlagSet("tls", flag.ContinueOnError)
cmd.SetOutput(cg.errOut) cmd.SetOutput(tc.errOut)
// Subcommands // Certificate type flags
var ( var (
genCA = cmd.Bool("ca", false, "Generate CA certificate") genCA = cmd.Bool("ca", false, "Generate CA certificate")
genServer = cmd.Bool("server", false, "Generate server certificate") genServer = cmd.Bool("server", false, "Generate server certificate")
genClient = cmd.Bool("client", false, "Generate client certificate") genClient = cmd.Bool("client", false, "Generate client certificate")
selfSign = cmd.Bool("self-signed", false, "Generate self-signed certificate") selfSign = cmd.Bool("self-signed", false, "Generate self-signed certificate")
// Common options // Common options - short forms
commonName = cmd.String("cn", "", "Common name (required)") commonName = cmd.String("cn", "", "Common name (required)")
org = cmd.String("org", "LogWisp", "Organization") org = cmd.String("o", "LogWisp", "Organization")
country = cmd.String("country", "US", "Country code") country = cmd.String("c", "US", "Country code")
validDays = cmd.Int("days", 365, "Validity period in days") validDays = cmd.Int("d", 365, "Validity period in days")
keySize = cmd.Int("bits", 2048, "RSA key size") keySize = cmd.Int("b", 2048, "RSA key size")
// Server/Client specific // Common options - long forms
hosts = cmd.String("hosts", "", "Comma-separated hostnames/IPs (server cert)") commonNameLong = cmd.String("common-name", "", "Common name (required)")
caFile = cmd.String("ca-cert", "", "CA certificate file (for signing)") orgLong = cmd.String("org", "LogWisp", "Organization")
caKeyFile = cmd.String("ca-key", "", "CA key file (for signing)") countryLong = cmd.String("country", "US", "Country code")
validDaysLong = cmd.Int("days", 365, "Validity period in days")
keySizeLong = cmd.Int("bits", 2048, "RSA key size")
// Server/Client specific - short forms
hosts = cmd.String("h", "", "Comma-separated hostnames/IPs")
caFile = cmd.String("ca-cert", "", "CA certificate file")
caKey = cmd.String("ca-key", "", "CA key file")
// Server/Client specific - long forms
hostsLong = cmd.String("hosts", "", "Comma-separated hostnames/IPs")
// Output files // Output files
certOut = cmd.String("cert-out", "", "Output certificate file") certOut = cmd.String("cert-out", "", "Output certificate file")
@ -58,51 +68,135 @@ func (cg *CertGeneratorCommand) Execute(args []string) error {
) )
cmd.Usage = func() { cmd.Usage = func() {
fmt.Fprintln(cg.errOut, "Generate TLS certificates for LogWisp") fmt.Fprintln(tc.errOut, "Generate TLS certificates for LogWisp")
fmt.Fprintln(cg.errOut, "\nUsage: logwisp tls [options]") fmt.Fprintln(tc.errOut, "\nUsage: logwisp tls [options]")
fmt.Fprintln(cg.errOut, "\nExamples:") fmt.Fprintln(tc.errOut, "\nExamples:")
fmt.Fprintln(cg.errOut, " # Generate self-signed certificate") fmt.Fprintln(tc.errOut, " # Generate self-signed certificate")
fmt.Fprintln(cg.errOut, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1") fmt.Fprintln(tc.errOut, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1")
fmt.Fprintln(cg.errOut, " ") fmt.Fprintln(tc.errOut, " ")
fmt.Fprintln(cg.errOut, " # Generate CA certificate") fmt.Fprintln(tc.errOut, " # Generate CA certificate")
fmt.Fprintln(cg.errOut, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key") fmt.Fprintln(tc.errOut, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key")
fmt.Fprintln(cg.errOut, " ") fmt.Fprintln(tc.errOut, " ")
fmt.Fprintln(cg.errOut, " # Generate server certificate signed by CA") fmt.Fprintln(tc.errOut, " # Generate server certificate signed by CA")
fmt.Fprintln(cg.errOut, " logwisp tls --server --cn server.example.com --hosts server.example.com \\") fmt.Fprintln(tc.errOut, " logwisp tls --server --cn server.example.com --hosts server.example.com \\")
fmt.Fprintln(cg.errOut, " --ca-cert ca.crt --ca-key ca.key") fmt.Fprintln(tc.errOut, " --ca-cert ca.crt --ca-key ca.key")
fmt.Fprintln(cg.errOut, "\nOptions:") fmt.Fprintln(tc.errOut, "\nOptions:")
cmd.PrintDefaults() cmd.PrintDefaults()
fmt.Fprintln(cg.errOut) fmt.Fprintln(tc.errOut)
} }
if err := cmd.Parse(args); err != nil { if err := cmd.Parse(args); err != nil {
return err return err
} }
// Check for unparsed arguments
if cmd.NArg() > 0 {
return fmt.Errorf("unexpected argument(s): %s", strings.Join(cmd.Args(), " "))
}
// Merge short and long options
finalCN := coalesceString(*commonName, *commonNameLong)
finalOrg := coalesceString(*org, *orgLong, "LogWisp")
finalCountry := coalesceString(*country, *countryLong, "US")
finalDays := coalesceInt(*validDays, *validDaysLong, 365)
finalKeySize := coalesceInt(*keySize, *keySizeLong, 2048)
finalHosts := coalesceString(*hosts, *hostsLong)
finalCAFile := *caFile // no short form
finalCAKey := *caKey // no short form
finalCertOut := *certOut // no short form
finalKeyOut := *keyOut // no short form
// Validate common name // Validate common name
if *commonName == "" { if finalCN == "" {
cmd.Usage() cmd.Usage()
return fmt.Errorf("common name (--cn) is required") return fmt.Errorf("common name (--cn) is required")
} }
// Validate RSA key size
if finalKeySize != 2048 && finalKeySize != 3072 && finalKeySize != 4096 {
return fmt.Errorf("invalid key size: %d (valid: 2048, 3072, 4096)", finalKeySize)
}
// Route to appropriate generator // Route to appropriate generator
switch { switch {
case *genCA: case *genCA:
return cg.generateCA(*commonName, *org, *country, *validDays, *keySize, *certOut, *keyOut) return tc.generateCA(finalCN, finalOrg, finalCountry, finalDays, finalKeySize, finalCertOut, finalKeyOut)
case *selfSign: case *selfSign:
return cg.generateSelfSigned(*commonName, *org, *country, *hosts, *validDays, *keySize, *certOut, *keyOut) return tc.generateSelfSigned(finalCN, finalOrg, finalCountry, finalHosts, finalDays, finalKeySize, finalCertOut, finalKeyOut)
case *genServer: case *genServer:
return cg.generateServerCert(*commonName, *org, *country, *hosts, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) return tc.generateServerCert(finalCN, finalOrg, finalCountry, finalHosts, finalCAFile, finalCAKey, finalDays, finalKeySize, finalCertOut, finalKeyOut)
case *genClient: case *genClient:
return cg.generateClientCert(*commonName, *org, *country, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut) return tc.generateClientCert(finalCN, finalOrg, finalCountry, finalCAFile, finalCAKey, finalDays, finalKeySize, finalCertOut, finalKeyOut)
default: default:
cmd.Usage() cmd.Usage()
return fmt.Errorf("specify certificate type: --ca, --self-signed, --server, or --client") return fmt.Errorf("specify certificate type: --ca, --self-signed, --server, or --client")
} }
} }
func (tc *TLSCommand) Description() string {
return "Generate TLS certificates (CA, server, client, self-signed)"
}
func (tc *TLSCommand) Help() string {
return `TLS Command - Generate TLS certificates for LogWisp
Usage:
logwisp tls [options]
Certificate Types:
--ca Generate Certificate Authority (CA) certificate
--server Generate server certificate (requires CA or self-signed)
--client Generate client certificate (for mTLS)
--self-signed Generate self-signed certificate (single cert for testing)
Common Options:
--cn, --common-name <name> Common Name (required)
-o, --org <organization> Organization name (default: "LogWisp")
-c, --country <code> Country code (default: "US")
-d, --days <number> Validity period in days (default: 365)
-b, --bits <size> RSA key size (default: 2048)
Server Certificate Options:
-h, --hosts <list> Comma-separated hostnames/IPs
Example: "localhost,10.0.0.1,example.com"
--ca-cert <file> CA certificate file (for signing)
--ca-key <file> CA key file (for signing)
Output Options:
--cert-out <file> Output certificate file (default: stdout)
--key-out <file> Output private key file (default: stdout)
Examples:
# Generate self-signed certificate for testing
logwisp tls --self-signed --cn localhost --hosts "localhost,127.0.0.1" \
--cert-out server.crt --key-out server.key
# Generate CA certificate
logwisp tls --ca --cn "LogWisp CA" --days 3650 \
--cert-out ca.crt --key-out ca.key
# Generate server certificate signed by CA
logwisp tls --server --cn "logwisp.example.com" \
--hosts "logwisp.example.com,10.0.0.100" \
--ca-cert ca.crt --ca-key ca.key \
--cert-out server.crt --key-out server.key
# Generate client certificate for mTLS
logwisp tls --client --cn "client1" \
--ca-cert ca.crt --ca-key ca.key \
--cert-out client.crt --key-out client.key
Security Notes:
- Keep private keys secure and never share them
- Use 2048-bit RSA minimum, 3072 or 4096 for higher security
- For production, use certificates from a trusted CA
- Self-signed certificates are only for development/testing
- Rotate certificates before expiration
`
}
// Create and manage private CA // Create and manage private CA
func (cg *CertGeneratorCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error { func (tc *TLSCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error {
// Generate RSA key // Generate RSA key
priv, err := rsa.GenerateKey(rand.Reader, bits) priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil { if err != nil {
@ -178,7 +272,7 @@ func parseHosts(hostList string) ([]string, []net.IP) {
} }
// Generate self-signed certificate // Generate self-signed certificate
func (cg *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error { func (tc *TLSCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error {
// 1. Generate an RSA private key with the specified bit size // 1. Generate an RSA private key with the specified bit size
priv, err := rsa.GenerateKey(rand.Reader, bits) priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil { if err != nil {
@ -245,7 +339,7 @@ func (cg *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts strin
} }
// Generate server cert with CA // Generate server cert with CA
func (cg *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { func (tc *TLSCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
caCert, caKey, err := loadCA(caFile, caKeyFile) caCert, caKey, err := loadCA(caFile, caKeyFile)
if err != nil { if err != nil {
return err return err
@ -308,7 +402,7 @@ func (cg *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFi
} }
// Generate client cert with CA // Generate client cert with CA
func (cg *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { func (tc *TLSCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
caCert, caKey, err := loadCA(caFile, caKeyFile) caCert, caKey, err := loadCA(caFile, caKeyFile)
if err != nil { if err != nil {
return err return err

View File

@ -0,0 +1,41 @@
// FILE: src/cmd/logwisp/commands/version.go
package commands
import (
"fmt"
"logwisp/src/internal/version"
)
// VersionCommand handles version display
type VersionCommand struct{}
// NewVersionCommand creates a new version command
func NewVersionCommand() *VersionCommand {
return &VersionCommand{}
}
func (c *VersionCommand) Execute(args []string) error {
fmt.Println(version.String())
return nil
}
func (c *VersionCommand) Description() string {
return "Show version information"
}
func (c *VersionCommand) Help() string {
return `Version Command - Show LogWisp version information
Usage:
logwisp version
logwisp -v
logwisp --version
Output includes:
- Version number
- Build date
- Git commit hash (if available)
- Go version used for compilation
`
}

View File

@ -1,59 +0,0 @@
// FILE: logwisp/src/cmd/logwisp/help.go
package main
import (
"fmt"
"os"
)
const helpText = `LogWisp: A flexible log transport and processing tool.
Usage:
logwisp [command] [options]
logwisp [options]
Commands:
auth Generate authentication credentials
version Display version information
Application Control:
-c, --config <path> Path to configuration file (default: logwisp.toml)
-h, --help Display this help message and exit
-v, --version Display version information and exit
-b, --background Run LogWisp in the background as a daemon
-q, --quiet Suppress all console output, including errors
Runtime Behavior:
--disable-status-reporter Disable the periodic status reporter
--config-auto-reload Enable config reload on file change
For command-specific help:
logwisp <command> --help
Configuration Sources (Precedence: CLI > Env > File > Defaults):
- CLI flags override all other settings
- Environment variables override file settings
- TOML configuration file is the primary method
Examples:
# Generate password for admin user
logwisp auth -u admin
# Start service with custom config
logwisp -c /etc/logwisp/prod.toml
# Run in background
logwisp -b --config-auto-reload
For detailed configuration options, please refer to the documentation.
`
// Scans arguments for help flags and prints help text if found.
func CheckAndDisplayHelp(args []string) {
for _, arg := range args {
if arg == "-h" || arg == "--help" {
fmt.Fprint(os.Stdout, helpText)
os.Exit(0)
}
}
}

View File

@ -11,6 +11,7 @@ import (
"syscall" "syscall"
"time" "time"
"logwisp/src/cmd/logwisp/commands"
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/version" "logwisp/src/internal/version"
@ -22,12 +23,22 @@ var logger *log.Logger
func main() { func main() {
// Handle subcommands before any config loading // Handle subcommands before any config loading
// This prevents flag conflicts with lixenwraith/config // This prevents flag conflicts with lixenwraith/config
router := NewCommandRouter() router := commands.NewCommandRouter()
if router.Route(os.Args) != nil { handled, err := router.Route(os.Args)
// Subcommand was handled, exit already called
return if err != nil {
// Command execution error
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
} }
if handled {
// Command was successfully handled
os.Exit(0)
}
// No subcommand, continue with main application
// Emulates nohup // Emulates nohup
signal.Ignore(syscall.SIGHUP) signal.Ignore(syscall.SIGHUP)
@ -158,8 +169,6 @@ func main() {
select { select {
case <-done: case <-done:
// Save configuration after graceful shutdown (no reload manager in static mode)
saveConfigurationOnExit(cfg, nil, logger)
logger.Info("msg", "Shutdown complete") logger.Info("msg", "Shutdown complete")
case <-shutdownCtx.Done(): case <-shutdownCtx.Done():
logger.Error("msg", "Shutdown timeout exceeded - forcing exit") logger.Error("msg", "Shutdown timeout exceeded - forcing exit")
@ -172,9 +181,6 @@ func main() {
// Wait for context cancellation // Wait for context cancellation
<-ctx.Done() <-ctx.Done()
// Save configuration before final shutdown, handled by reloadManager
saveConfigurationOnExit(cfg, reloadManager, logger)
// Shutdown is handled by ReloadManager.Shutdown() in defer // Shutdown is handled by ReloadManager.Shutdown() in defer
logger.Info("msg", "Shutdown complete") logger.Info("msg", "Shutdown complete")
} }
@ -187,47 +193,3 @@ func shutdownLogger() {
} }
} }
} }
// Saves the configuration to file on exist
func saveConfigurationOnExit(cfg *config.Config, reloadManager *ReloadManager, logger *log.Logger) {
// Only save if explicitly enabled and we have a valid path
if !cfg.ConfigSaveOnExit || cfg.ConfigFile == "" {
return
}
// Create a context with timeout for save operation
saveCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Perform save in goroutine to respect timeout
done := make(chan error, 1)
go func() {
var err error
if reloadManager != nil && reloadManager.lcfg != nil {
// Use existing lconfig instance from reload manager
// This ensures we save through the same configuration system
err = reloadManager.lcfg.Save(cfg.ConfigFile)
} else {
// Static mode: create temporary lconfig for saving
err = cfg.SaveToFile(cfg.ConfigFile)
}
done <- err
}()
select {
case err := <-done:
if err != nil {
logger.Error("msg", "Failed to save configuration on exit",
"path", cfg.ConfigFile,
"error", err)
// Don't fail the exit on save error
} else {
logger.Info("msg", "Configuration saved successfully",
"path", cfg.ConfigFile)
}
case <-saveCtx.Done():
logger.Error("msg", "Configuration save timeout exceeded",
"path", cfg.ConfigFile,
"timeout", "5s")
}
}

View File

@ -338,14 +338,6 @@ func (rm *ReloadManager) stopStatusReporter() {
} }
} }
// Wrapper to save the config
func (rm *ReloadManager) SaveConfig(path string) error {
if rm.lcfg == nil {
return fmt.Errorf("no lconfig instance available")
}
return rm.lcfg.Save(path)
}
// Stops the reload manager // Stops the reload manager
func (rm *ReloadManager) Shutdown() { func (rm *ReloadManager) Shutdown() {
rm.logger.Info("msg", "Shutting down reload manager") rm.logger.Info("msg", "Shutting down reload manager")

View File

@ -114,81 +114,76 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
for i, sinkCfg := range cfg.Sinks { for i, sinkCfg := range cfg.Sinks {
switch sinkCfg.Type { switch sinkCfg.Type {
case "tcp": case "tcp":
if port, ok := sinkCfg.Options["port"].(int64); ok { if sinkCfg.TCP != nil {
host := "0.0.0.0" // Get host or default to 0.0.0.0 host := "0.0.0.0"
if h, ok := sinkCfg.Options["host"].(string); ok && h != "" { if sinkCfg.TCP.Host != "" {
host = h host = sinkCfg.TCP.Host
} }
logger.Info("msg", "TCP endpoint configured", logger.Info("msg", "TCP endpoint configured",
"component", "main", "component", "main",
"pipeline", cfg.Name, "pipeline", cfg.Name,
"sink_index", i, "sink_index", i,
"listen", fmt.Sprintf("%s:%d", host, port)) "listen", fmt.Sprintf("%s:%d", host, sinkCfg.TCP.Port))
// Display net limit info if configured // Display net limit info if configured
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok { if sinkCfg.TCP.NetLimit != nil && sinkCfg.TCP.NetLimit.Enabled {
if enabled, ok := nl["enabled"].(bool); ok && enabled { logger.Info("msg", "TCP net limiting enabled",
logger.Info("msg", "TCP net limiting enabled", "pipeline", cfg.Name,
"pipeline", cfg.Name, "sink_index", i,
"sink_index", i, "requests_per_second", sinkCfg.TCP.NetLimit.RequestsPerSecond,
"requests_per_second", nl["requests_per_second"], "burst_size", sinkCfg.TCP.NetLimit.BurstSize)
"burst_size", nl["burst_size"])
}
} }
} }
case "http": case "http":
if port, ok := sinkCfg.Options["port"].(int64); ok { if sinkCfg.HTTP != nil {
host := "0.0.0.0" host := "0.0.0.0"
if h, ok := sinkCfg.Options["host"].(string); ok && h != "" { if sinkCfg.HTTP.Host != "" {
host = h host = sinkCfg.HTTP.Host
} }
streamPath := "/stream" streamPath := "/stream"
statusPath := "/status" statusPath := "/status"
if path, ok := sinkCfg.Options["stream_path"].(string); ok { if sinkCfg.HTTP.StreamPath != "" {
streamPath = path streamPath = sinkCfg.HTTP.StreamPath
} }
if path, ok := sinkCfg.Options["status_path"].(string); ok { if sinkCfg.HTTP.StatusPath != "" {
statusPath = path statusPath = sinkCfg.HTTP.StatusPath
} }
logger.Info("msg", "HTTP endpoints configured", logger.Info("msg", "HTTP endpoints configured",
"pipeline", cfg.Name, "pipeline", cfg.Name,
"sink_index", i, "sink_index", i,
"listen", fmt.Sprintf("%s:%d", host, port), "listen", fmt.Sprintf("%s:%d", host, sinkCfg.HTTP.Port),
"stream_url", fmt.Sprintf("http://%s:%d%s", host, port, streamPath), "stream_url", fmt.Sprintf("http://%s:%d%s", host, sinkCfg.HTTP.Port, streamPath),
"status_url", fmt.Sprintf("http://%s:%d%s", host, port, statusPath)) "status_url", fmt.Sprintf("http://%s:%d%s", host, sinkCfg.HTTP.Port, statusPath))
// Display net limit info if configured // Display net limit info if configured
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok { if sinkCfg.HTTP.NetLimit != nil && sinkCfg.HTTP.NetLimit.Enabled {
if enabled, ok := nl["enabled"].(bool); ok && enabled { logger.Info("msg", "HTTP net limiting enabled",
logger.Info("msg", "HTTP net limiting enabled", "pipeline", cfg.Name,
"pipeline", cfg.Name, "sink_index", i,
"sink_index", i, "requests_per_second", sinkCfg.HTTP.NetLimit.RequestsPerSecond,
"requests_per_second", nl["requests_per_second"], "burst_size", sinkCfg.HTTP.NetLimit.BurstSize)
"burst_size", nl["burst_size"])
}
} }
} }
case "file": case "file":
if dir, ok := sinkCfg.Options["directory"].(string); ok { if sinkCfg.File != nil {
name, _ := sinkCfg.Options["name"].(string)
logger.Info("msg", "File sink configured", logger.Info("msg", "File sink configured",
"pipeline", cfg.Name, "pipeline", cfg.Name,
"sink_index", i, "sink_index", i,
"directory", dir, "directory", sinkCfg.File.Directory,
"name", name) "name", sinkCfg.File.Name)
} }
case "console": case "console":
if target, ok := sinkCfg.Options["target"].(string); ok { if sinkCfg.Console != nil {
logger.Info("msg", "Console sink configured", logger.Info("msg", "Console sink configured",
"pipeline", cfg.Name, "pipeline", cfg.Name,
"sink_index", i, "sink_index", i,
"target", target) "target", sinkCfg.Console.Target)
} }
} }
} }
@ -197,10 +192,10 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
for i, sourceCfg := range cfg.Sources { for i, sourceCfg := range cfg.Sources {
switch sourceCfg.Type { switch sourceCfg.Type {
case "http": case "http":
if port, ok := sourceCfg.Options["port"].(int64); ok { if sourceCfg.HTTP != nil {
host := "0.0.0.0" host := "0.0.0.0"
if h, ok := sourceCfg.Options["host"].(string); ok && h != "" { if sourceCfg.HTTP.Host != "" {
host = h host = sourceCfg.HTTP.Host
} }
displayHost := host displayHost := host
@ -209,22 +204,22 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
} }
ingestPath := "/ingest" ingestPath := "/ingest"
if path, ok := sourceCfg.Options["ingest_path"].(string); ok { if sourceCfg.HTTP.IngestPath != "" {
ingestPath = path ingestPath = sourceCfg.HTTP.IngestPath
} }
logger.Info("msg", "HTTP source configured", logger.Info("msg", "HTTP source configured",
"pipeline", cfg.Name, "pipeline", cfg.Name,
"source_index", i, "source_index", i,
"listen", fmt.Sprintf("%s:%d", host, port), "listen", fmt.Sprintf("%s:%d", host, sourceCfg.HTTP.Port),
"ingest_url", fmt.Sprintf("http://%s:%d%s", displayHost, port, ingestPath)) "ingest_url", fmt.Sprintf("http://%s:%d%s", displayHost, sourceCfg.HTTP.Port, ingestPath))
} }
case "tcp": case "tcp":
if port, ok := sourceCfg.Options["port"].(int64); ok { if sourceCfg.TCP != nil {
host := "0.0.0.0" host := "0.0.0.0"
if h, ok := sourceCfg.Options["host"].(string); ok && h != "" { if sourceCfg.TCP.Host != "" {
host = h host = sourceCfg.TCP.Host
} }
displayHost := host displayHost := host
@ -235,19 +230,24 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
logger.Info("msg", "TCP source configured", logger.Info("msg", "TCP source configured",
"pipeline", cfg.Name, "pipeline", cfg.Name,
"source_index", i, "source_index", i,
"listen", fmt.Sprintf("%s:%d", host, port), "listen", fmt.Sprintf("%s:%d", host, sourceCfg.TCP.Port),
"endpoint", fmt.Sprintf("%s:%d", displayHost, port)) "endpoint", fmt.Sprintf("%s:%d", displayHost, sourceCfg.TCP.Port))
} }
// TODO: missing other types of source, to be added case "directory":
} if sourceCfg.Directory != nil {
} logger.Info("msg", "Directory source configured",
"pipeline", cfg.Name,
"source_index", i,
"path", sourceCfg.Directory.Path,
"pattern", sourceCfg.Directory.Pattern)
}
// Display authentication information case "stdin":
if cfg.Auth != nil && cfg.Auth.Type != "none" { logger.Info("msg", "Stdin source configured",
logger.Info("msg", "Authentication enabled", "pipeline", cfg.Name,
"pipeline", cfg.Name, "source_index", i)
"auth_type", cfg.Auth.Type) }
} }
// Display filter information // Display filter information

View File

@ -5,7 +5,6 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -13,7 +12,6 @@ import (
"logwisp/src/internal/config" "logwisp/src/internal/config"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
"golang.org/x/time/rate"
) )
// Prevent unbounded map growth // Prevent unbounded map growth
@ -21,66 +19,50 @@ const maxAuthTrackedIPs = 10000
// Handles all authentication methods for a pipeline // Handles all authentication methods for a pipeline
type Authenticator struct { type Authenticator struct {
config *config.AuthConfig config *config.ServerAuthConfig
logger *log.Logger logger *log.Logger
bearerTokens map[string]bool // token -> valid tokens map[string]bool // token -> valid
mu sync.RWMutex mu sync.RWMutex
// Session tracking // Session tracking
sessions map[string]*Session sessions map[string]*Session
sessionMu sync.RWMutex sessionMu sync.RWMutex
// Brute-force protection
ipAuthAttempts map[string]*ipAuthState
authMu sync.RWMutex
}
// Per-IP auth attempt tracking
type ipAuthState struct {
limiter *rate.Limiter
failCount int
lastAttempt time.Time
blockedUntil time.Time
} }
// Represents an authenticated connection // Represents an authenticated connection
type Session struct { type Session struct {
ID string ID string
Username string Username string
Method string // basic, bearer, mtls Method string // basic, token, mtls
RemoteAddr string RemoteAddr string
CreatedAt time.Time CreatedAt time.Time
LastActivity time.Time LastActivity time.Time
} }
// Creates a new authenticator from config // Creates a new authenticator from config
func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) { func NewAuthenticator(cfg *config.ServerAuthConfig, logger *log.Logger) (*Authenticator, error) {
// SCRAM is handled by ScramManager in sources // SCRAM is handled by ScramManager in sources
if cfg == nil || cfg.Type == "none" || cfg.Type == "scram" { if cfg == nil || cfg.Type == "none" || cfg.Type == "scram" {
return nil, nil return nil, nil
} }
a := &Authenticator{ a := &Authenticator{
config: cfg, config: cfg,
logger: logger, logger: logger,
bearerTokens: make(map[string]bool), tokens: make(map[string]bool),
sessions: make(map[string]*Session), sessions: make(map[string]*Session),
ipAuthAttempts: make(map[string]*ipAuthState),
} }
// Initialize Bearer tokens // Initialize tokens
if cfg.Type == "bearer" && cfg.BearerAuth != nil { if cfg.Type == "token" && cfg.Token != nil {
for _, token := range cfg.BearerAuth.Tokens { for _, token := range cfg.Token.Tokens {
a.bearerTokens[token] = true a.tokens[token] = true
} }
} }
// Start session cleanup // Start session cleanup
go a.sessionCleanup() go a.sessionCleanup()
// Start auth attempt cleanup
go a.authAttemptCleanup()
logger.Info("msg", "Authenticator initialized", logger.Info("msg", "Authenticator initialized",
"component", "auth", "component", "auth",
"type", cfg.Type) "type", cfg.Type)
@ -88,129 +70,6 @@ func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticato
return a, nil return a, nil
} }
// Check and enforce rate limits
func (a *Authenticator) checkRateLimit(remoteAddr string) error {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ip = remoteAddr // Fallback for malformed addresses
}
a.authMu.Lock()
defer a.authMu.Unlock()
state, exists := a.ipAuthAttempts[ip]
now := time.Now()
if !exists {
// Check map size limit before creating new entry
if len(a.ipAuthAttempts) >= maxAuthTrackedIPs {
// Evict an old entry using simplified LRU
// Sample 20 random entries and evict the oldest
const sampleSize = 20
var oldestIP string
oldestTime := now
// Build sample
sampled := 0
for sampledIP, sampledState := range a.ipAuthAttempts {
if sampledState.lastAttempt.Before(oldestTime) {
oldestIP = sampledIP
oldestTime = sampledState.lastAttempt
}
sampled++
if sampled >= sampleSize {
break
}
}
// Evict the oldest from our sample
if oldestIP != "" {
delete(a.ipAuthAttempts, oldestIP)
a.logger.Debug("msg", "Evicted old auth attempt state",
"component", "auth",
"evicted_ip", oldestIP,
"last_seen", oldestTime)
}
}
// Create new state for this IP
// 5 attempts per minute, burst of 3
state = &ipAuthState{
limiter: rate.NewLimiter(rate.Every(12*time.Second), 3),
lastAttempt: now,
}
a.ipAuthAttempts[ip] = state
}
// Check if IP is temporarily blocked
if now.Before(state.blockedUntil) {
remaining := state.blockedUntil.Sub(now)
a.logger.Warn("msg", "IP temporarily blocked",
"component", "auth",
"ip", ip,
"remaining", remaining)
// Sleep to slow down even blocked attempts
time.Sleep(2 * time.Second)
return fmt.Errorf("temporarily blocked, try again in %v", remaining.Round(time.Second))
}
// Check rate limit
if !state.limiter.Allow() {
state.failCount++
// Only set new blockedUntil if not already blocked
// This prevents indefinite block extension
if state.blockedUntil.IsZero() || now.After(state.blockedUntil) {
// Progressive blocking: 2^failCount minutes
blockMinutes := 1 << min(state.failCount, 6) // Cap at 64 minutes
state.blockedUntil = now.Add(time.Duration(blockMinutes) * time.Minute)
a.logger.Warn("msg", "Rate limit exceeded, blocking IP",
"component", "auth",
"ip", ip,
"fail_count", state.failCount,
"block_duration", time.Duration(blockMinutes)*time.Minute)
}
return fmt.Errorf("rate limit exceeded")
}
state.lastAttempt = now
return nil
}
// Record failed attempt
func (a *Authenticator) recordFailure(remoteAddr string) {
ip, _, _ := net.SplitHostPort(remoteAddr)
if ip == "" {
ip = remoteAddr
}
a.authMu.Lock()
defer a.authMu.Unlock()
if state, exists := a.ipAuthAttempts[ip]; exists {
state.failCount++
state.lastAttempt = time.Now()
}
}
// Reset failure count on success
func (a *Authenticator) recordSuccess(remoteAddr string) {
ip, _, _ := net.SplitHostPort(remoteAddr)
if ip == "" {
ip = remoteAddr
}
a.authMu.Lock()
defer a.authMu.Unlock()
if state, exists := a.ipAuthAttempts[ip]; exists {
state.failCount = 0
state.blockedUntil = time.Time{}
}
}
// Handles HTTP authentication headers // Handles HTTP authentication headers
func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) { func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) {
if a == nil || a.config.Type == "none" { if a == nil || a.config.Type == "none" {
@ -222,77 +81,27 @@ func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Sessio
}, nil }, nil
} }
// Check rate limit
if err := a.checkRateLimit(remoteAddr); err != nil {
return nil, err
}
var session *Session var session *Session
var err error var err error
switch a.config.Type { switch a.config.Type {
case "bearer": case "token":
session, err = a.authenticateBearer(authHeader, remoteAddr) session, err = a.authenticateToken(authHeader, remoteAddr)
default: default:
err = fmt.Errorf("unsupported auth type: %s", a.config.Type) err = fmt.Errorf("unsupported auth type: %s", a.config.Type)
} }
if err != nil { if err != nil {
a.recordFailure(remoteAddr)
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
return nil, err return nil, err
} }
a.recordSuccess(remoteAddr)
return session, nil return session, nil
} }
// Handles TCP connection authentication func (a *Authenticator) authenticateToken(authHeader, remoteAddr string) (*Session, error) {
func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string) (*Session, error) { if !strings.HasPrefix(authHeader, "Token") {
if a == nil || a.config.Type == "none" { return nil, fmt.Errorf("invalid token auth header")
return &Session{
ID: generateSessionID(),
Method: "none",
RemoteAddr: remoteAddr,
CreatedAt: time.Now(),
}, nil
}
// Check rate limit first
if err := a.checkRateLimit(remoteAddr); err != nil {
return nil, err
}
var session *Session
var err error
// TCP auth protocol: AUTH <method> <credentials>
switch strings.ToLower(method) {
case "token":
if a.config.Type != "bearer" {
err = fmt.Errorf("token auth not configured")
} else {
session, err = a.validateToken(credentials, remoteAddr)
}
default:
err = fmt.Errorf("unsupported auth method: %s", method)
}
if err != nil {
a.recordFailure(remoteAddr)
// Add delay on failure
time.Sleep(500 * time.Millisecond)
return nil, err
}
a.recordSuccess(remoteAddr)
return session, nil
}
func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) {
if !strings.HasPrefix(authHeader, "Bearer ") {
return nil, fmt.Errorf("invalid bearer auth header")
} }
token := authHeader[7:] token := authHeader[7:]
@ -302,7 +111,7 @@ func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Sess
func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) { func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) {
// Check static tokens first // Check static tokens first
a.mu.RLock() a.mu.RLock()
isValid := a.bearerTokens[token] isValid := a.tokens[token]
a.mu.RUnlock() a.mu.RUnlock()
if !isValid { if !isValid {
@ -311,7 +120,7 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error
session := &Session{ session := &Session{
ID: generateSessionID(), ID: generateSessionID(),
Method: "bearer", Method: "token",
RemoteAddr: remoteAddr, RemoteAddr: remoteAddr,
CreatedAt: time.Now(), CreatedAt: time.Now(),
LastActivity: time.Now(), LastActivity: time.Now(),
@ -352,27 +161,6 @@ func (a *Authenticator) sessionCleanup() {
} }
} }
// Cleanup old auth attempts
func (a *Authenticator) authAttemptCleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
a.authMu.Lock()
now := time.Now()
for ip, state := range a.ipAuthAttempts {
// Remove entries older than 1 hour with no recent activity
if now.Sub(state.lastAttempt) > time.Hour {
delete(a.ipAuthAttempts, ip)
a.logger.Debug("msg", "Cleaned up auth attempt state",
"component", "auth",
"ip", ip)
}
}
a.authMu.Unlock()
}
}
func generateSessionID() string { func generateSessionID() string {
b := make([]byte, 32) b := make([]byte, 32)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
@ -418,6 +206,6 @@ func (a *Authenticator) GetStats() map[string]any {
"enabled": true, "enabled": true,
"type": a.config.Type, "type": a.config.Type,
"active_sessions": sessionCount, "active_sessions": sessionCount,
"static_tokens": len(a.bearerTokens), "static_tokens": len(a.tokens),
} }
} }

View File

@ -1,207 +0,0 @@
// FILE: src/internal/auth/generator.go
package auth
import (
"crypto/rand"
"encoding/base64"
"flag"
"fmt"
"io"
"os"
"syscall"
"logwisp/src/internal/scram"
"golang.org/x/crypto/argon2"
"golang.org/x/term"
)
// Argon2id parameters
const (
argon2Time = 3
argon2Memory = 64 * 1024 // 64 MB
argon2Threads = 4
argon2SaltLen = 16
argon2KeyLen = 32
)
type AuthGeneratorCommand struct {
output io.Writer
errOut io.Writer
}
func NewAuthGeneratorCommand() *AuthGeneratorCommand {
return &AuthGeneratorCommand{
output: os.Stdout,
errOut: os.Stderr,
}
}
func (ag *AuthGeneratorCommand) Execute(args []string) error {
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
cmd.SetOutput(ag.errOut)
var (
username = cmd.String("u", "", "Username")
password = cmd.String("p", "", "Password (will prompt if not provided)")
authType = cmd.String("type", "basic", "Auth type: basic (HTTP) or scram (TCP)")
genToken = cmd.Bool("t", false, "Generate random bearer token")
tokenLen = cmd.Int("l", 32, "Token length in bytes (min 16, max 512)")
)
cmd.Usage = func() {
fmt.Fprintln(ag.errOut, "Generate authentication credentials for LogWisp")
fmt.Fprintln(ag.errOut, "\nUsage: logwisp auth [options]")
fmt.Fprintln(ag.errOut, "\nExamples:")
fmt.Fprintln(ag.errOut, " # Generate basic auth hash for HTTP sources/sinks")
fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type basic")
fmt.Fprintln(ag.errOut, " ")
fmt.Fprintln(ag.errOut, " # Generate SCRAM credentials for TCP sources/sinks")
fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type scram")
fmt.Fprintln(ag.errOut, " ")
fmt.Fprintln(ag.errOut, " # Generate 64-byte bearer token")
fmt.Fprintln(ag.errOut, " logwisp auth -t -l 64")
fmt.Fprintln(ag.errOut, "\nOptions:")
cmd.PrintDefaults()
fmt.Fprintln(ag.errOut)
}
if err := cmd.Parse(args); err != nil {
return err
}
if *genToken {
return ag.generateToken(*tokenLen)
}
if *username == "" {
cmd.Usage()
return fmt.Errorf("username required for credential generation")
}
switch *authType {
case "basic":
return ag.generateBasicAuth(*username, *password)
case "scram":
return ag.generateScramAuth(*username, *password)
default:
return fmt.Errorf("invalid auth type: %s (use 'basic' or 'scram')", *authType)
}
}
func (ag *AuthGeneratorCommand) generateBasicAuth(username, password string) error {
// Get password if not provided
if password == "" {
pass1 := ag.promptPassword("Enter password: ")
pass2 := ag.promptPassword("Confirm password: ")
if pass1 != pass2 {
return fmt.Errorf("passwords don't match")
}
password = pass1
}
// Generate salt
salt := make([]byte, argon2SaltLen)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("failed to generate salt: %w", err)
}
// Generate Argon2id hash
hash := argon2.IDKey([]byte(password), salt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
// Encode in PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64)
// Output configuration snippets
fmt.Fprintln(ag.output, "\n# Basic Auth Configuration (HTTP sources/sinks)")
fmt.Fprintln(ag.output, "# REQUIRES HTTPS/TLS for security")
fmt.Fprintln(ag.output, "# Add to logwisp.toml under [[pipelines]]:")
fmt.Fprintln(ag.output, "")
fmt.Fprintln(ag.output, "[pipelines.auth]")
fmt.Fprintln(ag.output, `type = "basic"`)
fmt.Fprintln(ag.output, "")
fmt.Fprintln(ag.output, "[[pipelines.auth.basic_auth.users]]")
fmt.Fprintf(ag.output, "username = %q\n", username)
fmt.Fprintf(ag.output, "password_hash = %q\n\n", phcHash)
return nil
}
func (ag *AuthGeneratorCommand) generateScramAuth(username, password string) error {
// Get password if not provided
if password == "" {
pass1 := ag.promptPassword("Enter password: ")
pass2 := ag.promptPassword("Confirm password: ")
if pass1 != pass2 {
return fmt.Errorf("passwords don't match")
}
password = pass1
}
// Generate salt
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("failed to generate salt: %w", err)
}
// Derive SCRAM credential
cred, err := scram.DeriveCredential(username, password, salt, 3, 65536, 4)
if err != nil {
return fmt.Errorf("failed to derive SCRAM credential: %w", err)
}
// Output SCRAM configuration
fmt.Fprintln(ag.output, "\n# SCRAM Auth Configuration (for TCP sources/sinks)")
fmt.Fprintln(ag.output, "# Add to logwisp.toml:")
fmt.Fprintln(ag.output, "[[pipelines.auth.scram_auth.users]]")
fmt.Fprintf(ag.output, "username = %q\n", username)
fmt.Fprintf(ag.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
fmt.Fprintf(ag.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
fmt.Fprintf(ag.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
fmt.Fprintf(ag.output, "argon_time = %d\n", cred.ArgonTime)
fmt.Fprintf(ag.output, "argon_memory = %d\n", cred.ArgonMemory)
fmt.Fprintf(ag.output, "argon_threads = %d\n\n", cred.ArgonThreads)
return nil
}
func (ag *AuthGeneratorCommand) generateToken(length int) error {
if length < 16 {
fmt.Fprintln(ag.errOut, "Warning: tokens < 16 bytes are cryptographically weak")
}
if length > 512 {
return fmt.Errorf("token length exceeds maximum (512 bytes)")
}
token := make([]byte, length)
if _, err := rand.Read(token); err != nil {
return fmt.Errorf("failed to generate random bytes: %w", err)
}
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
hex := fmt.Sprintf("%x", token)
fmt.Fprintln(ag.output, "\n# Bearer Token Configuration")
fmt.Fprintln(ag.output, "# Add to logwisp.toml:")
fmt.Fprintf(ag.output, "tokens = [%q]\n\n", b64)
fmt.Fprintln(ag.output, "# Generated Token:")
fmt.Fprintf(ag.output, "Base64: %s\n", b64)
fmt.Fprintf(ag.output, "Hex: %s\n", hex)
return nil
}
func (ag *AuthGeneratorCommand) promptPassword(prompt string) string {
fmt.Fprint(ag.errOut, prompt)
password, err := term.ReadPassword(syscall.Stdin)
fmt.Fprintln(ag.errOut)
if err != nil {
fmt.Fprintf(ag.errOut, "Failed to read password: %v\n", err)
os.Exit(1)
}
return string(password)
}

View File

@ -1,5 +1,5 @@
// FILE: src/internal/scram/client.go // FILE: src/internal/auth/scram_client.go
package scram package auth
import ( import (
"crypto/rand" "crypto/rand"
@ -12,7 +12,7 @@ import (
) )
// Client handles SCRAM client-side authentication // Client handles SCRAM client-side authentication
type Client struct { type ScramClient struct {
Username string Username string
Password string Password string
@ -23,16 +23,16 @@ type Client struct {
serverKey []byte serverKey []byte
} }
// NewClient creates SCRAM client // NewScramClient creates SCRAM client
func NewClient(username, password string) *Client { func NewScramClient(username, password string) *ScramClient {
return &Client{ return &ScramClient{
Username: username, Username: username,
Password: password, Password: password,
} }
} }
// StartAuthentication generates ClientFirst message // StartAuthentication generates ClientFirst message
func (c *Client) StartAuthentication() (*ClientFirst, error) { func (c *ScramClient) StartAuthentication() (*ClientFirst, error) {
// Generate client nonce // Generate client nonce
nonce := make([]byte, 32) nonce := make([]byte, 32)
if _, err := rand.Read(nonce); err != nil { if _, err := rand.Read(nonce); err != nil {
@ -47,7 +47,7 @@ func (c *Client) StartAuthentication() (*ClientFirst, error) {
} }
// ProcessServerFirst handles server challenge // ProcessServerFirst handles server challenge
func (c *Client) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) { func (c *ScramClient) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) {
c.serverFirst = msg c.serverFirst = msg
// Decode salt // Decode salt
@ -83,7 +83,7 @@ func (c *Client) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) {
} }
// VerifyServerFinal validates server signature // VerifyServerFinal validates server signature
func (c *Client) VerifyServerFinal(msg *ServerFinal) error { func (c *ScramClient) VerifyServerFinal(msg *ServerFinal) error {
if c.authMessage == "" || c.serverKey == nil { if c.authMessage == "" || c.serverKey == nil {
return fmt.Errorf("invalid handshake state") return fmt.Errorf("invalid handshake state")
} }

View File

@ -1,5 +1,5 @@
// FILE: src/internal/scram/credential.go // FILE: src/internal/auth/scram_credential.go
package scram package auth
import ( import (
"crypto/hmac" "crypto/hmac"
@ -9,6 +9,8 @@ import (
"fmt" "fmt"
"strings" "strings"
"logwisp/src/internal/core"
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
) )
@ -31,7 +33,13 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
} }
// Derive salted password using Argon2id // Derive salted password using Argon2id
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, 32) saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, core.Argon2KeyLen)
// Construct PHC format for basic auth compatibility
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
hashB64 := base64.RawStdEncoding.EncodeToString(saltedPassword)
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, memory, time, threads, saltB64, hashB64)
// Derive keys // Derive keys
clientKey := computeHMAC(saltedPassword, []byte("Client Key")) clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
@ -46,6 +54,7 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
ArgonThreads: threads, ArgonThreads: threads,
StoredKey: storedKey[:], StoredKey: storedKey[:],
ServerKey: serverKey, ServerKey: serverKey,
PHCHash: phcHash,
}, nil }, nil
} }

View File

@ -0,0 +1,83 @@
// FILE: src/internal/auth/scram_manager.go
package auth
import (
"crypto/rand"
"encoding/base64"
"fmt"
"logwisp/src/internal/config"
)
// ScramManager provides high-level SCRAM operations with rate limiting
type ScramManager struct {
server *ScramServer
}
// NewScramManager creates SCRAM manager
func NewScramManager(scramAuthCfg *config.ScramAuthConfig) *ScramManager {
manager := &ScramManager{
server: NewScramServer(),
}
// Load users from SCRAM config
for _, user := range scramAuthCfg.Users {
storedKey, err := base64.StdEncoding.DecodeString(user.StoredKey)
if err != nil {
// Skip user with invalid stored key
continue
}
serverKey, err := base64.StdEncoding.DecodeString(user.ServerKey)
if err != nil {
// Skip user with invalid server key
continue
}
salt, err := base64.StdEncoding.DecodeString(user.Salt)
if err != nil {
// Skip user with invalid salt
continue
}
cred := &Credential{
Username: user.Username,
StoredKey: storedKey,
ServerKey: serverKey,
Salt: salt,
ArgonTime: user.ArgonTime,
ArgonMemory: user.ArgonMemory,
ArgonThreads: user.ArgonThreads,
}
manager.server.AddCredential(cred)
}
return manager
}
// RegisterUser creates new user credential
func (sm *ScramManager) RegisterUser(username, password string) error {
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("salt generation failed: %w", err)
}
cred, err := DeriveCredential(username, password, salt,
sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads)
if err != nil {
return err
}
sm.server.AddCredential(cred)
return nil
}
// HandleClientFirst wraps server's HandleClientFirst
func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
return sm.server.HandleClientFirst(msg)
}
// HandleClientFinal wraps server's HandleClientFinal
func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
return sm.server.HandleClientFinal(msg)
}

View File

@ -0,0 +1,38 @@
// FILE: src/internal/auth/scram_message.go
package auth
import (
"fmt"
)
// ClientFirst initiates authentication
type ClientFirst struct {
Username string `json:"u"`
ClientNonce string `json:"n"`
}
// ServerFirst contains server challenge
type ServerFirst struct {
FullNonce string `json:"r"` // client_nonce + server_nonce
Salt string `json:"s"` // base64
ArgonTime uint32 `json:"t"`
ArgonMemory uint32 `json:"m"`
ArgonThreads uint8 `json:"p"`
}
// ClientFinal contains client proof
type ClientFinal struct {
FullNonce string `json:"r"`
ClientProof string `json:"p"` // base64
}
// ServerFinal contains server signature for mutual auth
type ServerFinal struct {
ServerSignature string `json:"v"` // base64
SessionID string `json:"sid,omitempty"`
}
func (sf *ServerFirst) Marshal() string {
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads)
}

View File

@ -0,0 +1,117 @@
// FILE: src/internal/auth/scram_protocol.go
package auth
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/lixenwraith/log"
"github.com/panjf2000/gnet/v2"
)
// ScramProtocolHandler handles SCRAM message exchange for TCP
type ScramProtocolHandler struct {
manager *ScramManager
logger *log.Logger
}
// NewScramProtocolHandler creates protocol handler
func NewScramProtocolHandler(manager *ScramManager, logger *log.Logger) *ScramProtocolHandler {
return &ScramProtocolHandler{
manager: manager,
logger: logger,
}
}
// HandleAuthMessage processes a complete auth line from buffer
func (sph *ScramProtocolHandler) HandleAuthMessage(line []byte, conn gnet.Conn) (authenticated bool, session *Session, err error) {
// Parse SCRAM messages
parts := strings.Fields(string(line))
if len(parts) < 2 {
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil)
return false, nil, fmt.Errorf("invalid message format")
}
switch parts[0] {
case "SCRAM-FIRST":
// Parse ClientFirst JSON
var clientFirst ClientFirst
if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil {
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
return false, nil, fmt.Errorf("invalid JSON")
}
// Process with SCRAM server
serverFirst, err := sph.manager.HandleClientFirst(&clientFirst)
if err != nil {
// Still send challenge to prevent user enumeration
response, _ := json.Marshal(serverFirst)
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
return false, nil, err
}
// Send ServerFirst challenge
response, _ := json.Marshal(serverFirst)
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
return false, nil, nil // Not authenticated yet
case "SCRAM-PROOF":
// Parse ClientFinal JSON
var clientFinal ClientFinal
if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil {
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
return false, nil, fmt.Errorf("invalid JSON")
}
// Verify proof
serverFinal, err := sph.manager.HandleClientFinal(&clientFinal)
if err != nil {
conn.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil)
return false, nil, err
}
// Authentication successful
session = &Session{
ID: serverFinal.SessionID,
Method: "scram-sha-256",
RemoteAddr: conn.RemoteAddr().String(),
CreatedAt: time.Now(),
}
// Send ServerFinal with signature
response, _ := json.Marshal(serverFinal)
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil)
return true, session, nil
default:
conn.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil)
return false, nil, fmt.Errorf("unknown command: %s", parts[0])
}
}
// FormatSCRAMRequest formats a SCRAM protocol message for TCP
func FormatSCRAMRequest(command string, data interface{}) (string, error) {
jsonData, err := json.Marshal(data)
if err != nil {
return "", fmt.Errorf("failed to marshal %s: %w", command, err)
}
return fmt.Sprintf("%s %s\n", command, jsonData), nil
}
// ParseSCRAMResponse parses a SCRAM protocol response from TCP
func ParseSCRAMResponse(response string) (command string, data string, err error) {
response = strings.TrimSpace(response)
parts := strings.SplitN(response, " ", 2)
if len(parts) < 1 {
return "", "", fmt.Errorf("empty response")
}
command = parts[0]
if len(parts) > 1 {
data = parts[1]
}
return command, data, nil
}

View File

@ -1,5 +1,5 @@
// FILE: src/internal/scram/server.go // FILE: src/internal/auth/scram_server.go
package scram package auth
import ( import (
"crypto/rand" "crypto/rand"
@ -9,14 +9,17 @@ import (
"fmt" "fmt"
"sync" "sync"
"time" "time"
"logwisp/src/internal/core"
) )
// Server handles SCRAM authentication // Server handles SCRAM authentication
type Server struct { type ScramServer struct {
credentials map[string]*Credential credentials map[string]*Credential
handshakes map[string]*HandshakeState handshakes map[string]*HandshakeState
mu sync.RWMutex mu sync.RWMutex
// TODO: configurability useful? to be included in config or refactor to use core.const directly for simplicity
// Default Argon2 params for new registrations // Default Argon2 params for new registrations
DefaultTime uint32 DefaultTime uint32
DefaultMemory uint32 DefaultMemory uint32
@ -29,32 +32,30 @@ type HandshakeState struct {
ClientNonce string ClientNonce string
ServerNonce string ServerNonce string
FullNonce string FullNonce string
AuthMessage string
Credential *Credential Credential *Credential
CreatedAt time.Time CreatedAt time.Time
ClientProof []byte
} }
// NewServer creates SCRAM server // NewScramServer creates SCRAM server
func NewServer() *Server { func NewScramServer() *ScramServer {
return &Server{ return &ScramServer{
credentials: make(map[string]*Credential), credentials: make(map[string]*Credential),
handshakes: make(map[string]*HandshakeState), handshakes: make(map[string]*HandshakeState),
DefaultTime: 3, DefaultTime: core.Argon2Time,
DefaultMemory: 64 * 1024, DefaultMemory: core.Argon2Memory,
DefaultThreads: 4, DefaultThreads: core.Argon2Threads,
} }
} }
// AddCredential registers user credential // AddCredential registers user credential
func (s *Server) AddCredential(cred *Credential) { func (s *ScramServer) AddCredential(cred *Credential) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.credentials[cred.Username] = cred s.credentials[cred.Username] = cred
} }
// HandleClientFirst processes initial auth request // HandleClientFirst processes initial auth request
func (s *Server) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) { func (s *ScramServer) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -103,7 +104,7 @@ func (s *Server) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
} }
// HandleClientFinal verifies client proof // HandleClientFinal verifies client proof
func (s *Server) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) { func (s *ScramServer) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -157,7 +158,7 @@ func (s *Server) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
}, nil }, nil
} }
func (s *Server) cleanupHandshakes() { func (s *ScramServer) cleanupHandshakes() {
cutoff := time.Now().Add(-60 * time.Second) cutoff := time.Now().Add(-60 * time.Second)
for nonce, state := range s.handshakes { for nonce, state := range s.handshakes {
if state.CreatedAt.Before(cutoff) { if state.CreatedAt.Before(cutoff) {
@ -171,9 +172,3 @@ func generateNonce() string {
rand.Read(b) rand.Read(b)
return base64.StdEncoding.EncodeToString(b) return base64.StdEncoding.EncodeToString(b)
} }
func generateSessionID() string {
b := make([]byte, 24)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}

View File

@ -1,81 +0,0 @@
// FILE: logwisp/src/internal/config/auth.go
package config
import (
"fmt"
)
type AuthConfig struct {
// Authentication type: "none", "basic", "scram", "bearer", "mtls"
Type string `toml:"type"`
BasicAuth *BasicAuthConfig `toml:"basic_auth"`
ScramAuth *ScramAuthConfig `toml:"scram_auth"`
BearerAuth *BearerAuthConfig `toml:"bearer_auth"`
}
type BasicAuthConfig struct {
Users []BasicAuthUser `toml:"users"`
Realm string `toml:"realm"`
}
type BasicAuthUser struct {
Username string `toml:"username"`
PasswordHash string `toml:"password_hash"` // Argon2
}
type ScramAuthConfig struct {
Users []ScramUser `toml:"users"`
}
type ScramUser struct {
Username string `toml:"username"`
StoredKey string `toml:"stored_key"` // base64
ServerKey string `toml:"server_key"` // base64
Salt string `toml:"salt"` // base64
ArgonTime uint32 `toml:"argon_time"`
ArgonMemory uint32 `toml:"argon_memory"`
ArgonThreads uint8 `toml:"argon_threads"`
}
type BearerAuthConfig struct {
// Static tokens
Tokens []string `toml:"tokens"`
// TODO: Maybe future development
// // JWT validation
// JWT *JWTConfig `toml:"jwt"`
}
// TODO: Maybe future development
// type JWTConfig struct {
// JWKSURL string `toml:"jwks_url"`
// SigningKey string `toml:"signing_key"`
// Issuer string `toml:"issuer"`
// Audience string `toml:"audience"`
// }
func validateAuth(pipelineName string, auth *AuthConfig) error {
if auth == nil {
return nil
}
validTypes := map[string]bool{"none": true, "basic": true, "scram": true, "bearer": true, "mtls": true}
if !validTypes[auth.Type] {
return fmt.Errorf("pipeline '%s': invalid auth type: %s", pipelineName, auth.Type)
}
if auth.Type == "basic" && auth.BasicAuth == nil {
return fmt.Errorf("pipeline '%s': basic auth type specified but config missing", pipelineName)
}
if auth.Type == "scram" && auth.ScramAuth == nil {
return fmt.Errorf("pipeline '%s': scram auth type specified but config missing", pipelineName)
}
if auth.Type == "bearer" && auth.BearerAuth == nil {
return fmt.Errorf("pipeline '%s': bearer auth type specified but config missing", pipelineName)
}
return nil
}

View File

@ -1,6 +1,8 @@
// FILE: logwisp/src/internal/config/config.go // FILE: logwisp/src/internal/config/config.go
package config package config
// --- LogWisp Configuration Options ---
type Config struct { type Config struct {
// Top-level flags for application control // Top-level flags for application control
Background bool `toml:"background"` Background bool `toml:"background"`
@ -10,7 +12,6 @@ type Config struct {
// Runtime behavior flags // Runtime behavior flags
DisableStatusReporter bool `toml:"disable_status_reporter"` DisableStatusReporter bool `toml:"disable_status_reporter"`
ConfigAutoReload bool `toml:"config_auto_reload"` ConfigAutoReload bool `toml:"config_auto_reload"`
ConfigSaveOnExit bool `toml:"config_save_on_exit"`
// Internal flag indicating demonized child process // Internal flag indicating demonized child process
BackgroundDaemon bool `toml:"background-daemon"` BackgroundDaemon bool `toml:"background-daemon"`
@ -22,3 +23,364 @@ type Config struct {
Logging *LogConfig `toml:"logging"` Logging *LogConfig `toml:"logging"`
Pipelines []PipelineConfig `toml:"pipelines"` Pipelines []PipelineConfig `toml:"pipelines"`
} }
// --- Logging Options ---
// Represents logging configuration for LogWisp
type LogConfig struct {
// Output mode: "file", "stdout", "stderr", "split", "all", "none"
Output string `toml:"output"`
// Log level: "debug", "info", "warn", "error"
Level string `toml:"level"`
// File output settings (when Output includes "file" or "all")
File *LogFileConfig `toml:"file"`
// Console output settings
Console *LogConsoleConfig `toml:"console"`
}
type LogFileConfig struct {
// Directory for log files
Directory string `toml:"directory"`
// Base name for log files
Name string `toml:"name"`
// Maximum size per log file in MB
MaxSizeMB int64 `toml:"max_size_mb"`
// Maximum total size of all logs in MB
MaxTotalSizeMB int64 `toml:"max_total_size_mb"`
// Log retention in hours (0 = disabled)
RetentionHours float64 `toml:"retention_hours"`
}
type LogConsoleConfig struct {
// Target for console output: "stdout", "stderr", "split"
// "split": info/debug to stdout, warn/error to stderr
Target string `toml:"target"`
// Format: "txt" or "json"
Format string `toml:"format"`
}
// --- Pipeline Options ---
type PipelineConfig struct {
Name string `toml:"name"`
Sources []SourceConfig `toml:"sources"`
RateLimit *RateLimitConfig `toml:"rate_limit"`
Filters []FilterConfig `toml:"filters"`
Format *FormatConfig `toml:"format"`
Sinks []SinkConfig `toml:"sinks"`
// Auth *ServerAuthConfig `toml:"auth"` // Global auth for pipeline
}
// Common configuration structs used across components
type NetLimitConfig struct {
Enabled bool `toml:"enabled"`
MaxConnections int64 `toml:"max_connections"`
RequestsPerSecond float64 `toml:"requests_per_second"`
BurstSize int64 `toml:"burst_size"`
ResponseMessage string `toml:"response_message"`
ResponseCode int64 `toml:"response_code"` // Default: 429
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
MaxConnectionsPerUser int64 `toml:"max_connections_per_user"`
MaxConnectionsPerToken int64 `toml:"max_connections_per_token"`
MaxConnectionsTotal int64 `toml:"max_connections_total"`
IPWhitelist []string `toml:"ip_whitelist"`
IPBlacklist []string `toml:"ip_blacklist"`
}
type TLSConfig struct {
Enabled bool `toml:"enabled"`
CertFile string `toml:"cert_file"`
KeyFile string `toml:"key_file"`
CAFile string `toml:"ca_file"`
ServerName string `toml:"server_name"` // for client verification
SkipVerify bool `toml:"skip_verify"`
// Client certificate authentication
ClientAuth bool `toml:"client_auth"`
ClientCAFile string `toml:"client_ca_file"`
VerifyClientCert bool `toml:"verify_client_cert"`
// TLS version constraints
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
MaxVersion string `toml:"max_version"`
// Cipher suites (comma-separated list)
CipherSuites string `toml:"cipher_suites"`
}
type HeartbeatConfig struct {
Enabled bool `toml:"enabled"`
Interval int64 `toml:"interval_ms"`
IncludeTimestamp bool `toml:"include_timestamp"`
IncludeStats bool `toml:"include_stats"`
Format string `toml:"format"`
}
type ClientAuthConfig struct {
Type string `toml:"type"` // "none", "basic", "token", "scram"
Username string `toml:"username"`
Password string `toml:"password"`
Token string `toml:"token"`
}
// --- Source Options ---
type SourceConfig struct {
Type string `toml:"type"`
// Polymorphic - only one populated based on type
Directory *DirectorySourceOptions `toml:"directory,omitempty"`
Stdin *StdinSourceOptions `toml:"stdin,omitempty"`
HTTP *HTTPSourceOptions `toml:"http,omitempty"`
TCP *TCPSourceOptions `toml:"tcp,omitempty"`
}
type DirectorySourceOptions struct {
Path string `toml:"path"`
Pattern string `toml:"pattern"` // glob pattern
CheckIntervalMS int64 `toml:"check_interval_ms"`
Recursive bool `toml:"recursive"`
FollowSymlinks bool `toml:"follow_symlinks"`
DeleteAfterRead bool `toml:"delete_after_read"`
MoveToDirectory string `toml:"move_to_directory"` // move after processing
}
type StdinSourceOptions struct {
BufferSize int64 `toml:"buffer_size"`
}
type HTTPSourceOptions struct {
Host string `toml:"host"`
Port int64 `toml:"port"`
IngestPath string `toml:"ingest_path"`
BufferSize int64 `toml:"buffer_size"`
MaxRequestBodySize int64 `toml:"max_body_size"`
ReadTimeout int64 `toml:"read_timeout_ms"`
WriteTimeout int64 `toml:"write_timeout_ms"`
NetLimit *NetLimitConfig `toml:"net_limit"`
TLS *TLSConfig `toml:"tls"`
Auth *ServerAuthConfig `toml:"auth"`
}
type TCPSourceOptions struct {
Host string `toml:"host"`
Port int64 `toml:"port"`
BufferSize int64 `toml:"buffer_size"`
ReadTimeout int64 `toml:"read_timeout_ms"`
KeepAlive bool `toml:"keep_alive"`
KeepAlivePeriod int64 `toml:"keep_alive_period_ms"`
NetLimit *NetLimitConfig `toml:"net_limit"`
Auth *ServerAuthConfig `toml:"auth"`
}
// --- Sink Options ---
type SinkConfig struct {
Type string `toml:"type"`
// Polymorphic - only one populated based on type
Console *ConsoleSinkOptions `toml:"console,omitempty"`
File *FileSinkOptions `toml:"file,omitempty"`
HTTP *HTTPSinkOptions `toml:"http,omitempty"`
TCP *TCPSinkOptions `toml:"tcp,omitempty"`
HTTPClient *HTTPClientSinkOptions `toml:"http_client,omitempty"`
TCPClient *TCPClientSinkOptions `toml:"tcp_client,omitempty"`
}
type ConsoleSinkOptions struct {
Target string `toml:"target"` // "stdout", "stderr", "split"
Colorize bool `toml:"colorize"`
BufferSize int64 `toml:"buffer_size"`
}
type FileSinkOptions struct {
Directory string `toml:"directory"`
Name string `toml:"name"`
// Extension string `toml:"extension"`
MaxSizeMB int64 `toml:"max_size_mb"`
MaxTotalSizeMB int64 `toml:"max_total_size_mb"`
MinDiskFreeMB int64 `toml:"min_disk_free_mb"`
RetentionHours float64 `toml:"retention_hours"`
BufferSize int64 `toml:"buffer_size"`
FlushInterval int64 `toml:"flush_interval_ms"`
}
type HTTPSinkOptions struct {
Host string `toml:"host"`
Port int64 `toml:"port"`
StreamPath string `toml:"stream_path"`
StatusPath string `toml:"status_path"`
BufferSize int64 `toml:"buffer_size"`
WriteTimeout int64 `toml:"write_timeout_ms"`
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
NetLimit *NetLimitConfig `toml:"net_limit"`
TLS *TLSConfig `toml:"tls"`
Auth *ServerAuthConfig `toml:"auth"`
}
type TCPSinkOptions struct {
Host string `toml:"host"`
Port int64 `toml:"port"`
BufferSize int64 `toml:"buffer_size"`
WriteTimeout int64 `toml:"write_timeout_ms"`
KeepAlive bool `toml:"keep_alive"`
KeepAlivePeriod int64 `toml:"keep_alive_period_ms"`
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
NetLimit *NetLimitConfig `toml:"net_limit"`
Auth *ServerAuthConfig `toml:"auth"`
}
type HTTPClientSinkOptions struct {
URL string `toml:"url"`
Headers map[string]string `toml:"headers"`
BufferSize int64 `toml:"buffer_size"`
BatchSize int64 `toml:"batch_size"`
BatchDelayMS int64 `toml:"batch_delay_ms"`
Timeout int64 `toml:"timeout_seconds"`
MaxRetries int64 `toml:"max_retries"`
RetryDelayMS int64 `toml:"retry_delay_ms"`
RetryBackoff float64 `toml:"retry_backoff"`
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
TLS *TLSConfig `toml:"tls"`
Auth *ClientAuthConfig `toml:"auth"`
}
type TCPClientSinkOptions struct {
Host string `toml:"host"`
Port int64 `toml:"port"`
BufferSize int64 `toml:"buffer_size"`
DialTimeout int64 `toml:"dial_timeout_seconds"`
WriteTimeout int64 `toml:"write_timeout_seconds"`
ReadTimeout int64 `toml:"read_timeout_seconds"`
KeepAlive int64 `toml:"keep_alive_seconds"`
ReconnectDelayMS int64 `toml:"reconnect_delay_ms"`
MaxReconnectDelayMS int64 `toml:"max_reconnect_delay_ms"`
ReconnectBackoff float64 `toml:"reconnect_backoff"`
Auth *ClientAuthConfig `toml:"auth"`
}
// --- Rate Limit Options ---
// Defines the action to take when a rate limit is exceeded.
type RateLimitPolicy int
const (
// PolicyPass allows all logs through, effectively disabling the limiter.
PolicyPass RateLimitPolicy = iota
// PolicyDrop drops logs that exceed the rate limit.
PolicyDrop
)
// Defines the configuration for pipeline-level rate limiting.
type RateLimitConfig struct {
// Rate is the number of log entries allowed per second. Default: 0 (disabled).
Rate float64 `toml:"rate"`
// Burst is the maximum number of log entries that can be sent in a short burst. Defaults to the Rate.
Burst float64 `toml:"burst"`
// Policy defines the action to take when the limit is exceeded. "pass" or "drop".
Policy string `toml:"policy"`
// MaxEntrySizeBytes is the maximum allowed size for a single log entry. 0 = no limit.
MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"`
}
// --- Filter Options ---
// Represents the filter type
type FilterType string
const (
FilterTypeInclude FilterType = "include" // Whitelist - only matching logs pass
FilterTypeExclude FilterType = "exclude" // Blacklist - matching logs are dropped
)
// Represents how multiple patterns are combined
type FilterLogic string
const (
FilterLogicOr FilterLogic = "or" // Match any pattern
FilterLogicAnd FilterLogic = "and" // Match all patterns
)
// Represents filter configuration
type FilterConfig struct {
Type FilterType `toml:"type"`
Logic FilterLogic `toml:"logic"`
Patterns []string `toml:"patterns"`
}
// --- Formatter Options ---
type FormatConfig struct {
// Format configuration - polymorphic like sources/sinks
Type string `toml:"type"` // "json", "text", "raw"
// Only one will be populated based on format type
JSONFormatOptions *JSONFormatterOptions `toml:"json_format,omitempty"`
TextFormatOptions *TextFormatterOptions `toml:"text_format,omitempty"`
RawFormatOptions *RawFormatterOptions `toml:"raw_format,omitempty"`
}
type JSONFormatterOptions struct {
Pretty bool `toml:"pretty"`
TimestampField string `toml:"timestamp_field"`
LevelField string `toml:"level_field"`
MessageField string `toml:"message_field"`
SourceField string `toml:"source_field"`
}
type TextFormatterOptions struct {
Template string `toml:"template"`
TimestampFormat string `toml:"timestamp_format"`
}
type RawFormatterOptions struct {
AddNewLine bool `toml:"add_new_line"`
}
// --- Server-side Auth (for sources) ---
type BasicAuthConfig struct {
Users []BasicAuthUser `toml:"users"`
Realm string `toml:"realm"`
}
type BasicAuthUser struct {
Username string `toml:"username"`
PasswordHash string `toml:"password_hash"` // Argon2
}
type ScramAuthConfig struct {
Users []ScramUser `toml:"users"`
}
type ScramUser struct {
Username string `toml:"username"`
StoredKey string `toml:"stored_key"` // base64
ServerKey string `toml:"server_key"` // base64
Salt string `toml:"salt"` // base64
ArgonTime uint32 `toml:"argon_time"`
ArgonMemory uint32 `toml:"argon_memory"`
ArgonThreads uint8 `toml:"argon_threads"`
}
type TokenAuthConfig struct {
Tokens []string `toml:"tokens"`
}
// Server auth wrapper (for sources accepting connections)
type ServerAuthConfig struct {
Type string `toml:"type"` // "none", "basic", "token", "scram"
Basic *BasicAuthConfig `toml:"basic,omitempty"`
Token *TokenAuthConfig `toml:"token,omitempty"`
Scram *ScramAuthConfig `toml:"scram,omitempty"`
}

View File

@ -1,65 +0,0 @@
// FILE: logwisp/src/internal/config/filter.go
package config
import (
"fmt"
"regexp"
)
// Represents the filter type
type FilterType string
const (
FilterTypeInclude FilterType = "include" // Whitelist - only matching logs pass
FilterTypeExclude FilterType = "exclude" // Blacklist - matching logs are dropped
)
// Represents how multiple patterns are combined
type FilterLogic string
const (
FilterLogicOr FilterLogic = "or" // Match any pattern
FilterLogicAnd FilterLogic = "and" // Match all patterns
)
// Represents filter configuration
type FilterConfig struct {
Type FilterType `toml:"type"`
Logic FilterLogic `toml:"logic"`
Patterns []string `toml:"patterns"`
}
func validateFilter(pipelineName string, filterIndex int, cfg *FilterConfig) error {
// Validate filter type
switch cfg.Type {
case FilterTypeInclude, FilterTypeExclude, "":
// Valid types
default:
return fmt.Errorf("pipeline '%s' filter[%d]: invalid type '%s' (must be 'include' or 'exclude')",
pipelineName, filterIndex, cfg.Type)
}
// Validate filter logic
switch cfg.Logic {
case FilterLogicOr, FilterLogicAnd, "":
// Valid logic
default:
return fmt.Errorf("pipeline '%s' filter[%d]: invalid logic '%s' (must be 'or' or 'and')",
pipelineName, filterIndex, cfg.Logic)
}
// Empty patterns is valid - passes everything
if len(cfg.Patterns) == 0 {
return nil
}
// Validate regex patterns
for i, pattern := range cfg.Patterns {
if _, err := regexp.Compile(pattern); err != nil {
return fmt.Errorf("pipeline '%s' filter[%d] pattern[%d] '%s': invalid regex: %w",
pipelineName, filterIndex, i, pattern, err)
}
}
return nil
}

View File

@ -1,58 +0,0 @@
// FILE: logwisp/src/internal/config/ratelimit.go
package config
import (
"fmt"
"strings"
)
// Defines the action to take when a rate limit is exceeded.
type RateLimitPolicy int
const (
// PolicyPass allows all logs through, effectively disabling the limiter.
PolicyPass RateLimitPolicy = iota
// PolicyDrop drops logs that exceed the rate limit.
PolicyDrop
)
// Defines the configuration for pipeline-level rate limiting.
type RateLimitConfig struct {
// Rate is the number of log entries allowed per second. Default: 0 (disabled).
Rate float64 `toml:"rate"`
// Burst is the maximum number of log entries that can be sent in a short burst. Defaults to the Rate.
Burst float64 `toml:"burst"`
// Policy defines the action to take when the limit is exceeded. "pass" or "drop".
Policy string `toml:"policy"`
// MaxEntrySizeBytes is the maximum allowed size for a single log entry. 0 = no limit.
MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"`
}
func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error {
if cfg == nil {
return nil
}
if cfg.Rate < 0 {
return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName)
}
if cfg.Burst < 0 {
return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName)
}
if cfg.MaxEntrySizeBytes < 0 {
return fmt.Errorf("pipeline '%s': max entry size bytes cannot be negative", pipelineName)
}
// Validate policy
switch strings.ToLower(cfg.Policy) {
case "", "pass", "drop":
// Valid policies
default:
return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')",
pipelineName, cfg.Policy)
}
return nil
}

View File

@ -11,6 +11,8 @@ import (
lconfig "github.com/lixenwraith/config" lconfig "github.com/lixenwraith/config"
) )
var configManager *lconfig.Config
func defaults() *Config { func defaults() *Config {
return &Config{ return &Config{
// Top-level flag defaults // Top-level flag defaults
@ -21,41 +23,46 @@ func defaults() *Config {
// Runtime behavior defaults // Runtime behavior defaults
DisableStatusReporter: false, DisableStatusReporter: false,
ConfigAutoReload: false, ConfigAutoReload: false,
ConfigSaveOnExit: false,
// Child process indicator // Child process indicator
BackgroundDaemon: false, BackgroundDaemon: false,
// Existing defaults // Existing defaults
Logging: DefaultLogConfig(), Logging: &LogConfig{
Output: "stdout",
Level: "info",
File: &LogFileConfig{
Directory: "./log",
Name: "logwisp",
MaxSizeMB: 100,
MaxTotalSizeMB: 1000,
RetentionHours: 168, // 7 days
},
Console: &LogConsoleConfig{
Target: "stdout",
Format: "txt",
},
},
Pipelines: []PipelineConfig{ Pipelines: []PipelineConfig{
{ {
Name: "default", Name: "default",
Sources: []SourceConfig{ Sources: []SourceConfig{
{ {
Type: "directory", Type: "directory",
Options: map[string]any{ Directory: &DirectorySourceOptions{
"path": "./", Path: "./",
"pattern": "*.log", Pattern: "*.log",
"check_interval_ms": int64(100), CheckIntervalMS: int64(100),
}, },
}, },
}, },
Sinks: []SinkConfig{ Sinks: []SinkConfig{
{ {
Type: "http", Type: "console",
Options: map[string]any{ Console: &ConsoleSinkOptions{
"port": int64(8080), Target: "stdout",
"buffer_size": int64(1000), Colorize: false,
"stream_path": "/stream", BufferSize: 100,
"status_path": "/status",
"heartbeat": map[string]any{
"enabled": true,
"interval_seconds": int64(30),
"include_timestamp": true,
"include_stats": false,
"format": "comment",
},
}, },
}, },
}, },
@ -68,18 +75,30 @@ func defaults() *Config {
func Load(args []string) (*Config, error) { func Load(args []string) (*Config, error) {
configPath, isExplicit := resolveConfigPath(args) configPath, isExplicit := resolveConfigPath(args)
// Build configuration with all sources // Build configuration with all sources
// Create target config instance that will be populated
finalConfig := &Config{}
// The builder now handles loading, populating the target struct, and validation
cfg, err := lconfig.NewBuilder(). cfg, err := lconfig.NewBuilder().
WithDefaults(defaults()). WithTarget(finalConfig). // Typed target struct
WithEnvPrefix("LOGWISP_"). WithDefaults(defaults()). // Default values
WithEnvTransform(customEnvTransform).
WithArgs(args).
WithFile(configPath).
WithSources( WithSources(
lconfig.SourceCLI, lconfig.SourceCLI,
lconfig.SourceEnv, lconfig.SourceEnv,
lconfig.SourceFile, lconfig.SourceFile,
lconfig.SourceDefault, lconfig.SourceDefault,
). ).
WithEnvTransform(customEnvTransform). // Convert '.' to '_' in env separation
WithEnvPrefix("LOGWISP_"). // Environment variable prefix
WithArgs(args). // Command-line arguments
WithFile(configPath). // TOML config file
WithFileFormat("toml"). // Explicit format
WithTypedValidator(validateConfig). // Centralized validation
WithSecurityOptions(lconfig.SecurityOptions{
PreventPathTraversal: true,
MaxFileSize: 10 * 1024 * 1024, // 10MB max config
}).
Build() Build()
if err != nil { if err != nil {
@ -88,42 +107,28 @@ func Load(args []string) (*Config, error) {
if isExplicit { if isExplicit {
return nil, fmt.Errorf("config file not found: %s", configPath) return nil, fmt.Errorf("config file not found: %s", configPath)
} }
// If the default config file is not found, it's not an error // If the default config file is not found, it's not an error, default/cli/env will be used
} else { } else {
return nil, fmt.Errorf("failed to load config: %w", err) return nil, fmt.Errorf("failed to load or validate config: %w", err)
} }
} }
// Scan into final config struct - using new interface // Store the config file path for hot reload
finalConfig := &Config{} finalConfig.ConfigFile = configPath
if err := cfg.Scan(finalConfig); err != nil {
return nil, fmt.Errorf("failed to scan config: %w", err) // Store the manager for hot reload
if cfg != nil {
configManager = cfg
} }
// Set config file path if it exists return finalConfig, nil
if _, err := os.Stat(configPath); err == nil {
finalConfig.ConfigFile = configPath
}
// Ensure critical fields are not nil
if finalConfig.Logging == nil {
finalConfig.Logging = DefaultLogConfig()
}
// Apply console target overrides if needed
if err := applyConsoleTargetOverrides(finalConfig); err != nil {
return nil, fmt.Errorf("failed to apply console target overrides: %w", err)
}
// Validate configuration
return finalConfig, finalConfig.validate()
} }
// Returns the configuration file path // Returns the configuration file path
func resolveConfigPath(args []string) (path string, isExplicit bool) { func resolveConfigPath(args []string) (path string, isExplicit bool) {
// 1. Check for --config flag in command-line arguments (highest precedence) // 1. Check for --config flag in command-line arguments (highest precedence)
for i, arg := range args { for i, arg := range args {
if (arg == "--config" || arg == "-c") && i+1 < len(args) { if arg == "-c" {
return args[i+1], true return args[i+1], true
} }
if strings.HasPrefix(arg, "--config=") { if strings.HasPrefix(arg, "--config=") {
@ -161,37 +166,3 @@ func customEnvTransform(path string) string {
// env = "LOGWISP_" + env // already added by WithEnvPrefix // env = "LOGWISP_" + env // already added by WithEnvPrefix
return env return env
} }
// Centralizes console target configuration
func applyConsoleTargetOverrides(cfg *Config) error {
// Check environment variable for console target override
consoleTarget := os.Getenv("LOGWISP_CONSOLE_TARGET")
if consoleTarget == "" {
return nil
}
// Validate console target value
validTargets := map[string]bool{
"stdout": true,
"stderr": true,
"split": true,
}
if !validTargets[consoleTarget] {
return fmt.Errorf("invalid LOGWISP_CONSOLE_TARGET value: %s", consoleTarget)
}
// Apply to console sinks
for i, pipeline := range cfg.Pipelines {
for j, sink := range pipeline.Sinks {
if sink.Type == "console" {
if sink.Options == nil {
cfg.Pipelines[i].Sinks[j].Options = make(map[string]any)
}
// Set target for split mode handling
cfg.Pipelines[i].Sinks[j].Options["target"] = consoleTarget
}
}
}
return nil
}

View File

@ -1,99 +0,0 @@
// FILE: logwisp/src/internal/config/logging.go
package config
import "fmt"
// Represents logging configuration for LogWisp
type LogConfig struct {
// Output mode: "file", "stdout", "stderr", "split", "all", "none"
Output string `toml:"output"`
// Log level: "debug", "info", "warn", "error"
Level string `toml:"level"`
// File output settings (when Output includes "file" or "all")
File *LogFileConfig `toml:"file"`
// Console output settings
Console *LogConsoleConfig `toml:"console"`
}
type LogFileConfig struct {
// Directory for log files
Directory string `toml:"directory"`
// Base name for log files
Name string `toml:"name"`
// Maximum size per log file in MB
MaxSizeMB int64 `toml:"max_size_mb"`
// Maximum total size of all logs in MB
MaxTotalSizeMB int64 `toml:"max_total_size_mb"`
// Log retention in hours (0 = disabled)
RetentionHours float64 `toml:"retention_hours"`
}
type LogConsoleConfig struct {
// Target for console output: "stdout", "stderr", "split"
// "split": info/debug to stdout, warn/error to stderr
Target string `toml:"target"`
// Format: "txt" or "json"
Format string `toml:"format"`
}
// Returns sensible logging defaults
func DefaultLogConfig() *LogConfig {
return &LogConfig{
Output: "stdout",
Level: "info",
File: &LogFileConfig{
Directory: "./log",
Name: "logwisp",
MaxSizeMB: 100,
MaxTotalSizeMB: 1000,
RetentionHours: 168, // 7 days
},
Console: &LogConsoleConfig{
Target: "stdout",
Format: "txt",
},
}
}
func validateLogConfig(cfg *LogConfig) error {
validOutputs := map[string]bool{
"file": true, "stdout": true, "stderr": true,
"split": true, "all": true, "none": true,
}
if !validOutputs[cfg.Output] {
return fmt.Errorf("invalid log output mode: %s", cfg.Output)
}
validLevels := map[string]bool{
"debug": true, "info": true, "warn": true, "error": true,
}
if !validLevels[cfg.Level] {
return fmt.Errorf("invalid log level: %s", cfg.Level)
}
if cfg.Console != nil {
validTargets := map[string]bool{
"stdout": true, "stderr": true, "split": true,
}
if !validTargets[cfg.Console.Target] {
return fmt.Errorf("invalid console target: %s", cfg.Console.Target)
}
validFormats := map[string]bool{
"txt": true, "json": true, "": true,
}
if !validFormats[cfg.Console.Format] {
return fmt.Errorf("invalid console format: %s", cfg.Console.Format)
}
}
return nil
}

View File

@ -1,416 +0,0 @@
// FILE: logwisp/src/internal/config/pipeline.go
package config
import (
"fmt"
"net"
"net/url"
"path/filepath"
"strings"
)
// Represents a data processing pipeline
type PipelineConfig struct {
// Pipeline identifier (used in logs and metrics)
Name string `toml:"name"`
// Data sources for this pipeline
Sources []SourceConfig `toml:"sources"`
// Rate limiting
RateLimit *RateLimitConfig `toml:"rate_limit"`
// Filter configuration
Filters []FilterConfig `toml:"filters"`
// Log formatting configuration
Format string `toml:"format"`
FormatOptions map[string]any `toml:"format_options"`
// Output sinks for this pipeline
Sinks []SinkConfig `toml:"sinks"`
// Authentication/Authorization (applies to network sinks)
Auth *AuthConfig `toml:"auth"`
}
// Represents an input data source
type SourceConfig struct {
// Source type
Type string `toml:"type"`
// Type-specific configuration options
Options map[string]any `toml:"options"`
}
// Represents an output destination
type SinkConfig struct {
// Sink type
Type string `toml:"type"`
// Type-specific configuration options
Options map[string]any `toml:"options"`
}
func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) error {
if cfg.Type == "" {
return fmt.Errorf("pipeline '%s' source[%d]: missing type", pipelineName, sourceIndex)
}
switch cfg.Type {
case "directory":
// Validate path
path, ok := cfg.Options["path"].(string)
if !ok || path == "" {
return fmt.Errorf("pipeline '%s' source[%d]: directory source requires 'path' option",
pipelineName, sourceIndex)
}
// Check for directory traversal
if strings.Contains(path, "..") {
return fmt.Errorf("pipeline '%s' source[%d]: path contains directory traversal",
pipelineName, sourceIndex)
}
// Validate pattern
if pattern, ok := cfg.Options["pattern"].(string); ok && pattern != "" {
// Try to compile as glob pattern (will be converted to regex internally)
if strings.Count(pattern, "*") == 0 && strings.Count(pattern, "?") == 0 {
// If no wildcards, ensure it's a valid filename
if filepath.Base(pattern) != pattern {
return fmt.Errorf("pipeline '%s' source[%d]: pattern contains path separators",
pipelineName, sourceIndex)
}
}
}
// Validate check interval
if interval, ok := cfg.Options["check_interval_ms"]; ok {
if intVal, ok := interval.(int64); ok {
if intVal < 10 {
return fmt.Errorf("pipeline '%s' source[%d]: check interval too small: %d ms (min: 10ms)",
pipelineName, sourceIndex, intVal)
}
} else {
return fmt.Errorf("pipeline '%s' source[%d]: invalid check_interval_ms type",
pipelineName, sourceIndex)
}
}
case "stdin":
// Validate buffer size
if bufSize, ok := cfg.Options["buffer_size"].(int64); ok {
if bufSize < 1 {
return fmt.Errorf("pipeline '%s' source[%d]: stdin buffer_size must be positive: %d",
pipelineName, sourceIndex, bufSize)
}
}
case "http":
// Validate host
if host, ok := cfg.Options["host"].(string); ok && host != "" {
if net.ParseIP(host) == nil {
return fmt.Errorf("pipeline '%s' source[%d]: invalid IP address: %s",
pipelineName, sourceIndex, host)
}
}
// Validate port
port, ok := cfg.Options["port"].(int64)
if !ok || port < 1 || port > 65535 {
return fmt.Errorf("pipeline '%s' source[%d]: invalid or missing HTTP port",
pipelineName, sourceIndex)
}
// Validate path
if path, ok := cfg.Options["ingest_path"].(string); ok {
if !strings.HasPrefix(path, "/") {
return fmt.Errorf("pipeline '%s' source[%d]: ingest path must start with /: %s",
pipelineName, sourceIndex, path)
}
}
// Validate net_limit
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, nl); err != nil {
return err
}
}
// Validate TLS
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
if err := validateTLSOptions("HTTP source", pipelineName, sourceIndex, tls); err != nil {
return err
}
}
case "tcp":
// Validate host
if host, ok := cfg.Options["host"].(string); ok && host != "" {
if net.ParseIP(host) == nil {
return fmt.Errorf("pipeline '%s' source[%d]: invalid IP address: %s",
pipelineName, sourceIndex, host)
}
}
// Validate port
port, ok := cfg.Options["port"].(int64)
if !ok || port < 1 || port > 65535 {
return fmt.Errorf("pipeline '%s' source[%d]: invalid or missing TCP port",
pipelineName, sourceIndex)
}
// Validate net_limit
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, nl); err != nil {
return err
}
}
default:
return fmt.Errorf("pipeline '%s' source[%d]: unknown source type '%s'",
pipelineName, sourceIndex, cfg.Type)
}
return nil
}
func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts map[int64]string) error {
if cfg.Type == "" {
return fmt.Errorf("pipeline '%s' sink[%d]: missing type", pipelineName, sinkIndex)
}
switch cfg.Type {
case "http":
// Extract and validate HTTP configuration
port, ok := cfg.Options["port"].(int64)
if !ok || port < 1 || port > 65535 {
return fmt.Errorf("pipeline '%s' sink[%d]: invalid or missing HTTP port",
pipelineName, sinkIndex)
}
// Validate host
if host, ok := cfg.Options["host"].(string); ok && host != "" {
if net.ParseIP(host) == nil {
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
pipelineName, sinkIndex, host)
}
}
// Check port conflicts
if existing, exists := allPorts[port]; exists {
return fmt.Errorf("pipeline '%s' sink[%d]: HTTP port %d already used by %s",
pipelineName, sinkIndex, port, existing)
}
allPorts[port] = fmt.Sprintf("%s-http[%d]", pipelineName, sinkIndex)
// Validate buffer size
if bufSize, ok := cfg.Options["buffer_size"].(int64); ok {
if bufSize < 1 {
return fmt.Errorf("pipeline '%s' sink[%d]: HTTP buffer size must be positive: %d",
pipelineName, sinkIndex, bufSize)
}
}
// Validate paths
if streamPath, ok := cfg.Options["stream_path"].(string); ok {
if !strings.HasPrefix(streamPath, "/") {
return fmt.Errorf("pipeline '%s' sink[%d]: stream path must start with /: %s",
pipelineName, sinkIndex, streamPath)
}
}
if statusPath, ok := cfg.Options["status_path"].(string); ok {
if !strings.HasPrefix(statusPath, "/") {
return fmt.Errorf("pipeline '%s' sink[%d]: status path must start with /: %s",
pipelineName, sinkIndex, statusPath)
}
}
// Validate heartbeat
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
if err := validateHeartbeatOptions("HTTP", pipelineName, sinkIndex, hb); err != nil {
return err
}
}
// Validate TLS if present
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
if err := validateTLSOptions("HTTP", pipelineName, sinkIndex, tls); err != nil {
return err
}
}
// Validate net limit
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, nl); err != nil {
return err
}
}
case "tcp":
// Extract and validate TCP configuration
port, ok := cfg.Options["port"].(int64)
if !ok || port < 1 || port > 65535 {
return fmt.Errorf("pipeline '%s' sink[%d]: invalid or missing TCP port",
pipelineName, sinkIndex)
}
// Validate host
if host, ok := cfg.Options["host"].(string); ok && host != "" {
if net.ParseIP(host) == nil {
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
pipelineName, sinkIndex, host)
}
}
// Check port conflicts
if existing, exists := allPorts[port]; exists {
return fmt.Errorf("pipeline '%s' sink[%d]: TCP port %d already used by %s",
pipelineName, sinkIndex, port, existing)
}
allPorts[port] = fmt.Sprintf("%s-tcp[%d]", pipelineName, sinkIndex)
// Validate buffer size
if bufSize, ok := cfg.Options["buffer_size"].(int64); ok {
if bufSize < 1 {
return fmt.Errorf("pipeline '%s' sink[%d]: TCP buffer size must be positive: %d",
pipelineName, sinkIndex, bufSize)
}
}
// Validate heartbeat
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
if err := validateHeartbeatOptions("TCP", pipelineName, sinkIndex, hb); err != nil {
return err
}
}
// Validate net limit
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, nl); err != nil {
return err
}
}
case "http_client":
// Validate URL
urlStr, ok := cfg.Options["url"].(string)
if !ok || urlStr == "" {
return fmt.Errorf("pipeline '%s' sink[%d]: http_client sink requires 'url' option",
pipelineName, sinkIndex)
}
// Validate URL format
parsedURL, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("pipeline '%s' sink[%d]: invalid URL: %w",
pipelineName, sinkIndex, err)
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return fmt.Errorf("pipeline '%s' sink[%d]: URL must use http or https scheme",
pipelineName, sinkIndex)
}
// Validate batch size
if batchSize, ok := cfg.Options["batch_size"].(int64); ok {
if batchSize < 1 {
return fmt.Errorf("pipeline '%s' sink[%d]: batch_size must be positive: %d",
pipelineName, sinkIndex, batchSize)
}
}
// Validate timeout
if timeout, ok := cfg.Options["timeout_seconds"].(int64); ok {
if timeout < 1 {
return fmt.Errorf("pipeline '%s' sink[%d]: timeout_seconds must be positive: %d",
pipelineName, sinkIndex, timeout)
}
}
case "tcp_client":
// Added validation for TCP client sink
// Validate address
address, ok := cfg.Options["address"].(string)
if !ok || address == "" {
return fmt.Errorf("pipeline '%s' sink[%d]: tcp_client sink requires 'address' option",
pipelineName, sinkIndex)
}
// Validate address format
_, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("pipeline '%s' sink[%d]: invalid address format (expected host:port): %w",
pipelineName, sinkIndex, err)
}
// Validate timeouts
if dialTimeout, ok := cfg.Options["dial_timeout_seconds"].(int64); ok {
if dialTimeout < 1 {
return fmt.Errorf("pipeline '%s' sink[%d]: dial_timeout_seconds must be positive: %d",
pipelineName, sinkIndex, dialTimeout)
}
}
if writeTimeout, ok := cfg.Options["write_timeout_seconds"].(int64); ok {
if writeTimeout < 1 {
return fmt.Errorf("pipeline '%s' sink[%d]: write_timeout_seconds must be positive: %d",
pipelineName, sinkIndex, writeTimeout)
}
}
case "file":
// Validate directory
directory, ok := cfg.Options["directory"].(string)
if !ok || directory == "" {
return fmt.Errorf("pipeline '%s' sink[%d]: file sink requires 'directory' option",
pipelineName, sinkIndex)
}
// Validate filename
name, ok := cfg.Options["name"].(string)
if !ok || name == "" {
return fmt.Errorf("pipeline '%s' sink[%d]: file sink requires 'name' option",
pipelineName, sinkIndex)
}
// Validate size options
if maxSize, ok := cfg.Options["max_size_mb"].(int64); ok {
if maxSize < 1 {
return fmt.Errorf("pipeline '%s' sink[%d]: max_size_mb must be positive: %d",
pipelineName, sinkIndex, maxSize)
}
}
if maxTotalSize, ok := cfg.Options["max_total_size_mb"].(int64); ok {
if maxTotalSize < 0 {
return fmt.Errorf("pipeline '%s' sink[%d]: max_total_size_mb cannot be negative: %d",
pipelineName, sinkIndex, maxTotalSize)
}
}
if minDiskFree, ok := cfg.Options["min_disk_free_mb"].(int64); ok {
if minDiskFree < 0 {
return fmt.Errorf("pipeline '%s' sink[%d]: min_disk_free_mb cannot be negative: %d",
pipelineName, sinkIndex, minDiskFree)
}
}
// Validate retention period
if retention, ok := cfg.Options["retention_hours"].(float64); ok {
if retention < 0 {
return fmt.Errorf("pipeline '%s' sink[%d]: retention_hours cannot be negative: %f",
pipelineName, sinkIndex, retention)
}
}
case "console":
// No specific validation needed for console sinks
default:
return fmt.Errorf("pipeline '%s' sink[%d]: unknown sink type '%s'",
pipelineName, sinkIndex, cfg.Type)
}
return nil
}

View File

@ -1,33 +0,0 @@
// FILE: logwisp/src/internal/config/saver.go
package config
import (
"fmt"
lconfig "github.com/lixenwraith/config"
)
// Saves the configuration to the specified file path.
func (c *Config) SaveToFile(path string) error {
if path == "" {
return fmt.Errorf("cannot save config: path is empty")
}
// Create a temporary lconfig instance just for saving
// This avoids the need to track lconfig throughout the application
lcfg, err := lconfig.NewBuilder().
WithFile(path).
WithTarget(c).
WithFileFormat("toml").
Build()
if err != nil {
return fmt.Errorf("failed to create config builder: %w", err)
}
// Use lconfig's Save method which handles atomic writes
if err := lcfg.Save(path); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
return nil
}

View File

@ -1,203 +0,0 @@
// FILE: logwisp/src/internal/config/server.go
package config
import (
"fmt"
"net"
"strings"
)
type TCPConfig struct {
Enabled bool `toml:"enabled"`
Port int64 `toml:"port"`
BufferSize int64 `toml:"buffer_size"`
// Net limiting
NetLimit *NetLimitConfig `toml:"net_limit"`
// Heartbeat
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
}
type HTTPConfig struct {
Enabled bool `toml:"enabled"`
Port int64 `toml:"port"`
BufferSize int64 `toml:"buffer_size"`
// Endpoint paths
StreamPath string `toml:"stream_path"`
StatusPath string `toml:"status_path"`
// TLS Configuration
TLS *TLSConfig `toml:"tls"`
// Nate limiting
NetLimit *NetLimitConfig `toml:"net_limit"`
// Heartbeat
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
}
type HeartbeatConfig struct {
Enabled bool `toml:"enabled"`
IntervalSeconds int64 `toml:"interval_seconds"`
IncludeTimestamp bool `toml:"include_timestamp"`
IncludeStats bool `toml:"include_stats"`
Format string `toml:"format"`
}
type NetLimitConfig struct {
// Enable net limiting
Enabled bool `toml:"enabled"`
// IP Access Control Lists
IPWhitelist []string `toml:"ip_whitelist"`
IPBlacklist []string `toml:"ip_blacklist"`
// Requests per second per client
RequestsPerSecond float64 `toml:"requests_per_second"`
// Burst size (token bucket)
BurstSize int64 `toml:"burst_size"`
// Response when net limited
ResponseCode int64 `toml:"response_code"` // Default: 429
ResponseMessage string `toml:"response_message"` // Default: "Net limit exceeded"
// Connection limits
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
MaxConnectionsPerUser int64 `toml:"max_connections_per_user"`
MaxConnectionsPerToken int64 `toml:"max_connections_per_token"`
MaxConnectionsTotal int64 `toml:"max_connections_total"`
}
func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb map[string]any) error {
if enabled, ok := hb["enabled"].(bool); ok && enabled {
interval, ok := hb["interval_seconds"].(int64)
if !ok || interval < 1 {
return fmt.Errorf("pipeline '%s' sink[%d] %s: heartbeat interval must be positive",
pipelineName, sinkIndex, serverType)
}
if format, ok := hb["format"].(string); ok {
if format != "json" && format != "comment" {
return fmt.Errorf("pipeline '%s' sink[%d] %s: heartbeat format must be 'json' or 'comment': %s",
pipelineName, sinkIndex, serverType, format)
}
}
}
return nil
}
func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, nl map[string]any) error {
if enabled, ok := nl["enabled"].(bool); !ok || !enabled {
return nil
}
// Validate IP lists if present
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
for i, entry := range ipWhitelist {
entryStr, ok := entry.(string)
if !ok {
continue
}
if err := validateIPv4Entry(entryStr); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: whitelist[%d] %v",
pipelineName, sinkIndex, serverType, i, err)
}
}
}
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
for i, entry := range ipBlacklist {
entryStr, ok := entry.(string)
if !ok {
continue
}
if err := validateIPv4Entry(entryStr); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: blacklist[%d] %v",
pipelineName, sinkIndex, serverType, i, err)
}
}
}
// Validate requests per second
rps, ok := nl["requests_per_second"].(float64)
if !ok || rps <= 0 {
return fmt.Errorf("pipeline '%s' sink[%d] %s: requests_per_second must be positive",
pipelineName, sinkIndex, serverType)
}
// Validate burst size
burst, ok := nl["burst_size"].(int64)
if !ok || burst < 1 {
return fmt.Errorf("pipeline '%s' sink[%d] %s: burst_size must be at least 1",
pipelineName, sinkIndex, serverType)
}
// Validate response code
if respCode, ok := nl["response_code"].(int64); ok {
if respCode > 0 && (respCode < 400 || respCode >= 600) {
return fmt.Errorf("pipeline '%s' sink[%d] %s: response_code must be 4xx or 5xx: %d",
pipelineName, sinkIndex, serverType, respCode)
}
}
// Validate connection limits
maxPerIP, perIPOk := nl["max_connections_per_ip"].(int64)
maxPerUser, perUserOk := nl["max_connections_per_user"].(int64)
maxPerToken, perTokenOk := nl["max_connections_per_token"].(int64)
maxTotal, totalOk := nl["max_connections_total"].(int64)
if perIPOk && perUserOk && perTokenOk && totalOk &&
maxPerIP > 0 && maxPerUser > 0 && maxPerToken > 0 && maxTotal > 0 {
if maxPerIP > maxTotal {
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_connections_total (%d)",
pipelineName, sinkIndex, serverType, maxPerIP, maxTotal)
}
if maxPerUser > maxTotal {
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_user (%d) cannot exceed max_connections_total (%d)",
pipelineName, sinkIndex, serverType, maxPerUser, maxTotal)
}
if maxPerToken > maxTotal {
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_token (%d) cannot exceed max_connections_total (%d)",
pipelineName, sinkIndex, serverType, maxPerToken, maxTotal)
}
}
return nil
}
// Ensures an IP or CIDR is IPv4
func validateIPv4Entry(entry string) error {
// Handle single IP
if !strings.Contains(entry, "/") {
ip := net.ParseIP(entry)
if ip == nil {
return fmt.Errorf("invalid IP address: %s", entry)
}
if ip.To4() == nil {
return fmt.Errorf("IPv6 not supported (IPv4-only): %s", entry)
}
return nil
}
// Handle CIDR
ipAddr, ipNet, err := net.ParseCIDR(entry)
if err != nil {
return fmt.Errorf("invalid CIDR: %s", entry)
}
// Check if the IP is IPv4
if ipAddr.To4() == nil {
return fmt.Errorf("IPv6 CIDR not supported (IPv4-only): %s", entry)
}
// Verify the network mask is appropriate for IPv4
_, bits := ipNet.Mask.Size()
if bits != 32 {
return fmt.Errorf("invalid IPv4 CIDR mask (got %d bits, expected 32): %s", bits, entry)
}
return nil
}

View File

@ -1,82 +0,0 @@
// FILE: logwisp/src/internal/config/tls.go
package config
import (
"fmt"
"os"
)
type TLSConfig struct {
Enabled bool `toml:"enabled"`
CertFile string `toml:"cert_file"`
KeyFile string `toml:"key_file"`
// Client certificate authentication
ClientAuth bool `toml:"client_auth"`
ClientCAFile string `toml:"client_ca_file"`
VerifyClientCert bool `toml:"verify_client_cert"`
// Option to skip verification for clients
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
// CA file for client to trust specific server certificates
CAFile string `toml:"ca_file"`
// TLS version constraints
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
MaxVersion string `toml:"max_version"`
// Cipher suites (comma-separated list)
CipherSuites string `toml:"cipher_suites"`
}
func validateTLSOptions(serverType, pipelineName string, sinkIndex int, tls map[string]any) error {
if enabled, ok := tls["enabled"].(bool); ok && enabled {
certFile, certOk := tls["cert_file"].(string)
keyFile, keyOk := tls["key_file"].(string)
if !certOk || certFile == "" || !keyOk || keyFile == "" {
return fmt.Errorf("pipeline '%s' sink[%d] %s: TLS enabled but cert/key files not specified",
pipelineName, sinkIndex, serverType)
}
// Validate that certificate files exist and are readable
if _, err := os.Stat(certFile); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: cert_file is not accessible: %w",
pipelineName, sinkIndex, serverType, err)
}
if _, err := os.Stat(keyFile); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: key_file is not accessible: %w",
pipelineName, sinkIndex, serverType, err)
}
if clientAuth, ok := tls["client_auth"].(bool); ok && clientAuth {
caFile, caOk := tls["client_ca_file"].(string)
if !caOk || caFile == "" {
return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified",
pipelineName, sinkIndex, serverType)
}
// Validate that the client CA file exists and is readable
if _, err := os.Stat(caFile); err != nil {
return fmt.Errorf("pipeline '%s' sink[%d] %s: client_ca_file is not accessible: %w",
pipelineName, sinkIndex, serverType, err)
}
}
// Validate TLS versions
validVersions := map[string]bool{"TLS1.0": true, "TLS1.1": true, "TLS1.2": true, "TLS1.3": true}
if minVer, ok := tls["min_version"].(string); ok && minVer != "" {
if !validVersions[minVer] {
return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid min TLS version: %s",
pipelineName, sinkIndex, serverType, minVer)
}
}
if maxVer, ok := tls["max_version"].(string); ok && maxVer != "" {
if !validVersions[maxVer] {
return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid max TLS version: %s",
pipelineName, sinkIndex, serverType, maxVer)
}
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,13 @@
// FILE: logwisp/src/internal/core/const.go
package core
// Argon2id parameters
const (
Argon2Time = 3
Argon2Memory = 64 * 1024 // 64 MB
Argon2Threads = 4
Argon2SaltLen = 16
Argon2KeyLen = 32
)
const DefaultTokenLength = 32

View File

@ -1,80 +0,0 @@
// FILE: logwisp/src/internal/filter/chain_test.go
package filter
import (
"testing"
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"github.com/stretchr/testify/assert"
)
func TestNewChain(t *testing.T) {
logger := newTestLogger()
t.Run("Success", func(t *testing.T) {
configs := []config.FilterConfig{
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
{Type: config.FilterTypeExclude, Patterns: []string{"banana"}},
}
chain, err := NewChain(configs, logger)
assert.NoError(t, err)
assert.NotNil(t, chain)
assert.Len(t, chain.filters, 2)
})
t.Run("ErrorInvalidRegexInChain", func(t *testing.T) {
configs := []config.FilterConfig{
{Patterns: []string{"apple"}},
{Patterns: []string{"["}},
}
chain, err := NewChain(configs, logger)
assert.Error(t, err)
assert.Nil(t, chain)
assert.Contains(t, err.Error(), "filter[1]")
})
}
func TestChain_Apply(t *testing.T) {
logger := newTestLogger()
entry := core.LogEntry{Message: "an apple a day"}
t.Run("EmptyChain", func(t *testing.T) {
chain, err := NewChain([]config.FilterConfig{}, logger)
assert.NoError(t, err)
assert.True(t, chain.Apply(entry))
})
t.Run("AllFiltersPass", func(t *testing.T) {
configs := []config.FilterConfig{
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
{Type: config.FilterTypeInclude, Patterns: []string{"day"}},
{Type: config.FilterTypeExclude, Patterns: []string{"banana"}},
}
chain, err := NewChain(configs, logger)
assert.NoError(t, err)
assert.True(t, chain.Apply(entry))
})
t.Run("OneFilterFails", func(t *testing.T) {
configs := []config.FilterConfig{
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
{Type: config.FilterTypeExclude, Patterns: []string{"day"}}, // This one will fail
{Type: config.FilterTypeInclude, Patterns: []string{"a"}},
}
chain, err := NewChain(configs, logger)
assert.NoError(t, err)
assert.False(t, chain.Apply(entry))
})
t.Run("FirstFilterFails", func(t *testing.T) {
configs := []config.FilterConfig{
{Type: config.FilterTypeInclude, Patterns: []string{"banana"}}, // This one will fail
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
}
chain, err := NewChain(configs, logger)
assert.NoError(t, err)
assert.False(t, chain.Apply(entry))
})
}

View File

@ -1,159 +0,0 @@
// FILE: logwisp/src/internal/filter/filter_test.go
package filter
import (
"logwisp/src/internal/config"
"logwisp/src/internal/core"
"testing"
"github.com/lixenwraith/log"
"github.com/stretchr/testify/assert"
)
func newTestLogger() *log.Logger {
return log.NewLogger()
}
func TestNewFilter(t *testing.T) {
logger := newTestLogger()
t.Run("SuccessWithDefaults", func(t *testing.T) {
cfg := config.FilterConfig{Patterns: []string{"test"}}
f, err := NewFilter(cfg, logger)
assert.NoError(t, err)
assert.NotNil(t, f)
assert.Equal(t, config.FilterTypeInclude, f.config.Type)
assert.Equal(t, config.FilterLogicOr, f.config.Logic)
})
t.Run("SuccessWithCustomConfig", func(t *testing.T) {
cfg := config.FilterConfig{
Type: config.FilterTypeExclude,
Logic: config.FilterLogicAnd,
Patterns: []string{"test", "pattern"},
}
f, err := NewFilter(cfg, logger)
assert.NoError(t, err)
assert.NotNil(t, f)
assert.Equal(t, config.FilterTypeExclude, f.config.Type)
assert.Equal(t, config.FilterLogicAnd, f.config.Logic)
assert.Len(t, f.patterns, 2)
})
t.Run("ErrorInvalidRegex", func(t *testing.T) {
cfg := config.FilterConfig{Patterns: []string{"["}}
f, err := NewFilter(cfg, logger)
assert.Error(t, err)
assert.Nil(t, f)
assert.Contains(t, err.Error(), "invalid regex pattern")
})
}
func TestFilter_Apply(t *testing.T) {
logger := newTestLogger()
testCases := []struct {
name string
cfg config.FilterConfig
entry core.LogEntry
expected bool
}{
// Include OR logic
{
name: "IncludeOR_MatchOne",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}},
entry: core.LogEntry{Message: "this is an apple"},
expected: true,
},
{
name: "IncludeOR_NoMatch",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}},
entry: core.LogEntry{Message: "this is a pear"},
expected: false,
},
// Include AND logic
{
name: "IncludeAND_MatchAll",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}},
entry: core.LogEntry{Message: "an apple keeps the doctor away"},
expected: true,
},
{
name: "IncludeAND_MatchOne",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}},
entry: core.LogEntry{Message: "this is an apple"},
expected: false,
},
// Exclude OR logic
{
name: "ExcludeOR_MatchOne",
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}},
entry: core.LogEntry{Message: "this is an error"},
expected: false,
},
{
name: "ExcludeOR_NoMatch",
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}},
entry: core.LogEntry{Message: "this is a warning"},
expected: true,
},
// Exclude AND logic
{
name: "ExcludeAND_MatchAll",
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}},
entry: core.LogEntry{Message: "critical error in database"},
expected: false,
},
{
name: "ExcludeAND_MatchOne",
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}},
entry: core.LogEntry{Message: "critical error in app"},
expected: true,
},
// Edge Cases
{
name: "NoPatterns",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{}},
entry: core.LogEntry{Message: "any message"},
expected: true,
},
{
name: "EmptyEntry_NoPatterns",
cfg: config.FilterConfig{Patterns: []string{}},
entry: core.LogEntry{},
expected: true,
},
{
name: "EmptyEntry_DoesNotMatchSpace",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{" "}},
entry: core.LogEntry{Level: "", Source: "", Message: ""},
expected: false, // CORRECTED: An empty entry results in an empty string, which doesn't match a space.
},
{
name: "MatchOnLevel",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"ERROR"}},
entry: core.LogEntry{Level: "ERROR", Message: "A message"},
expected: true,
},
{
name: "MatchOnSource",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"database"}},
entry: core.LogEntry{Source: "database", Message: "A message"},
expected: true,
},
{
name: "MatchOnCombinedFields",
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"^app ERROR"}},
entry: core.LogEntry{Source: "app", Level: "ERROR", Message: "A message"},
expected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
f, err := NewFilter(tc.cfg, logger)
assert.NoError(t, err)
assert.Equal(t, tc.expected, f.Apply(tc.entry))
})
}
}

View File

@ -3,6 +3,7 @@ package format
import ( import (
"fmt" "fmt"
"logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
@ -19,20 +20,15 @@ type Formatter interface {
} }
// Creates a new Formatter based on the provided configuration. // Creates a new Formatter based on the provided configuration.
func NewFormatter(name string, options map[string]any, logger *log.Logger) (Formatter, error) { func NewFormatter(cfg *config.FormatConfig, logger *log.Logger) (Formatter, error) {
// Default to raw if no format specified switch cfg.Type {
if name == "" {
name = "raw"
}
switch name {
case "json": case "json":
return NewJSONFormatter(options, logger) return NewJSONFormatter(cfg.JSONFormatOptions, logger)
case "txt": case "txt":
return NewTextFormatter(options, logger) return NewTextFormatter(cfg.TextFormatOptions, logger)
case "raw": case "raw", "":
return NewRawFormatter(options, logger) return NewRawFormatter(cfg.RawFormatOptions, logger)
default: default:
return nil, fmt.Errorf("unknown formatter type: %s", name) return nil, fmt.Errorf("unknown formatter type: %s", cfg.Type)
} }
} }

View File

@ -1,65 +0,0 @@
// FILE: logwisp/src/internal/format/format_test.go
package format
import (
"testing"
"github.com/lixenwraith/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestLogger() *log.Logger {
return log.NewLogger()
}
func TestNewFormatter(t *testing.T) {
logger := newTestLogger()
testCases := []struct {
name string
formatName string
expected string
expectError bool
}{
{
name: "JSONFormatter",
formatName: "json",
expected: "json",
},
{
name: "TextFormatter",
formatName: "txt",
expected: "txt",
},
{
name: "RawFormatter",
formatName: "raw",
expected: "raw",
},
{
name: "DefaultToRaw",
formatName: "",
expected: "raw",
},
{
name: "UnknownFormatter",
formatName: "xml",
expectError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
formatter, err := NewFormatter(tc.formatName, nil, logger)
if tc.expectError {
assert.Error(t, err)
assert.Nil(t, formatter)
} else {
require.NoError(t, err)
require.NotNil(t, formatter)
assert.Equal(t, tc.expected, formatter.Name())
}
})
}
}

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"time" "time"
"logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
@ -13,39 +14,15 @@ import (
// Produces structured JSON logs // Produces structured JSON logs
type JSONFormatter struct { type JSONFormatter struct {
pretty bool config *config.JSONFormatterOptions
timestampField string logger *log.Logger
levelField string
messageField string
sourceField string
logger *log.Logger
} }
// Creates a new JSON formatter // Creates a new JSON formatter
func NewJSONFormatter(options map[string]any, logger *log.Logger) (*JSONFormatter, error) { func NewJSONFormatter(opts *config.JSONFormatterOptions, logger *log.Logger) (*JSONFormatter, error) {
f := &JSONFormatter{ f := &JSONFormatter{
timestampField: "timestamp", config: opts,
levelField: "level", logger: logger,
messageField: "message",
sourceField: "source",
logger: logger,
}
// Extract options
if pretty, ok := options["pretty"].(bool); ok {
f.pretty = pretty
}
if field, ok := options["timestamp_field"].(string); ok && field != "" {
f.timestampField = field
}
if field, ok := options["level_field"].(string); ok && field != "" {
f.levelField = field
}
if field, ok := options["message_field"].(string); ok && field != "" {
f.messageField = field
}
if field, ok := options["source_field"].(string); ok && field != "" {
f.sourceField = field
} }
return f, nil return f, nil
@ -57,9 +34,9 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) {
output := make(map[string]any) output := make(map[string]any)
// First, populate with LogWisp metadata // First, populate with LogWisp metadata
output[f.timestampField] = entry.Time.Format(time.RFC3339Nano) output[f.config.TimestampField] = entry.Time.Format(time.RFC3339Nano)
output[f.levelField] = entry.Level output[f.config.LevelField] = entry.Level
output[f.sourceField] = entry.Source output[f.config.SourceField] = entry.Source
// Try to parse the message as JSON // Try to parse the message as JSON
var msgData map[string]any var msgData map[string]any
@ -68,21 +45,21 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) {
// LogWisp metadata takes precedence // LogWisp metadata takes precedence
for k, v := range msgData { for k, v := range msgData {
// Don't overwrite our standard fields // Don't overwrite our standard fields
if k != f.timestampField && k != f.levelField && k != f.sourceField { if k != f.config.TimestampField && k != f.config.LevelField && k != f.config.SourceField {
output[k] = v output[k] = v
} }
} }
// If the original JSON had these fields, log that we're overriding // If the original JSON had these fields, log that we're overriding
if _, hasTime := msgData[f.timestampField]; hasTime { if _, hasTime := msgData[f.config.TimestampField]; hasTime {
f.logger.Debug("msg", "Overriding timestamp from JSON message", f.logger.Debug("msg", "Overriding timestamp from JSON message",
"component", "json_formatter", "component", "json_formatter",
"original", msgData[f.timestampField], "original", msgData[f.config.TimestampField],
"logwisp", output[f.timestampField]) "logwisp", output[f.config.TimestampField])
} }
} else { } else {
// Message is not valid JSON - add as message field // Message is not valid JSON - add as message field
output[f.messageField] = entry.Message output[f.config.MessageField] = entry.Message
} }
// Add any additional fields from LogEntry.Fields // Add any additional fields from LogEntry.Fields
@ -101,7 +78,7 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) {
// Marshal to JSON // Marshal to JSON
var result []byte var result []byte
var err error var err error
if f.pretty { if f.config.Pretty {
result, err = json.MarshalIndent(output, "", " ") result, err = json.MarshalIndent(output, "", " ")
} else { } else {
result, err = json.Marshal(output) result, err = json.Marshal(output)
@ -147,7 +124,7 @@ func (f *JSONFormatter) FormatBatch(entries []core.LogEntry) ([]byte, error) {
// Marshal the entire batch as an array // Marshal the entire batch as an array
var result []byte var result []byte
var err error var err error
if f.pretty { if f.config.Pretty {
result, err = json.MarshalIndent(batch, "", " ") result, err = json.MarshalIndent(batch, "", " ")
} else { } else {
result, err = json.Marshal(batch) result, err = json.Marshal(batch)

View File

@ -1,129 +0,0 @@
// FILE: logwisp/src/internal/format/json_test.go
package format
import (
"encoding/json"
"strings"
"testing"
"time"
"logwisp/src/internal/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestJSONFormatter_Format(t *testing.T) {
logger := newTestLogger()
testTime := time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC)
entry := core.LogEntry{
Time: testTime,
Source: "test-app",
Level: "INFO",
Message: "this is a test",
}
t.Run("BasicFormatting", func(t *testing.T) {
formatter, err := NewJSONFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err, "Output should be valid JSON")
assert.Equal(t, testTime.Format(time.RFC3339Nano), result["timestamp"])
assert.Equal(t, "INFO", result["level"])
assert.Equal(t, "test-app", result["source"])
assert.Equal(t, "this is a test", result["message"])
assert.True(t, strings.HasSuffix(string(output), "\n"), "Output should end with a newline")
})
t.Run("PrettyFormatting", func(t *testing.T) {
formatter, err := NewJSONFormatter(map[string]any{"pretty": true}, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
assert.Contains(t, string(output), ` "level": "INFO"`)
assert.True(t, strings.HasSuffix(string(output), "\n"))
})
t.Run("MessageIsJSON", func(t *testing.T) {
jsonMessageEntry := entry
jsonMessageEntry.Message = `{"user":"test","request_id":"abc-123"}`
formatter, err := NewJSONFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(jsonMessageEntry)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err)
assert.Equal(t, "test", result["user"])
assert.Equal(t, "abc-123", result["request_id"])
_, messageExists := result["message"]
assert.False(t, messageExists, "message field should not exist when message is merged JSON")
})
t.Run("MessageIsJSONWithConflicts", func(t *testing.T) {
jsonMessageEntry := entry
jsonMessageEntry.Level = "INFO" // top-level
jsonMessageEntry.Message = `{"level":"DEBUG","msg":"hello"}`
formatter, err := NewJSONFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(jsonMessageEntry)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err)
assert.Equal(t, "INFO", result["level"], "Top-level LogEntry field should take precedence")
})
t.Run("CustomFieldNames", func(t *testing.T) {
options := map[string]any{"timestamp_field": "@timestamp"}
formatter, err := NewJSONFormatter(options, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err)
_, defaultExists := result["timestamp"]
assert.False(t, defaultExists)
assert.Equal(t, testTime.Format(time.RFC3339Nano), result["@timestamp"])
})
}
func TestJSONFormatter_FormatBatch(t *testing.T) {
logger := newTestLogger()
formatter, err := NewJSONFormatter(nil, logger)
require.NoError(t, err)
entries := []core.LogEntry{
{Time: time.Now(), Level: "INFO", Message: "First message"},
{Time: time.Now(), Level: "WARN", Message: "Second message"},
}
output, err := formatter.FormatBatch(entries)
require.NoError(t, err)
var result []map[string]interface{}
err = json.Unmarshal(output, &result)
require.NoError(t, err, "Batch output should be a valid JSON array")
require.Len(t, result, 2)
assert.Equal(t, "First message", result[0]["message"])
assert.Equal(t, "WARN", result[1]["level"])
}

View File

@ -2,6 +2,7 @@
package format package format
import ( import (
"logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
@ -9,20 +10,26 @@ import (
// Outputs the log message as-is with a newline // Outputs the log message as-is with a newline
type RawFormatter struct { type RawFormatter struct {
config *config.RawFormatterOptions
logger *log.Logger logger *log.Logger
} }
// Creates a new raw formatter // Creates a new raw formatter
func NewRawFormatter(options map[string]any, logger *log.Logger) (*RawFormatter, error) { func NewRawFormatter(cfg *config.RawFormatterOptions, logger *log.Logger) (*RawFormatter, error) {
return &RawFormatter{ return &RawFormatter{
config: cfg,
logger: logger, logger: logger,
}, nil }, nil
} }
// Returns the message with a newline appended // Returns the message with a newline appended
func (f *RawFormatter) Format(entry core.LogEntry) ([]byte, error) { func (f *RawFormatter) Format(entry core.LogEntry) ([]byte, error) {
// Simply return the message with newline // TODO: Standardize not to add "\n" when processing raw, check lixenwraith/log for consistency
return append([]byte(entry.Message), '\n'), nil if f.config.AddNewLine {
return append([]byte(entry.Message), '\n'), nil
} else {
return []byte(entry.Message), nil
}
} }
// Returns the formatter name // Returns the formatter name

View File

@ -1,29 +0,0 @@
// FILE: logwisp/src/internal/format/raw_test.go
package format
import (
"testing"
"time"
"logwisp/src/internal/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRawFormatter_Format(t *testing.T) {
logger := newTestLogger()
formatter, err := NewRawFormatter(nil, logger)
require.NoError(t, err)
entry := core.LogEntry{
Time: time.Now(),
Message: "This is a raw log line.",
}
output, err := formatter.Format(entry)
require.NoError(t, err)
expected := "This is a raw log line.\n"
assert.Equal(t, expected, string(output))
}

View File

@ -4,6 +4,7 @@ package format
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"logwisp/src/internal/config"
"strings" "strings"
"text/template" "text/template"
"time" "time"
@ -15,41 +16,29 @@ import (
// Produces human-readable text logs using templates // Produces human-readable text logs using templates
type TextFormatter struct { type TextFormatter struct {
template *template.Template config *config.TextFormatterOptions
timestampFormat string template *template.Template
logger *log.Logger logger *log.Logger
} }
// Creates a new text formatter // Creates a new text formatter
func NewTextFormatter(options map[string]any, logger *log.Logger) (*TextFormatter, error) { func NewTextFormatter(opts *config.TextFormatterOptions, logger *log.Logger) (*TextFormatter, error) {
// Default template
templateStr := "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}"
if tmpl, ok := options["template"].(string); ok && tmpl != "" {
templateStr = tmpl
}
// Default timestamp format
timestampFormat := time.RFC3339
if tsFormat, ok := options["timestamp_format"].(string); ok && tsFormat != "" {
timestampFormat = tsFormat
}
f := &TextFormatter{ f := &TextFormatter{
timestampFormat: timestampFormat, config: opts,
logger: logger, logger: logger,
} }
// Create template with helper functions // Create template with helper functions
funcMap := template.FuncMap{ funcMap := template.FuncMap{
"FmtTime": func(t time.Time) string { "FmtTime": func(t time.Time) string {
return t.Format(f.timestampFormat) return t.Format(f.config.TimestampFormat)
}, },
"ToUpper": strings.ToUpper, "ToUpper": strings.ToUpper,
"ToLower": strings.ToLower, "ToLower": strings.ToLower,
"TrimSpace": strings.TrimSpace, "TrimSpace": strings.TrimSpace,
} }
tmpl, err := template.New("log").Funcs(funcMap).Parse(templateStr) tmpl, err := template.New("log").Funcs(funcMap).Parse(f.config.Template)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid template: %w", err) return nil, fmt.Errorf("invalid template: %w", err)
} }
@ -86,7 +75,7 @@ func (f *TextFormatter) Format(entry core.LogEntry) ([]byte, error) {
"error", err) "error", err)
fallback := fmt.Sprintf("[%s] [%s] %s - %s\n", fallback := fmt.Sprintf("[%s] [%s] %s - %s\n",
entry.Time.Format(f.timestampFormat), entry.Time.Format(f.config.TimestampFormat),
strings.ToUpper(entry.Level), strings.ToUpper(entry.Level),
entry.Source, entry.Source,
entry.Message) entry.Message)

View File

@ -1,81 +0,0 @@
// FILE: logwisp/src/internal/format/text_test.go
package format
import (
"fmt"
"strings"
"testing"
"time"
"logwisp/src/internal/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewTextFormatter(t *testing.T) {
logger := newTestLogger()
t.Run("InvalidTemplate", func(t *testing.T) {
options := map[string]any{"template": "{{ .Timestamp | InvalidFunc }}"}
_, err := NewTextFormatter(options, logger)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid template")
})
}
func TestTextFormatter_Format(t *testing.T) {
logger := newTestLogger()
testTime := time.Date(2023, 10, 27, 10, 30, 0, 0, time.UTC)
entry := core.LogEntry{
Time: testTime,
Source: "api",
Level: "WARN",
Message: "rate limit exceeded",
}
t.Run("DefaultTemplate", func(t *testing.T) {
formatter, err := NewTextFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
expected := fmt.Sprintf("[%s] [WARN] api - rate limit exceeded\n", testTime.Format(time.RFC3339))
assert.Equal(t, expected, string(output))
})
t.Run("CustomTemplate", func(t *testing.T) {
options := map[string]any{"template": "{{.Level}}:{{.Source}}:{{.Message}}"}
formatter, err := NewTextFormatter(options, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
expected := "WARN:api:rate limit exceeded\n"
assert.Equal(t, expected, string(output))
})
t.Run("CustomTimestampFormat", func(t *testing.T) {
options := map[string]any{"timestamp_format": "2006-01-02"}
formatter, err := NewTextFormatter(options, logger)
require.NoError(t, err)
output, err := formatter.Format(entry)
require.NoError(t, err)
assert.True(t, strings.HasPrefix(string(output), "[2023-10-27]"))
})
t.Run("EmptyLevelDefaultsToInfo", func(t *testing.T) {
emptyLevelEntry := entry
emptyLevelEntry.Level = ""
formatter, err := NewTextFormatter(nil, logger)
require.NoError(t, err)
output, err := formatter.Format(emptyLevelEntry)
require.NoError(t, err)
assert.Contains(t, string(output), "[INFO]")
})
}

View File

@ -34,7 +34,7 @@ const (
// NetLimiter manages net limiting for a transport // NetLimiter manages net limiting for a transport
type NetLimiter struct { type NetLimiter struct {
config config.NetLimitConfig config *config.NetLimitConfig
logger *log.Logger logger *log.Logger
// IP Access Control Lists // IP Access Control Lists
@ -89,7 +89,11 @@ type connTracker struct {
} }
// Creates a new net limiter // Creates a new net limiter
func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter { func NewNetLimiter(cfg *config.NetLimitConfig, logger *log.Logger) *NetLimiter {
if cfg == nil {
return nil
}
// Return nil only if nothing is configured // Return nil only if nothing is configured
hasACL := len(cfg.IPWhitelist) > 0 || len(cfg.IPBlacklist) > 0 hasACL := len(cfg.IPWhitelist) > 0 || len(cfg.IPBlacklist) > 0
hasRateLimit := cfg.Enabled hasRateLimit := cfg.Enabled
@ -120,7 +124,7 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
} }
// Parse IP lists // Parse IP lists
l.parseIPLists(cfg) l.parseIPLists()
// Start cleanup goroutine only if rate limiting is enabled // Start cleanup goroutine only if rate limiting is enabled
if cfg.Enabled { if cfg.Enabled {
@ -144,16 +148,16 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
} }
// parseIPLists parses and validates IP whitelist/blacklist // parseIPLists parses and validates IP whitelist/blacklist
func (l *NetLimiter) parseIPLists(cfg config.NetLimitConfig) { func (l *NetLimiter) parseIPLists() {
// Parse whitelist // Parse whitelist
for _, entry := range cfg.IPWhitelist { for _, entry := range l.config.IPWhitelist {
if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil { if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil {
l.ipWhitelist = append(l.ipWhitelist, ipNet) l.ipWhitelist = append(l.ipWhitelist, ipNet)
} }
} }
// Parse blacklist // Parse blacklist
for _, entry := range cfg.IPBlacklist { for _, entry := range l.config.IPBlacklist {
if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil { if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil {
l.ipBlacklist = append(l.ipBlacklist, ipNet) l.ipBlacklist = append(l.ipBlacklist, ipNet)
} }

View File

@ -1,117 +0,0 @@
// FILE: src/internal/scram/integration.go
package scram
import (
"crypto/rand"
"fmt"
"sync"
"time"
"golang.org/x/time/rate"
)
// ScramManager provides high-level SCRAM operations with rate limiting
type ScramManager struct {
server *Server
sessions map[string]*SessionInfo
limiter map[string]*rate.Limiter
mu sync.RWMutex
}
// SessionInfo tracks authenticated sessions
type SessionInfo struct {
Username string
RemoteAddr string
SessionID string
CreatedAt time.Time
LastActivity time.Time
Method string // "scram-sha-256"
}
// NewScramManager creates SCRAM manager
func NewScramManager() *ScramManager {
m := &ScramManager{
server: NewServer(),
sessions: make(map[string]*SessionInfo),
limiter: make(map[string]*rate.Limiter),
}
// Start cleanup goroutine
go m.cleanupLoop()
return m
}
// RegisterUser creates new user credential
func (sm *ScramManager) RegisterUser(username, password string) error {
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return fmt.Errorf("salt generation failed: %w", err)
}
cred, err := DeriveCredential(username, password, salt,
sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads)
if err != nil {
return err
}
sm.server.AddCredential(cred)
return nil
}
// GetRateLimiter returns per-IP rate limiter
func (sm *ScramManager) GetRateLimiter(remoteAddr string) *rate.Limiter {
sm.mu.Lock()
defer sm.mu.Unlock()
if limiter, exists := sm.limiter[remoteAddr]; exists {
return limiter
}
// 10 attempts per minute, burst of 3
limiter := rate.NewLimiter(rate.Every(6*time.Second), 3)
sm.limiter[remoteAddr] = limiter
// Prevent unbounded growth
if len(sm.limiter) > 10000 {
// Remove oldest entries
for addr := range sm.limiter {
delete(sm.limiter, addr)
if len(sm.limiter) < 8000 {
break
}
}
}
return limiter
}
func (sm *ScramManager) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
sm.mu.Lock()
cutoff := time.Now().Add(-30 * time.Minute)
for sid, session := range sm.sessions {
if session.LastActivity.Before(cutoff) {
delete(sm.sessions, sid)
}
}
sm.mu.Unlock()
}
}
// HandleClientFirst wraps server's HandleClientFirst
func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
return sm.server.HandleClientFirst(msg)
}
// HandleClientFinal wraps server's HandleClientFinal
func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
return sm.server.HandleClientFinal(msg)
}
// AddCredential wraps server's AddCredential
func (sm *ScramManager) AddCredential(cred *Credential) {
sm.server.AddCredential(cred)
}

View File

@ -1,101 +0,0 @@
// FILE: src/internal/scram/message.go
package scram
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)
// ClientFirst initiates authentication
type ClientFirst struct {
Username string `json:"u"`
ClientNonce string `json:"n"`
}
// ServerFirst contains server challenge
type ServerFirst struct {
FullNonce string `json:"r"` // client_nonce + server_nonce
Salt string `json:"s"` // base64
ArgonTime uint32 `json:"t"`
ArgonMemory uint32 `json:"m"`
ArgonThreads uint8 `json:"p"`
}
// ClientFinal contains client proof
type ClientFinal struct {
FullNonce string `json:"r"`
ClientProof string `json:"p"` // base64
}
// ServerFinal contains server signature for mutual auth
type ServerFinal struct {
ServerSignature string `json:"v"` // base64
SessionID string `json:"sid,omitempty"`
}
// Marshal/Unmarshal helpers for TCP protocol (line-based)
func (cf *ClientFirst) Marshal() string {
return fmt.Sprintf("u=%s,n=%s", cf.Username, cf.ClientNonce)
}
func ParseClientFirst(data string) (*ClientFirst, error) {
parts := strings.Split(data, ",")
msg := &ClientFirst{}
for _, part := range parts {
kv := strings.SplitN(part, "=", 2)
if len(kv) != 2 {
continue
}
switch kv[0] {
case "u":
msg.Username = kv[1]
case "n":
msg.ClientNonce = kv[1]
}
}
if msg.Username == "" || msg.ClientNonce == "" {
return nil, fmt.Errorf("missing required fields")
}
return msg, nil
}
func (sf *ServerFirst) Marshal() string {
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads)
}
func ParseServerFirst(data string) (*ServerFirst, error) {
parts := strings.Split(data, ",")
msg := &ServerFirst{}
for _, part := range parts {
kv := strings.SplitN(part, "=", 2)
if len(kv) != 2 {
continue
}
switch kv[0] {
case "r":
msg.FullNonce = kv[1]
case "s":
msg.Salt = kv[1]
case "t":
fmt.Sscanf(kv[1], "%d", &msg.ArgonTime)
case "m":
fmt.Sscanf(kv[1], "%d", &msg.ArgonMemory)
case "p":
fmt.Sscanf(kv[1], "%d", &msg.ArgonThreads)
}
}
return msg, nil
}
// JSON variants for HTTP
func (cf *ClientFirst) MarshalJSON() ([]byte, error) {
return json.Marshal(*cf)
}
func (sf *ServerFirst) MarshalJSON() ([]byte, error) {
sf.Salt = base64.StdEncoding.EncodeToString([]byte(sf.Salt))
return json.Marshal(*sf)
}

View File

@ -1,228 +0,0 @@
// FILE: src/internal/scram/scram_test.go
package scram
import (
"crypto/rand"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/argon2"
)
func TestCredentialDerivation(t *testing.T) {
salt := make([]byte, 16)
_, err := rand.Read(salt)
require.NoError(t, err)
cred, err := DeriveCredential("testuser", "testpass123", salt, 3, 64*1024, 4)
require.NoError(t, err)
assert.Equal(t, "testuser", cred.Username)
assert.Equal(t, uint32(3), cred.ArgonTime)
assert.Equal(t, uint32(64*1024), cred.ArgonMemory)
assert.Equal(t, uint8(4), cred.ArgonThreads)
assert.Len(t, cred.StoredKey, 32)
assert.Len(t, cred.ServerKey, 32)
}
func TestFullHandshake(t *testing.T) {
// Setup server with credential
server := NewServer()
salt := make([]byte, 16)
_, err := rand.Read(salt)
require.NoError(t, err)
cred, err := DeriveCredential("alice", "secret123", salt, 3, 64*1024, 4)
require.NoError(t, err)
server.AddCredential(cred)
// Setup client
client := NewClient("alice", "secret123")
// Step 1: Client starts
clientFirst, err := client.StartAuthentication()
require.NoError(t, err)
assert.Equal(t, "alice", clientFirst.Username)
assert.NotEmpty(t, clientFirst.ClientNonce)
// Step 2: Server responds
serverFirst, err := server.HandleClientFirst(clientFirst)
require.NoError(t, err)
assert.Contains(t, serverFirst.FullNonce, clientFirst.ClientNonce)
assert.Equal(t, base64.StdEncoding.EncodeToString(salt), serverFirst.Salt)
// Step 3: Client proves
clientFinal, err := client.ProcessServerFirst(serverFirst)
require.NoError(t, err)
assert.Equal(t, serverFirst.FullNonce, clientFinal.FullNonce)
assert.NotEmpty(t, clientFinal.ClientProof)
// Step 4: Server verifies and signs
serverFinal, err := server.HandleClientFinal(clientFinal)
require.NoError(t, err)
assert.NotEmpty(t, serverFinal.ServerSignature)
assert.NotEmpty(t, serverFinal.SessionID)
// Step 5: Client verifies server
err = client.VerifyServerFinal(serverFinal)
assert.NoError(t, err)
}
func TestInvalidPassword(t *testing.T) {
server := NewServer()
salt := make([]byte, 16)
rand.Read(salt)
cred, _ := DeriveCredential("bob", "correct", salt, 3, 64*1024, 4)
server.AddCredential(cred)
// Client with wrong password
client := NewClient("bob", "wrong")
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.HandleClientFirst(clientFirst)
clientFinal, _ := client.ProcessServerFirst(serverFirst)
clientFinal, err := client.ProcessServerFirst(serverFirst)
require.NoError(t, err) // Check error to prevent panic
// Server should reject
_, err = server.HandleClientFinal(clientFinal)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
}
func TestHandshakeTimeout(t *testing.T) {
server := NewServer()
salt := make([]byte, 16)
rand.Read(salt)
cred, _ := DeriveCredential("charlie", "pass", salt, 3, 64*1024, 4)
server.AddCredential(cred)
client := NewClient("charlie", "pass")
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.HandleClientFirst(clientFirst)
// Manipulate handshake timestamp
server.mu.Lock()
if state, exists := server.handshakes[serverFirst.FullNonce]; exists {
state.CreatedAt = time.Now().Add(-61 * time.Second)
}
server.mu.Unlock()
clientFinal, _ := client.ProcessServerFirst(serverFirst)
_, err := server.HandleClientFinal(clientFinal)
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout")
}
func TestMessageParsing(t *testing.T) {
// Test ClientFirst parsing
cf, err := ParseClientFirst("u=alice,n=abc123")
require.NoError(t, err)
assert.Equal(t, "alice", cf.Username)
assert.Equal(t, "abc123", cf.ClientNonce)
// Test marshal round-trip
marshaled := cf.Marshal()
parsed, err := ParseClientFirst(marshaled)
require.NoError(t, err)
assert.Equal(t, cf.Username, parsed.Username)
assert.Equal(t, cf.ClientNonce, parsed.ClientNonce)
}
func TestPHCMigration(t *testing.T) {
// Create PHC hash using known values
password := "testpass"
salt := []byte("saltsaltsaltsalt")
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
// This would be generated by the existing auth system
phcHash := "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + "dGVzdGhhc2g" // dummy hash
// For testing, we need a valid hash - generate it
hash := argon2.IDKey([]byte(password), salt, 3, 65536, 4, 32)
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
phcHash = "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + hashB64
cred, err := MigrateFromPHC("alice", password, phcHash)
require.NoError(t, err)
assert.Equal(t, "alice", cred.Username)
assert.Equal(t, salt, cred.Salt)
}
func TestRateLimiting(t *testing.T) {
manager := NewScramManager()
limiter := manager.GetRateLimiter("192.168.1.1")
// Should allow burst
for i := 0; i < 3; i++ {
assert.True(t, limiter.Allow())
}
// 4th should be rate limited
assert.False(t, limiter.Allow())
}
func TestNonceUniqueness(t *testing.T) {
nonces := make(map[string]bool)
for i := 0; i < 1000; i++ {
nonce := generateNonce()
assert.False(t, nonces[nonce], "Duplicate nonce generated")
nonces[nonce] = true
}
}
func TestConstantTimeComparison(t *testing.T) {
server := NewServer()
salt := make([]byte, 16)
rand.Read(salt)
// Add real user
cred, _ := DeriveCredential("realuser", "realpass", salt, 3, 64*1024, 4)
server.AddCredential(cred)
// Test timing independence for invalid user
fakeClient := NewClient("fakeuser", "fakepass")
clientFirst, _ := fakeClient.StartAuthentication()
// Should still return ServerFirst (no user enumeration)
serverFirst, err := server.HandleClientFirst(clientFirst)
assert.NotNil(t, serverFirst)
assert.Error(t, err) // Internal error, not exposed to client
}
func BenchmarkArgon2Derivation(b *testing.B) {
salt := make([]byte, 16)
rand.Read(salt)
b.ResetTimer()
for i := 0; i < b.N; i++ {
DeriveCredential("user", "password", salt, 3, 64*1024, 4)
}
}
func BenchmarkFullHandshake(b *testing.B) {
server := NewServer()
salt := make([]byte, 16)
rand.Read(salt)
cred, _ := DeriveCredential("bench", "pass", salt, 3, 64*1024, 4)
server.AddCredential(cred)
b.ResetTimer()
for i := 0; i < b.N; i++ {
client := NewClient("bench", "pass")
clientFirst, _ := client.StartAuthentication()
serverFirst, _ := server.HandleClientFirst(clientFirst)
clientFinal, _ := client.ProcessServerFirst(serverFirst)
serverFinal, _ := server.HandleClientFinal(clientFinal)
client.VerifyServerFinal(serverFinal)
}
}

View File

@ -3,12 +3,14 @@ package service
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/filter" "logwisp/src/internal/filter"
"logwisp/src/internal/format"
"logwisp/src/internal/limit" "logwisp/src/internal/limit"
"logwisp/src/internal/sink" "logwisp/src/internal/sink"
"logwisp/src/internal/source" "logwisp/src/internal/source"
@ -18,8 +20,7 @@ import (
// Manages the flow of data from sources through filters to sinks // Manages the flow of data from sources through filters to sinks
type Pipeline struct { type Pipeline struct {
Name string Config *config.PipelineConfig
Config config.PipelineConfig
Sources []source.Source Sources []source.Source
RateLimiter *limit.RateLimiter RateLimiter *limit.RateLimiter
FilterChain *filter.Chain FilterChain *filter.Chain
@ -43,11 +44,116 @@ type PipelineStats struct {
FilterStats map[string]any FilterStats map[string]any
} }
// Creates and starts a new pipeline
func (s *Service) NewPipeline(cfg *config.PipelineConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.pipelines[cfg.Name]; exists {
err := fmt.Errorf("pipeline '%s' already exists", cfg.Name)
s.logger.Error("msg", "Failed to create pipeline - duplicate name",
"component", "service",
"pipeline", cfg.Name,
"error", err)
return err
}
s.logger.Debug("msg", "Creating pipeline", "pipeline", cfg.Name)
// Create pipeline context
pipelineCtx, pipelineCancel := context.WithCancel(s.ctx)
// Create pipeline instance
pipeline := &Pipeline{
Config: cfg,
Stats: &PipelineStats{
StartTime: time.Now(),
},
ctx: pipelineCtx,
cancel: pipelineCancel,
logger: s.logger,
}
// Create sources
for i, srcCfg := range cfg.Sources {
src, err := s.createSource(&srcCfg)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create source[%d]: %w", i, err)
}
pipeline.Sources = append(pipeline.Sources, src)
}
// Create pipeline rate limiter
if cfg.RateLimit != nil {
limiter, err := limit.NewRateLimiter(*cfg.RateLimit, s.logger)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create pipeline rate limiter: %w", err)
}
pipeline.RateLimiter = limiter
}
// Create filter chain
if len(cfg.Filters) > 0 {
chain, err := filter.NewChain(cfg.Filters, s.logger)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create filter chain: %w", err)
}
pipeline.FilterChain = chain
}
// Create formatter for the pipeline
formatter, err := format.NewFormatter(cfg.Format, s.logger)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create formatter: %w", err)
}
// Create sinks
for i, sinkCfg := range cfg.Sinks {
sinkInst, err := s.createSink(sinkCfg, formatter)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create sink[%d]: %w", i, err)
}
pipeline.Sinks = append(pipeline.Sinks, sinkInst)
}
// Start all sources
for i, src := range pipeline.Sources {
if err := src.Start(); err != nil {
pipeline.Shutdown()
return fmt.Errorf("failed to start source[%d]: %w", i, err)
}
}
// Start all sinks
for i, sinkInst := range pipeline.Sinks {
if err := sinkInst.Start(pipelineCtx); err != nil {
pipeline.Shutdown()
return fmt.Errorf("failed to start sink[%d]: %w", i, err)
}
}
// Wire sources to sinks through filters
s.wirePipeline(pipeline)
// Start stats updater
pipeline.startStatsUpdater(pipelineCtx)
s.pipelines[cfg.Name] = pipeline
s.logger.Info("msg", "Pipeline created successfully",
"pipeline", cfg.Name)
return nil
}
// Gracefully stops the pipeline // Gracefully stops the pipeline
func (p *Pipeline) Shutdown() { func (p *Pipeline) Shutdown() {
p.logger.Info("msg", "Shutting down pipeline", p.logger.Info("msg", "Shutting down pipeline",
"component", "pipeline", "component", "pipeline",
"pipeline", p.Name) "pipeline", p.Config.Name)
// Cancel context to stop processing // Cancel context to stop processing
p.cancel() p.cancel()
@ -78,7 +184,7 @@ func (p *Pipeline) Shutdown() {
p.logger.Info("msg", "Pipeline shutdown complete", p.logger.Info("msg", "Pipeline shutdown complete",
"component", "pipeline", "component", "pipeline",
"pipeline", p.Name) "pipeline", p.Config.Name)
} }
// Returns pipeline statistics // Returns pipeline statistics
@ -88,7 +194,7 @@ func (p *Pipeline) GetStats() map[string]any {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
p.logger.Error("msg", "Panic getting pipeline stats", p.logger.Error("msg", "Panic getting pipeline stats",
"pipeline", p.Name, "pipeline", p.Config.Name,
"panic", r) "panic", r)
} }
}() }()
@ -142,7 +248,7 @@ func (p *Pipeline) GetStats() map[string]any {
} }
return map[string]any{ return map[string]any{
"name": p.Name, "name": p.Config.Name,
"uptime_seconds": int(time.Since(p.Stats.StartTime).Seconds()), "uptime_seconds": int(time.Since(p.Stats.StartTime).Seconds()),
"total_processed": p.Stats.TotalEntriesProcessed.Load(), "total_processed": p.Stats.TotalEntriesProcessed.Load(),
"total_dropped_rate_limit": p.Stats.TotalEntriesDroppedByRateLimit.Load(), "total_dropped_rate_limit": p.Stats.TotalEntriesDroppedByRateLimit.Load(),

View File

@ -5,13 +5,10 @@ import (
"context" "context"
"fmt" "fmt"
"sync" "sync"
"time"
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/filter"
"logwisp/src/internal/format" "logwisp/src/internal/format"
"logwisp/src/internal/limit"
"logwisp/src/internal/sink" "logwisp/src/internal/sink"
"logwisp/src/internal/source" "logwisp/src/internal/source"
@ -39,127 +36,6 @@ func NewService(ctx context.Context, logger *log.Logger) *Service {
} }
} }
// Creates and starts a new pipeline
func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.pipelines[cfg.Name]; exists {
err := fmt.Errorf("pipeline '%s' already exists", cfg.Name)
s.logger.Error("msg", "Failed to create pipeline - duplicate name",
"component", "service",
"pipeline", cfg.Name,
"error", err)
return err
}
s.logger.Debug("msg", "Creating pipeline", "pipeline", cfg.Name)
// Create pipeline context
pipelineCtx, pipelineCancel := context.WithCancel(s.ctx)
// Create pipeline instance
pipeline := &Pipeline{
Name: cfg.Name,
Config: cfg,
Stats: &PipelineStats{
StartTime: time.Now(),
},
ctx: pipelineCtx,
cancel: pipelineCancel,
logger: s.logger,
}
// Create sources
for i, srcCfg := range cfg.Sources {
src, err := s.createSource(srcCfg)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create source[%d]: %w", i, err)
}
pipeline.Sources = append(pipeline.Sources, src)
}
// Create pipeline rate limiter
if cfg.RateLimit != nil {
limiter, err := limit.NewRateLimiter(*cfg.RateLimit, s.logger)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create pipeline rate limiter: %w", err)
}
pipeline.RateLimiter = limiter
}
// Create filter chain
if len(cfg.Filters) > 0 {
chain, err := filter.NewChain(cfg.Filters, s.logger)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create filter chain: %w", err)
}
pipeline.FilterChain = chain
}
// Create formatter for the pipeline
var formatter format.Formatter
var err error
if cfg.Format != "" || len(cfg.FormatOptions) > 0 {
formatter, err = format.NewFormatter(cfg.Format, cfg.FormatOptions, s.logger)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create formatter: %w", err)
}
}
// Create sinks
for i, sinkCfg := range cfg.Sinks {
sinkInst, err := s.createSink(sinkCfg, formatter)
if err != nil {
pipelineCancel()
return fmt.Errorf("failed to create sink[%d]: %w", i, err)
}
pipeline.Sinks = append(pipeline.Sinks, sinkInst)
}
// Configure authentication for sources that support it before starting them
for _, sourceInst := range pipeline.Sources {
sourceInst.SetAuth(cfg.Auth)
}
// Start all sources
for i, src := range pipeline.Sources {
if err := src.Start(); err != nil {
pipeline.Shutdown()
return fmt.Errorf("failed to start source[%d]: %w", i, err)
}
}
// Configure authentication for sinks that support it before starting them
for _, sinkInst := range pipeline.Sinks {
sinkInst.SetAuth(cfg.Auth)
}
// Start all sinks
for i, sinkInst := range pipeline.Sinks {
if err := sinkInst.Start(pipelineCtx); err != nil {
pipeline.Shutdown()
return fmt.Errorf("failed to start sink[%d]: %w", i, err)
}
}
// Wire sources to sinks through filters
s.wirePipeline(pipeline)
// Start stats updater
pipeline.startStatsUpdater(pipelineCtx)
s.pipelines[cfg.Name] = pipeline
s.logger.Info("msg", "Pipeline created successfully",
"pipeline", cfg.Name,
"auth_enabled", cfg.Auth != nil && cfg.Auth.Type != "none")
return nil
}
// Connects sources to sinks through filters // Connects sources to sinks through filters
func (s *Service) wirePipeline(p *Pipeline) { func (s *Service) wirePipeline(p *Pipeline) {
// For each source, subscribe and process entries // For each source, subscribe and process entries
@ -175,17 +51,17 @@ func (s *Service) wirePipeline(p *Pipeline) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
s.logger.Error("msg", "Panic in pipeline processing", s.logger.Error("msg", "Panic in pipeline processing",
"pipeline", p.Name, "pipeline", p.Config.Name,
"source", source.GetStats().Type, "source", source.GetStats().Type,
"panic", r) "panic", r)
// Ensure failed pipelines don't leave resources hanging // Ensure failed pipelines don't leave resources hanging
go func() { go func() {
s.logger.Warn("msg", "Shutting down pipeline due to panic", s.logger.Warn("msg", "Shutting down pipeline due to panic",
"pipeline", p.Name) "pipeline", p.Config.Name)
if err := s.RemovePipeline(p.Name); err != nil { if err := s.RemovePipeline(p.Config.Name); err != nil {
s.logger.Error("msg", "Failed to remove panicked pipeline", s.logger.Error("msg", "Failed to remove panicked pipeline",
"pipeline", p.Name, "pipeline", p.Config.Name,
"error", err) "error", err)
} }
}() }()
@ -228,7 +104,7 @@ func (s *Service) wirePipeline(p *Pipeline) {
default: default:
// Drop if sink buffer is full, may flood logging for slow client // Drop if sink buffer is full, may flood logging for slow client
s.logger.Debug("msg", "Dropped log entry - sink buffer full", s.logger.Debug("msg", "Dropped log entry - sink buffer full",
"pipeline", p.Name) "pipeline", p.Config.Name)
} }
} }
} }
@ -238,16 +114,16 @@ func (s *Service) wirePipeline(p *Pipeline) {
} }
// Creates a source instance based on configuration // Creates a source instance based on configuration
func (s *Service) createSource(cfg config.SourceConfig) (source.Source, error) { func (s *Service) createSource(cfg *config.SourceConfig) (source.Source, error) {
switch cfg.Type { switch cfg.Type {
case "directory": case "directory":
return source.NewDirectorySource(cfg.Options, s.logger) return source.NewDirectorySource(cfg.Directory, s.logger)
case "stdin": case "stdin":
return source.NewStdinSource(cfg.Options, s.logger) return source.NewStdinSource(cfg.Stdin, s.logger)
case "http": case "http":
return source.NewHTTPSource(cfg.Options, s.logger) return source.NewHTTPSource(cfg.HTTP, s.logger)
case "tcp": case "tcp":
return source.NewTCPSource(cfg.Options, s.logger) return source.NewTCPSource(cfg.TCP, s.logger)
default: default:
return nil, fmt.Errorf("unknown source type: %s", cfg.Type) return nil, fmt.Errorf("unknown source type: %s", cfg.Type)
} }
@ -255,34 +131,28 @@ func (s *Service) createSource(cfg config.SourceConfig) (source.Source, error) {
// Creates a sink instance based on configuration // Creates a sink instance based on configuration
func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) (sink.Sink, error) { func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) (sink.Sink, error) {
if formatter == nil {
// Default formatters for different sink types
defaultFormat := "raw"
switch cfg.Type {
case "http", "tcp", "http_client", "tcp_client":
defaultFormat = "json"
}
var err error
formatter, err = format.NewFormatter(defaultFormat, nil, s.logger)
if err != nil {
return nil, fmt.Errorf("failed to create default formatter: %w", err)
}
}
switch cfg.Type { switch cfg.Type {
case "http": case "http":
return sink.NewHTTPSink(cfg.Options, s.logger, formatter) if cfg.HTTP == nil {
return nil, fmt.Errorf("HTTP sink configuration missing")
}
return sink.NewHTTPSink(cfg.HTTP, s.logger, formatter)
case "tcp": case "tcp":
return sink.NewTCPSink(cfg.Options, s.logger, formatter) if cfg.TCP == nil {
return nil, fmt.Errorf("TCP sink configuration missing")
}
return sink.NewTCPSink(cfg.TCP, s.logger, formatter)
case "http_client": case "http_client":
return sink.NewHTTPClientSink(cfg.Options, s.logger, formatter) return sink.NewHTTPClientSink(cfg.HTTPClient, s.logger, formatter)
case "tcp_client": case "tcp_client":
return sink.NewTCPClientSink(cfg.Options, s.logger, formatter) return sink.NewTCPClientSink(cfg.TCPClient, s.logger, formatter)
case "file": case "file":
return sink.NewFileSink(cfg.Options, s.logger, formatter) return sink.NewFileSink(cfg.File, s.logger, formatter)
case "console": case "console":
return sink.NewConsoleSink(cfg.Options, s.logger, formatter) return sink.NewConsoleSink(cfg.Console, s.logger, formatter)
default: default:
return nil, fmt.Errorf("unknown sink type: %s", cfg.Type) return nil, fmt.Errorf("unknown sink type: %s", cfg.Type)
} }

View File

@ -18,6 +18,7 @@ import (
// ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance // ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance
type ConsoleSink struct { type ConsoleSink struct {
config *config.ConsoleSinkOptions
input chan core.LogEntry input chan core.LogEntry
writer *log.Logger // Dedicated internal logger instance for console writing writer *log.Logger // Dedicated internal logger instance for console writing
done chan struct{} done chan struct{}
@ -31,22 +32,24 @@ type ConsoleSink struct {
} }
// Creates a new console sink // Creates a new console sink
func NewConsoleSink(options map[string]any, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) { func NewConsoleSink(opts *config.ConsoleSinkOptions, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) {
target := "stdout" if opts == nil {
if t, ok := options["target"].(string); ok { return nil, fmt.Errorf("console sink options cannot be nil")
target = t
} }
bufferSize := int64(1000) // Set defaults if not configured
if buf, ok := options["buffer_size"].(int64); ok && buf > 0 { if opts.Target == "" {
bufferSize = buf opts.Target = "stdout"
}
if opts.BufferSize <= 0 {
opts.BufferSize = 1000
} }
// Dedicated logger instance as console writer // Dedicated logger instance as console writer
writer, err := log.NewBuilder(). writer, err := log.NewBuilder().
EnableFile(false). EnableFile(false).
EnableConsole(true). EnableConsole(true).
ConsoleTarget(target). ConsoleTarget(opts.Target).
Format("raw"). // Passthrough pre-formatted messages Format("raw"). // Passthrough pre-formatted messages
ShowTimestamp(false). // Disable writer's own timestamp ShowTimestamp(false). // Disable writer's own timestamp
ShowLevel(false). // Disable writer's own level prefix ShowLevel(false). // Disable writer's own level prefix
@ -57,7 +60,8 @@ func NewConsoleSink(options map[string]any, appLogger *log.Logger, formatter for
} }
s := &ConsoleSink{ s := &ConsoleSink{
input: make(chan core.LogEntry, bufferSize), config: opts,
input: make(chan core.LogEntry, opts.BufferSize),
writer: writer, writer: writer,
done: make(chan struct{}), done: make(chan struct{}),
startTime: time.Now(), startTime: time.Now(),
@ -157,7 +161,3 @@ func (s *ConsoleSink) processLoop(ctx context.Context) {
} }
} }
} }
func (s *ConsoleSink) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to the console sink.
}

View File

@ -5,10 +5,10 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"logwisp/src/internal/config"
"sync/atomic" "sync/atomic"
"time" "time"
"logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/format" "logwisp/src/internal/format"
@ -17,6 +17,7 @@ import (
// Writes log entries to files with rotation // Writes log entries to files with rotation
type FileSink struct { type FileSink struct {
config *config.FileSinkOptions
input chan core.LogEntry input chan core.LogEntry
writer *log.Logger // Internal logger instance for file writing writer *log.Logger // Internal logger instance for file writing
done chan struct{} done chan struct{}
@ -30,64 +31,27 @@ type FileSink struct {
} }
// Creates a new file sink // Creates a new file sink
func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*FileSink, error) { func NewFileSink(opts *config.FileSinkOptions, logger *log.Logger, formatter format.Formatter) (*FileSink, error) {
directory, ok := options["directory"].(string) if opts == nil {
if !ok || directory == "" { return nil, fmt.Errorf("file sink options cannot be nil")
directory = "./"
logger.Warn("No directory or invalid directory provided, current directory will be used")
}
name, ok := options["name"].(string)
if !ok || name == "" {
name = "logwisp.output"
logger.Warn(fmt.Sprintf("No filename provided, %s will be used", name))
} }
// Create configuration for the internal log writer // Create configuration for the internal log writer
writerConfig := log.DefaultConfig() writerConfig := log.DefaultConfig()
writerConfig.Directory = directory writerConfig.Directory = opts.Directory
writerConfig.Name = name writerConfig.Name = opts.Name
writerConfig.EnableConsole = false // File only writerConfig.EnableConsole = false // File only
writerConfig.ShowTimestamp = false // We already have timestamps in entries writerConfig.ShowTimestamp = false // We already have timestamps in entries
writerConfig.ShowLevel = false // We already have levels in entries writerConfig.ShowLevel = false // We already have levels in entries
// Add optional configurations
if maxSize, ok := options["max_size_mb"].(int64); ok && maxSize > 0 {
writerConfig.MaxSizeKB = maxSize * 1000
}
if maxTotalSize, ok := options["max_total_size_mb"].(int64); ok && maxTotalSize >= 0 {
writerConfig.MaxTotalSizeKB = maxTotalSize * 1000
}
if retention, ok := options["retention_hours"].(int64); ok && retention > 0 {
writerConfig.RetentionPeriodHrs = float64(retention)
}
if minDiskFree, ok := options["min_disk_free_mb"].(int64); ok && minDiskFree > 0 {
writerConfig.MinDiskFreeKB = minDiskFree * 1000
}
// Create internal logger for file writing // Create internal logger for file writing
writer := log.NewLogger() writer := log.NewLogger()
if err := writer.ApplyConfig(writerConfig); err != nil { if err := writer.ApplyConfig(writerConfig); err != nil {
return nil, fmt.Errorf("failed to initialize file writer: %w", err) return nil, fmt.Errorf("failed to initialize file writer: %w", err)
} }
// Start the internal file writer
if err := writer.Start(); err != nil {
return nil, fmt.Errorf("failed to start file writer: %w", err)
}
// Buffer size for input channel
// TODO: Centralized constant file in core package
bufferSize := int64(1000)
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
bufferSize = bufSize
}
fs := &FileSink{ fs := &FileSink{
input: make(chan core.LogEntry, bufferSize), input: make(chan core.LogEntry, opts.BufferSize),
writer: writer, writer: writer,
done: make(chan struct{}), done: make(chan struct{}),
startTime: time.Now(), startTime: time.Now(),
@ -104,6 +68,11 @@ func (fs *FileSink) Input() chan<- core.LogEntry {
} }
func (fs *FileSink) Start(ctx context.Context) error { func (fs *FileSink) Start(ctx context.Context) error {
// Start the internal file writer
if err := fs.writer.Start(); err != nil {
return fmt.Errorf("failed to start sink file writer: %w", err)
}
go fs.processLoop(ctx) go fs.processLoop(ctx)
fs.logger.Info("msg", "File sink started", "component", "file_sink") fs.logger.Info("msg", "File sink started", "component", "file_sink")
return nil return nil
@ -167,7 +136,3 @@ func (fs *FileSink) processLoop(ctx context.Context) {
} }
} }
} }
func (fs *FileSink) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to file sink
}

View File

@ -26,8 +26,11 @@ import (
// Streams log entries via Server-Sent Events // Streams log entries via Server-Sent Events
type HTTPSink struct { type HTTPSink struct {
// Configuration reference (NOT a copy)
config *config.HTTPSinkOptions
// Runtime
input chan core.LogEntry input chan core.LogEntry
config HTTPConfig
server *fasthttp.Server server *fasthttp.Server
activeClients atomic.Int64 activeClients atomic.Int64
mu sync.RWMutex mu sync.RWMutex
@ -46,11 +49,7 @@ type HTTPSink struct {
// Security components // Security components
authenticator *auth.Authenticator authenticator *auth.Authenticator
tlsManager *tls.Manager tlsManager *tls.Manager
authConfig *config.AuthConfig authConfig *config.ServerAuthConfig
// Path configuration
streamPath string
statusPath string
// Net limiting // Net limiting
netLimiter *limit.NetLimiter netLimiter *limit.NetLimiter
@ -62,151 +61,58 @@ type HTTPSink struct {
authSuccesses atomic.Uint64 authSuccesses atomic.Uint64
} }
// Holds HTTP sink configuration
type HTTPConfig struct {
Host string
Port int64
BufferSize int64
StreamPath string
StatusPath string
Heartbeat *config.HeartbeatConfig
TLS *config.TLSConfig
NetLimit *config.NetLimitConfig
}
// Creates a new HTTP streaming sink // Creates a new HTTP streaming sink
func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*HTTPSink, error) { func NewHTTPSink(opts *config.HTTPSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPSink, error) {
cfg := HTTPConfig{ if opts == nil {
Host: "0.0.0.0", return nil, fmt.Errorf("HTTP sink options cannot be nil")
Port: 8080,
BufferSize: 1000,
StreamPath: "/stream",
StatusPath: "/status",
}
// Extract configuration from options
if host, ok := options["host"].(string); ok && host != "" {
cfg.Host = host
}
if port, ok := options["port"].(int64); ok {
cfg.Port = port
}
if bufSize, ok := options["buffer_size"].(int64); ok {
cfg.BufferSize = bufSize
}
if path, ok := options["stream_path"].(string); ok {
cfg.StreamPath = path
}
if path, ok := options["status_path"].(string); ok {
cfg.StatusPath = path
}
// Extract heartbeat config
if hb, ok := options["heartbeat"].(map[string]any); ok {
cfg.Heartbeat = &config.HeartbeatConfig{}
cfg.Heartbeat.Enabled, _ = hb["enabled"].(bool)
if interval, ok := hb["interval_seconds"].(int64); ok {
cfg.Heartbeat.IntervalSeconds = interval
}
cfg.Heartbeat.IncludeTimestamp, _ = hb["include_timestamp"].(bool)
cfg.Heartbeat.IncludeStats, _ = hb["include_stats"].(bool)
if hbFormat, ok := hb["format"].(string); ok {
cfg.Heartbeat.Format = hbFormat
}
}
// Extract TLS config
if tc, ok := options["tls"].(map[string]any); ok {
cfg.TLS = &config.TLSConfig{}
cfg.TLS.Enabled, _ = tc["enabled"].(bool)
if certFile, ok := tc["cert_file"].(string); ok {
cfg.TLS.CertFile = certFile
}
if keyFile, ok := tc["key_file"].(string); ok {
cfg.TLS.KeyFile = keyFile
}
cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool)
if caFile, ok := tc["client_ca_file"].(string); ok {
cfg.TLS.ClientCAFile = caFile
}
cfg.TLS.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
if minVer, ok := tc["min_version"].(string); ok {
cfg.TLS.MinVersion = minVer
}
if maxVer, ok := tc["max_version"].(string); ok {
cfg.TLS.MaxVersion = maxVer
}
if ciphers, ok := tc["cipher_suites"].(string); ok {
cfg.TLS.CipherSuites = ciphers
}
}
// Extract net limit config
if nl, ok := options["net_limit"].(map[string]any); ok {
cfg.NetLimit = &config.NetLimitConfig{}
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
if rps, ok := nl["requests_per_second"].(float64); ok {
cfg.NetLimit.RequestsPerSecond = rps
}
if burst, ok := nl["burst_size"].(int64); ok {
cfg.NetLimit.BurstSize = burst
}
if respCode, ok := nl["response_code"].(int64); ok {
cfg.NetLimit.ResponseCode = respCode
}
if msg, ok := nl["response_message"].(string); ok {
cfg.NetLimit.ResponseMessage = msg
}
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
}
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.NetLimit.MaxConnectionsTotal = maxTotal
}
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
for _, entry := range ipWhitelist {
if str, ok := entry.(string); ok {
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
}
}
}
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
for _, entry := range ipBlacklist {
if str, ok := entry.(string); ok {
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
}
}
}
} }
h := &HTTPSink{ h := &HTTPSink{
input: make(chan core.LogEntry, cfg.BufferSize), config: opts, // Direct reference to config struct
config: cfg, input: make(chan core.LogEntry, opts.BufferSize),
startTime: time.Now(), startTime: time.Now(),
done: make(chan struct{}), done: make(chan struct{}),
streamPath: cfg.StreamPath, logger: logger,
statusPath: cfg.StatusPath, formatter: formatter,
logger: logger, clients: make(map[uint64]chan core.LogEntry),
formatter: formatter,
clients: make(map[uint64]chan core.LogEntry),
unregister: make(chan uint64, 10), // Buffered for non-blocking
} }
h.lastProcessed.Store(time.Time{}) h.lastProcessed.Store(time.Time{})
// Initialize TLS manager // Initialize TLS manager if configured
if cfg.TLS != nil && cfg.TLS.Enabled { if opts.TLS != nil && opts.TLS.Enabled {
tlsManager, err := tls.NewManager(cfg.TLS, logger) tlsManager, err := tls.NewManager(opts.TLS, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create TLS manager: %w", err) return nil, fmt.Errorf("failed to create TLS manager: %w", err)
} }
h.tlsManager = tlsManager h.tlsManager = tlsManager
logger.Info("msg", "TLS enabled",
"component", "http_sink")
} }
// Initialize net limiter if configured // Initialize net limiter if configured
if cfg.NetLimit != nil && cfg.NetLimit.Enabled { if opts.NetLimit != nil && (opts.NetLimit.Enabled ||
h.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger) len(opts.NetLimit.IPWhitelist) > 0 ||
len(opts.NetLimit.IPBlacklist) > 0) {
h.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
}
// Initialize authenticator if auth is not "none"
if opts.Auth != nil && opts.Auth.Type != "none" {
// Only "basic" and "token" are valid for HTTP sink
if opts.Auth.Type != "basic" && opts.Auth.Type != "token" {
return nil, fmt.Errorf("invalid auth type '%s' for HTTP sink (valid: none, basic, token)", opts.Auth.Type)
}
authenticator, err := auth.NewAuthenticator(opts.Auth, logger)
if err != nil {
return nil, fmt.Errorf("failed to create authenticator: %w", err)
}
h.authenticator = authenticator
h.authConfig = opts.Auth
logger.Info("msg", "Authentication enabled",
"component", "http_sink",
"type", opts.Auth.Type)
} }
return h, nil return h, nil
@ -230,6 +136,9 @@ func (h *HTTPSink) Start(ctx context.Context) error {
DisableKeepalive: false, DisableKeepalive: false,
StreamRequestBody: true, StreamRequestBody: true,
Logger: fasthttpLogger, Logger: fasthttpLogger,
// ReadTimeout: time.Duration(h.config.ReadTimeout) * time.Millisecond,
WriteTimeout: time.Duration(h.config.WriteTimeout) * time.Millisecond,
// MaxRequestBodySize: int(h.config.MaxBodySize),
} }
// Configure TLS if enabled // Configure TLS if enabled
@ -250,8 +159,8 @@ func (h *HTTPSink) Start(ctx context.Context) error {
"component", "http_sink", "component", "http_sink",
"host", h.config.Host, "host", h.config.Host,
"port", h.config.Port, "port", h.config.Port,
"stream_path", h.streamPath, "stream_path", h.config.StreamPath,
"status_path", h.statusPath, "status_path", h.config.StatusPath,
"tls_enabled", h.tlsManager != nil) "tls_enabled", h.tlsManager != nil)
var err error var err error
@ -296,7 +205,7 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) {
var tickerChan <-chan time.Time var tickerChan <-chan time.Time
if h.config.Heartbeat != nil && h.config.Heartbeat.Enabled { if h.config.Heartbeat != nil && h.config.Heartbeat.Enabled {
ticker = time.NewTicker(time.Duration(h.config.Heartbeat.IntervalSeconds) * time.Second) ticker = time.NewTicker(time.Duration(h.config.Heartbeat.Interval) * time.Second)
tickerChan = ticker.C tickerChan = ticker.C
defer ticker.Stop() defer ticker.Stop()
} }
@ -441,8 +350,8 @@ func (h *HTTPSink) GetStats() SinkStats {
"port": h.config.Port, "port": h.config.Port,
"buffer_size": h.config.BufferSize, "buffer_size": h.config.BufferSize,
"endpoints": map[string]string{ "endpoints": map[string]string{
"stream": h.streamPath, "stream": h.config.StreamPath,
"status": h.statusPath, "status": h.config.StatusPath,
}, },
"net_limit": netLimitStats, "net_limit": netLimitStats,
"auth": authStats, "auth": authStats,
@ -489,7 +398,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path()) path := string(ctx.Path())
// Status endpoint doesn't require auth // Status endpoint doesn't require auth
if path == h.statusPath { if path == h.config.StatusPath {
h.handleStatus(ctx) h.handleStatus(ctx)
return return
} }
@ -509,14 +418,14 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
// Return 401 with WWW-Authenticate header // Return 401 with WWW-Authenticate header
ctx.SetStatusCode(fasthttp.StatusUnauthorized) ctx.SetStatusCode(fasthttp.StatusUnauthorized)
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil { if h.authConfig.Type == "basic" && h.authConfig.Basic != nil {
realm := h.authConfig.BasicAuth.Realm realm := h.authConfig.Basic.Realm
if realm == "" { if realm == "" {
realm = "Restricted" realm = "Restricted"
} }
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm)) ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
} else if h.authConfig.Type == "bearer" { } else if h.authConfig.Type == "token" {
ctx.Response.Header.Set("WWW-Authenticate", "Bearer") ctx.Response.Header.Set("WWW-Authenticate", "Token")
} }
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
@ -538,7 +447,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
} }
switch path { switch path {
case h.streamPath: case h.config.StreamPath:
h.handleStream(ctx, session) h.handleStream(ctx, session)
default: default:
ctx.SetStatusCode(fasthttp.StatusNotFound) ctx.SetStatusCode(fasthttp.StatusNotFound)
@ -547,6 +456,15 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
"error": "Not Found", "error": "Not Found",
}) })
} }
// Handle stream endpoint
// if path == h.config.StreamPath {
// h.handleStream(ctx, session)
// return
// }
//
// // Unknown path
// ctx.SetStatusCode(fasthttp.StatusNotFound)
// ctx.SetBody([]byte("Not Found"))
} }
func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) { func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) {
@ -611,8 +529,8 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
"client_id": fmt.Sprintf("%d", clientID), "client_id": fmt.Sprintf("%d", clientID),
"username": session.Username, "username": session.Username,
"auth_method": session.Method, "auth_method": session.Method,
"stream_path": h.streamPath, "stream_path": h.config.StreamPath,
"status_path": h.statusPath, "status_path": h.config.StatusPath,
"buffer_size": h.config.BufferSize, "buffer_size": h.config.BufferSize,
"tls": h.tlsManager != nil, "tls": h.tlsManager != nil,
} }
@ -627,7 +545,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
var tickerChan <-chan time.Time var tickerChan <-chan time.Time
if h.config.Heartbeat != nil && h.config.Heartbeat.Enabled { if h.config.Heartbeat != nil && h.config.Heartbeat.Enabled {
ticker = time.NewTicker(time.Duration(h.config.Heartbeat.IntervalSeconds) * time.Second) ticker = time.NewTicker(time.Duration(h.config.Heartbeat.Interval) * time.Second)
tickerChan = ticker.C tickerChan = ticker.C
defer ticker.Stop() defer ticker.Stop()
} }
@ -716,7 +634,7 @@ func (h *HTTPSink) createHeartbeatEntry() core.LogEntry {
fields := make(map[string]any) fields := make(map[string]any)
fields["type"] = "heartbeat" fields["type"] = "heartbeat"
if h.config.Heartbeat.IncludeStats { if h.config.Heartbeat.Enabled {
fields["active_clients"] = h.activeClients.Load() fields["active_clients"] = h.activeClients.Load()
fields["uptime_seconds"] = int(time.Since(h.startTime).Seconds()) fields["uptime_seconds"] = int(time.Since(h.startTime).Seconds())
} }
@ -775,13 +693,13 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
"uptime_seconds": int(time.Since(h.startTime).Seconds()), "uptime_seconds": int(time.Since(h.startTime).Seconds()),
}, },
"endpoints": map[string]string{ "endpoints": map[string]string{
"transport": h.streamPath, "transport": h.config.StreamPath,
"status": h.statusPath, "status": h.config.StatusPath,
}, },
"features": map[string]any{ "features": map[string]any{
"heartbeat": map[string]any{ "heartbeat": map[string]any{
"enabled": h.config.Heartbeat.Enabled, "enabled": h.config.Heartbeat.Enabled,
"interval": h.config.Heartbeat.IntervalSeconds, "interval": h.config.Heartbeat.Interval,
"format": h.config.Heartbeat.Format, "format": h.config.Heartbeat.Format,
}, },
"tls": tlsStats, "tls": tlsStats,
@ -806,37 +724,15 @@ func (h *HTTPSink) GetActiveConnections() int64 {
// Returns the configured transport endpoint path // Returns the configured transport endpoint path
func (h *HTTPSink) GetStreamPath() string { func (h *HTTPSink) GetStreamPath() string {
return h.streamPath return h.config.StreamPath
} }
// Returns the configured status endpoint path // Returns the configured status endpoint path
func (h *HTTPSink) GetStatusPath() string { func (h *HTTPSink) GetStatusPath() string {
return h.statusPath return h.config.StatusPath
} }
// Returns the configured host // Returns the configured host
func (h *HTTPSink) GetHost() string { func (h *HTTPSink) GetHost() string {
return h.config.Host return h.config.Host
} }
// Configures http sink auth
func (h *HTTPSink) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
h.authConfig = authCfg
authenticator, err := auth.NewAuthenticator(authCfg, h.logger)
if err != nil {
h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink",
"component", "http_sink",
"error", err)
// Continue without auth
return
}
h.authenticator = authenticator
h.logger.Info("msg", "Authentication configured for HTTP sink",
"component", "http_sink",
"auth_type", authCfg.Type)
}

View File

@ -8,7 +8,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/url"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -28,7 +27,7 @@ import (
// Forwards log entries to a remote HTTP endpoint // Forwards log entries to a remote HTTP endpoint
type HTTPClientSink struct { type HTTPClientSink struct {
input chan core.LogEntry input chan core.LogEntry
config HTTPClientConfig config *config.HTTPClientSinkOptions
client *fasthttp.Client client *fasthttp.Client
batch []core.LogEntry batch []core.LogEntry
batchMu sync.Mutex batchMu sync.Mutex
@ -48,195 +47,16 @@ type HTTPClientSink struct {
activeConnections atomic.Int64 activeConnections atomic.Int64
} }
// Holds HTTP client sink configuration
// TODO: missing toml tags
type HTTPClientConfig struct {
// Config
URL string `toml:"url"`
BufferSize int64 `toml:"buffer_size"`
BatchSize int64 `toml:"batch_size"`
BatchDelay time.Duration `toml:"batch_delay_ms"`
Timeout time.Duration `toml:"timeout_seconds"`
Headers map[string]string `toml:"headers"`
// Retry configuration
MaxRetries int64 `toml:"max_retries"`
RetryDelay time.Duration `toml:"retry_delay"`
RetryBackoff float64 `toml:"retry_backoff"` // Multiplier for exponential backoff
// Security
AuthType string `toml:"auth_type"` // "none", "basic", "bearer", "mtls"
Username string `toml:"username"` // For basic auth
Password string `toml:"password"` // For basic auth
BearerToken string `toml:"bearer_token"` // For bearer auth
// TLS configuration
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
CAFile string `toml:"ca_file"`
CertFile string `toml:"cert_file"`
KeyFile string `toml:"key_file"`
}
// Creates a new HTTP client sink // Creates a new HTTP client sink
func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*HTTPClientSink, error) { func NewHTTPClientSink(opts *config.HTTPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPClientSink, error) {
cfg := HTTPClientConfig{ if opts == nil {
BufferSize: int64(1000), return nil, fmt.Errorf("HTTP client sink options cannot be nil")
BatchSize: int64(100),
BatchDelay: time.Second,
Timeout: 30 * time.Second,
MaxRetries: int64(3),
RetryDelay: time.Second,
RetryBackoff: float64(2.0),
Headers: make(map[string]string),
}
// Extract URL
urlStr, ok := options["url"].(string)
if !ok || urlStr == "" {
return nil, fmt.Errorf("http_client sink requires 'url' option")
}
// Validate URL
parsedURL, err := url.Parse(urlStr)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return nil, fmt.Errorf("URL must use http or https scheme")
}
cfg.URL = urlStr
// Extract other options
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
cfg.BufferSize = bufSize
}
if batchSize, ok := options["batch_size"].(int64); ok && batchSize > 0 {
cfg.BatchSize = batchSize
}
if delayMs, ok := options["batch_delay_ms"].(int64); ok && delayMs > 0 {
cfg.BatchDelay = time.Duration(delayMs) * time.Millisecond
}
if timeoutSec, ok := options["timeout_seconds"].(int64); ok && timeoutSec > 0 {
cfg.Timeout = time.Duration(timeoutSec) * time.Second
}
if maxRetries, ok := options["max_retries"].(int64); ok && maxRetries >= 0 {
cfg.MaxRetries = maxRetries
}
if retryDelayMs, ok := options["retry_delay_ms"].(int64); ok && retryDelayMs > 0 {
cfg.RetryDelay = time.Duration(retryDelayMs) * time.Millisecond
}
if backoff, ok := options["retry_backoff"].(float64); ok && backoff >= 1.0 {
cfg.RetryBackoff = backoff
}
if insecure, ok := options["insecure_skip_verify"].(bool); ok {
cfg.InsecureSkipVerify = insecure
}
if authType, ok := options["auth_type"].(string); ok {
switch authType {
case "none", "basic", "bearer", "mtls":
cfg.AuthType = authType
default:
return nil, fmt.Errorf("http_client sink: invalid auth_type '%s'", authType)
}
} else {
cfg.AuthType = "none"
}
if username, ok := options["username"].(string); ok {
cfg.Username = username
}
if password, ok := options["password"].(string); ok {
cfg.Password = password // TODO: change to Argon2 hashed password
}
if token, ok := options["bearer_token"].(string); ok {
cfg.BearerToken = token
}
// Validate auth configuration and TLS enforcement
isHTTPS := strings.HasPrefix(cfg.URL, "https://")
switch cfg.AuthType {
case "basic":
if cfg.Username == "" || cfg.Password == "" {
return nil, fmt.Errorf("http_client sink: username and password required for basic auth")
}
if !isHTTPS {
return nil, fmt.Errorf("http_client sink: basic auth requires HTTPS (security: credentials would be sent in plaintext)")
}
case "bearer":
if cfg.BearerToken == "" {
return nil, fmt.Errorf("http_client sink: bearer_token required for bearer auth")
}
if !isHTTPS {
return nil, fmt.Errorf("http_client sink: bearer auth requires HTTPS (security: token would be sent in plaintext)")
}
case "mtls":
if !isHTTPS {
return nil, fmt.Errorf("http_client sink: mTLS requires HTTPS")
}
if cfg.CertFile == "" || cfg.KeyFile == "" {
return nil, fmt.Errorf("http_client sink: cert_file and key_file required for mTLS")
}
case "none":
// Clear any credentials if auth is "none"
if cfg.Username != "" || cfg.Password != "" || cfg.BearerToken != "" {
logger.Warn("msg", "Credentials provided but auth_type is 'none', ignoring",
"component", "http_client_sink")
cfg.Username = ""
cfg.Password = ""
cfg.BearerToken = ""
}
}
// Extract headers
if headers, ok := options["headers"].(map[string]any); ok {
for k, v := range headers {
if strVal, ok := v.(string); ok {
cfg.Headers[k] = strVal
}
}
}
// Set default Content-Type if not specified
if _, exists := cfg.Headers["Content-Type"]; !exists {
cfg.Headers["Content-Type"] = "application/json"
}
// Extract TLS options
if caFile, ok := options["ca_file"].(string); ok && caFile != "" {
cfg.CAFile = caFile
}
// Extract client certificate options from TLS config
if tc, ok := options["tls"].(map[string]any); ok {
if enabled, _ := tc["enabled"].(bool); enabled {
// Extract client certificate files for mTLS
if certFile, ok := tc["cert_file"].(string); ok && certFile != "" {
if keyFile, ok := tc["key_file"].(string); ok && keyFile != "" {
// These will be used below when configuring TLS
cfg.CertFile = certFile // Need to add these fields to HTTPClientConfig
cfg.KeyFile = keyFile
}
}
// Extract CA file from TLS config if not already set
if cfg.CAFile == "" {
if caFile, ok := tc["ca_file"].(string); ok {
cfg.CAFile = caFile
}
}
// Extract insecure skip verify from TLS config
if insecure, ok := tc["insecure_skip_verify"].(bool); ok {
cfg.InsecureSkipVerify = insecure
}
}
} }
h := &HTTPClientSink{ h := &HTTPClientSink{
input: make(chan core.LogEntry, cfg.BufferSize), config: opts,
config: cfg, input: make(chan core.LogEntry, opts.BufferSize),
batch: make([]core.LogEntry, 0, cfg.BatchSize), batch: make([]core.LogEntry, 0, opts.BatchSize),
done: make(chan struct{}), done: make(chan struct{}),
startTime: time.Now(), startTime: time.Now(),
logger: logger, logger: logger,
@ -249,46 +69,48 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
h.client = &fasthttp.Client{ h.client = &fasthttp.Client{
MaxConnsPerHost: 10, MaxConnsPerHost: 10,
MaxIdleConnDuration: 10 * time.Second, MaxIdleConnDuration: 10 * time.Second,
ReadTimeout: cfg.Timeout, ReadTimeout: time.Duration(opts.Timeout) * time.Second,
WriteTimeout: cfg.Timeout, WriteTimeout: time.Duration(opts.Timeout) * time.Second,
DisableHeaderNamesNormalizing: true, DisableHeaderNamesNormalizing: true,
} }
// Configure TLS if using HTTPS // Configure TLS if using HTTPS
if strings.HasPrefix(cfg.URL, "https://") { if strings.HasPrefix(opts.URL, "https://") {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: cfg.InsecureSkipVerify, InsecureSkipVerify: opts.InsecureSkipVerify,
} }
// Load custom CA for server verification if provided // Use TLS config if provided
if cfg.CAFile != "" { if opts.TLS != nil {
caCert, err := os.ReadFile(cfg.CAFile) // Load custom CA for server verification
if err != nil { if opts.TLS.CAFile != "" {
return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.CAFile, err) caCert, err := os.ReadFile(opts.TLS.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file '%s': %w", opts.TLS.CAFile, err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", opts.TLS.CAFile)
}
tlsConfig.RootCAs = caCertPool
logger.Debug("msg", "Custom CA loaded for server verification",
"component", "http_client_sink",
"ca_file", opts.TLS.CAFile)
} }
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) { // Load client certificate for mTLS if provided
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.CAFile) if opts.TLS.CertFile != "" && opts.TLS.KeyFile != "" {
cert, err := tls.LoadX509KeyPair(opts.TLS.CertFile, opts.TLS.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
logger.Info("msg", "Client certificate loaded for mTLS",
"component", "http_client_sink",
"cert_file", opts.TLS.CertFile)
} }
tlsConfig.RootCAs = caCertPool
logger.Debug("msg", "Custom CA loaded for server verification",
"component", "http_client_sink",
"ca_file", cfg.CAFile)
} }
// Load client certificate for mTLS if provided
if cfg.CertFile != "" && cfg.KeyFile != "" {
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
logger.Info("msg", "Client certificate loaded for mTLS",
"component", "http_client_sink",
"cert_file", cfg.CertFile)
}
// Set TLS config directly on the client
h.client.TLSConfig = tlsConfig h.client.TLSConfig = tlsConfig
} }
@ -308,7 +130,7 @@ func (h *HTTPClientSink) Start(ctx context.Context) error {
"component", "http_client_sink", "component", "http_client_sink",
"url", h.config.URL, "url", h.config.URL,
"batch_size", h.config.BatchSize, "batch_size", h.config.BatchSize,
"batch_delay", h.config.BatchDelay) "batch_delay_ms", h.config.BatchDelayMS)
return nil return nil
} }
@ -399,7 +221,7 @@ func (h *HTTPClientSink) processLoop(ctx context.Context) {
func (h *HTTPClientSink) batchTimer(ctx context.Context) { func (h *HTTPClientSink) batchTimer(ctx context.Context) {
defer h.wg.Done() defer h.wg.Done()
ticker := time.NewTicker(h.config.BatchDelay) ticker := time.NewTicker(time.Duration(h.config.BatchDelayMS) * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
for { for {
@ -468,7 +290,7 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
// Retry logic // Retry logic
var lastErr error var lastErr error
retryDelay := h.config.RetryDelay retryDelay := time.Duration(h.config.RetryDelayMS) * time.Millisecond
// TODO: verify retry loop placement is correct or should it be after acquiring resources (req :=....) // TODO: verify retry loop placement is correct or should it be after acquiring resources (req :=....)
for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ { for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ {
@ -480,9 +302,10 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
newDelay := time.Duration(float64(retryDelay) * h.config.RetryBackoff) newDelay := time.Duration(float64(retryDelay) * h.config.RetryBackoff)
// Cap at maximum to prevent integer overflow // Cap at maximum to prevent integer overflow
if newDelay > h.config.Timeout || newDelay < retryDelay { timeout := time.Duration(h.config.Timeout) * time.Second
if newDelay > timeout || newDelay < retryDelay {
// Either exceeded max or overflowed (negative/wrapped) // Either exceeded max or overflowed (negative/wrapped)
retryDelay = h.config.Timeout retryDelay = timeout
} else { } else {
retryDelay = newDelay retryDelay = newDelay
} }
@ -500,14 +323,14 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short())) req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short()))
// Add authentication based on auth type // Add authentication based on auth type
switch h.config.AuthType { switch h.config.Auth.Type {
case "basic": case "basic":
creds := h.config.Username + ":" + h.config.Password creds := h.config.Auth.Username + ":" + h.config.Auth.Password
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds)) encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
req.Header.Set("Authorization", "Basic "+encodedCreds) req.Header.Set("Authorization", "Basic "+encodedCreds)
case "bearer": case "token":
req.Header.Set("Authorization", "Bearer "+h.config.BearerToken) req.Header.Set("Authorization", "Token "+h.config.Auth.Token)
case "mtls": case "mtls":
// mTLS auth is handled at TLS layer via client certificates // mTLS auth is handled at TLS layer via client certificates
@ -523,7 +346,7 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
} }
// Send request // Send request
err := h.client.DoTimeout(req, resp, h.config.Timeout) err := h.client.DoTimeout(req, resp, time.Duration(h.config.Timeout)*time.Second)
// Capture response before releasing // Capture response before releasing
statusCode := resp.StatusCode() statusCode := resp.StatusCode()
@ -588,9 +411,3 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
"last_error", lastErr) "last_error", lastErr)
h.failedBatches.Add(1) h.failedBatches.Add(1)
} }
// Not applicable, Clients authenticate to remote servers using Username/Password in config
func (h *HTTPClientSink) SetAuth(authCfg *config.AuthConfig) {
// No-op: client sinks don't validate incoming connections
// They authenticate to remote servers using Username/Password fields
}

View File

@ -5,7 +5,6 @@ import (
"context" "context"
"time" "time"
"logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
) )
@ -22,9 +21,6 @@ type Sink interface {
// Returns sink statistics // Returns sink statistics
GetStats() SinkStats GetStats() SinkStats
// Configure authentication
SetAuth(auth *config.AuthConfig)
} }
// Contains statistics about a sink // Contains statistics about a sink

View File

@ -7,7 +7,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -25,26 +24,22 @@ import (
// Streams log entries via TCP // Streams log entries via TCP
type TCPSink struct { type TCPSink struct {
// C input chan core.LogEntry
input chan core.LogEntry config *config.TCPSinkOptions
config TCPConfig server *tcpServer
server *tcpServer done chan struct{}
done chan struct{} activeConns atomic.Int64
activeConns atomic.Int64 startTime time.Time
startTime time.Time engine *gnet.Engine
engine *gnet.Engine engineMu sync.Mutex
engineMu sync.Mutex wg sync.WaitGroup
wg sync.WaitGroup netLimiter *limit.NetLimiter
netLimiter *limit.NetLimiter logger *log.Logger
logger *log.Logger formatter format.Formatter
formatter format.Formatter
authenticator *auth.Authenticator
// Statistics // Statistics
totalProcessed atomic.Uint64 totalProcessed atomic.Uint64
lastProcessed atomic.Value // time.Time lastProcessed atomic.Value // time.Time
authFailures atomic.Uint64
authSuccesses atomic.Uint64
// Write error tracking // Write error tracking
writeErrors atomic.Uint64 writeErrors atomic.Uint64
@ -62,87 +57,14 @@ type TCPConfig struct {
} }
// Creates a new TCP streaming sink // Creates a new TCP streaming sink
func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*TCPSink, error) { func NewTCPSink(opts *config.TCPSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPSink, error) {
cfg := TCPConfig{ if opts == nil {
Host: "0.0.0.0", return nil, fmt.Errorf("TCP sink options cannot be nil")
Port: int64(9090),
BufferSize: int64(1000),
}
// Extract configuration from options
if host, ok := options["host"].(string); ok && host != "" {
cfg.Host = host
}
if port, ok := options["port"].(int64); ok {
cfg.Port = port
}
if bufSize, ok := options["buffer_size"].(int64); ok {
cfg.BufferSize = bufSize
}
// Extract heartbeat config
if hb, ok := options["heartbeat"].(map[string]any); ok {
cfg.Heartbeat = &config.HeartbeatConfig{}
cfg.Heartbeat.Enabled, _ = hb["enabled"].(bool)
if interval, ok := hb["interval_seconds"].(int64); ok {
cfg.Heartbeat.IntervalSeconds = interval
}
cfg.Heartbeat.IncludeTimestamp, _ = hb["include_timestamp"].(bool)
cfg.Heartbeat.IncludeStats, _ = hb["include_stats"].(bool)
if hbFormat, ok := hb["format"].(string); ok {
cfg.Heartbeat.Format = hbFormat
}
}
// Extract net limit config
if nl, ok := options["net_limit"].(map[string]any); ok {
cfg.NetLimit = &config.NetLimitConfig{}
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
if rps, ok := nl["requests_per_second"].(float64); ok {
cfg.NetLimit.RequestsPerSecond = rps
}
if burst, ok := nl["burst_size"].(int64); ok {
cfg.NetLimit.BurstSize = burst
}
if respCode, ok := nl["response_code"].(int64); ok {
cfg.NetLimit.ResponseCode = respCode
}
if msg, ok := nl["response_message"].(string); ok {
cfg.NetLimit.ResponseMessage = msg
}
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
}
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
cfg.NetLimit.MaxConnectionsPerUser = maxPerUser
}
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
cfg.NetLimit.MaxConnectionsPerToken = maxPerToken
}
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.NetLimit.MaxConnectionsTotal = maxTotal
}
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
for _, entry := range ipWhitelist {
if str, ok := entry.(string); ok {
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
}
}
}
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
for _, entry := range ipBlacklist {
if str, ok := entry.(string); ok {
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
}
}
}
} }
t := &TCPSink{ t := &TCPSink{
input: make(chan core.LogEntry, cfg.BufferSize), config: opts, // Direct reference to config
config: cfg, input: make(chan core.LogEntry, opts.BufferSize),
done: make(chan struct{}), done: make(chan struct{}),
startTime: time.Now(), startTime: time.Now(),
logger: logger, logger: logger,
@ -150,9 +72,11 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
} }
t.lastProcessed.Store(time.Time{}) t.lastProcessed.Store(time.Time{})
// Initialize net limiter // Initialize net limiter with pointer
if cfg.NetLimit != nil && cfg.NetLimit.Enabled { if opts.NetLimit != nil && (opts.NetLimit.Enabled ||
t.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger) len(opts.NetLimit.IPWhitelist) > 0 ||
len(opts.NetLimit.IPBlacklist) > 0) {
t.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
} }
return t, nil return t, nil
@ -193,8 +117,7 @@ func (t *TCPSink) Start(ctx context.Context) error {
go func() { go func() {
t.logger.Info("msg", "Starting TCP server", t.logger.Info("msg", "Starting TCP server",
"component", "tcp_sink", "component", "tcp_sink",
"port", t.config.Port, "port", t.config.Port)
"auth", t.authenticator != nil)
err := gnet.Run(t.server, addr, opts...) err := gnet.Run(t.server, addr, opts...)
if err != nil { if err != nil {
@ -282,7 +205,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
var tickerChan <-chan time.Time var tickerChan <-chan time.Time
if t.config.Heartbeat != nil && t.config.Heartbeat.Enabled { if t.config.Heartbeat != nil && t.config.Heartbeat.Enabled {
ticker = time.NewTicker(time.Duration(t.config.Heartbeat.IntervalSeconds) * time.Second) ticker = time.NewTicker(time.Duration(t.config.Heartbeat.Interval) * time.Second)
tickerChan = ticker.C tickerChan = ticker.C
defer ticker.Stop() defer ticker.Stop()
} }
@ -329,21 +252,19 @@ func (t *TCPSink) broadcastData(data []byte) {
t.server.mu.RLock() t.server.mu.RLock()
defer t.server.mu.RUnlock() defer t.server.mu.RUnlock()
for conn, client := range t.server.clients { for conn, _ := range t.server.clients {
if client.authenticated { conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
conn.AsyncWrite(data, func(c gnet.Conn, err error) error { if err != nil {
if err != nil { t.writeErrors.Add(1)
t.writeErrors.Add(1) t.handleWriteError(c, err)
t.handleWriteError(c, err) } else {
} else { // Reset consecutive error count on success
// Reset consecutive error count on success t.errorMu.Lock()
t.errorMu.Lock() delete(t.consecutiveWriteErrors, c)
delete(t.consecutiveWriteErrors, c) t.errorMu.Unlock()
t.errorMu.Unlock() }
} return nil
return nil })
})
}
} }
} }
@ -408,11 +329,10 @@ func (t *TCPSink) GetActiveConnections() int64 {
// Represents a connected TCP client with auth state // Represents a connected TCP client with auth state
type tcpClient struct { type tcpClient struct {
conn gnet.Conn conn gnet.Conn
buffer bytes.Buffer buffer bytes.Buffer
authenticated bool authTimeout time.Time
authTimeout time.Time session *auth.Session
session *auth.Session
} }
// Handles gnet events with authentication // Handles gnet events with authentication
@ -439,7 +359,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
remoteAddr := c.RemoteAddr() remoteAddr := c.RemoteAddr()
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr) s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
// Reject IPv6 connections immediately // Reject IPv6 connections
if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok { if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok {
if tcpAddr.IP.To4() == nil { if tcpAddr.IP.To4() == nil {
return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close
@ -467,14 +387,10 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
s.sink.netLimiter.AddConnection(remoteStr) s.sink.netLimiter.AddConnection(remoteStr)
} }
// Create client state without auth timeout initially // TCP Sink accepts all connections without authentication
client := &tcpClient{ client := &tcpClient{
conn: c, conn: c,
authenticated: s.sink.authenticator == nil, buffer: bytes.Buffer{},
}
if s.sink.authenticator != nil {
client.authTimeout = time.Now().Add(30 * time.Second)
} }
s.mu.Lock() s.mu.Lock()
@ -484,13 +400,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
newCount := s.sink.activeConns.Add(1) newCount := s.sink.activeConns.Add(1)
s.sink.logger.Debug("msg", "TCP connection opened", s.sink.logger.Debug("msg", "TCP connection opened",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
"active_connections", newCount, "active_connections", newCount)
"auth_enabled", s.sink.authenticator != nil)
// Send auth prompt if authentication is required
if s.sink.authenticator != nil {
return []byte("AUTH_REQUIRED\n"), gnet.None
}
return nil, gnet.None return nil, gnet.None
} }
@ -522,96 +432,7 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
} }
func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
s.mu.RLock() // TCP Sink doesn't expect any data from clients, discard all
client, exists := s.clients[c]
s.mu.RUnlock()
if !exists {
return gnet.Close
}
// Authentication phase
if !client.authenticated {
// Check auth timeout
if time.Now().After(client.authTimeout) {
s.sink.logger.Warn("msg", "Authentication timeout",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String())
return gnet.Close
}
// Read auth data
data, _ := c.Next(-1)
if len(data) == 0 {
return gnet.None
}
client.buffer.Write(data)
// Look for complete auth line
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 {
line := client.buffer.Bytes()[:idx]
client.buffer.Next(idx + 1)
// Parse AUTH command: AUTH <method> <credentials>
parts := strings.SplitN(string(line), " ", 3)
if len(parts) != 3 || parts[0] != "AUTH" {
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
return gnet.Close
}
// Authenticate
session, err := s.sink.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String())
if err != nil {
s.sink.authFailures.Add(1)
s.sink.logger.Warn("msg", "TCP authentication failed",
"remote_addr", c.RemoteAddr().String(),
"method", parts[1],
"error", err)
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
return gnet.Close
}
// Authentication successful
s.sink.authSuccesses.Add(1)
s.mu.Lock()
client.authenticated = true
client.session = session
s.mu.Unlock()
s.sink.logger.Info("msg", "TCP client authenticated",
"component", "tcp_sink",
"remote_addr", c.RemoteAddr().String(),
"username", session.Username,
"method", session.Method)
c.AsyncWrite([]byte("AUTH_OK\n"), nil)
client.buffer.Reset()
}
return gnet.None
}
// Clients shouldn't send data, just discard
c.Discard(-1) c.Discard(-1)
return gnet.None return gnet.None
} }
// Configures tcp sink auth
func (t *TCPSink) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
authenticator, err := auth.NewAuthenticator(authCfg, t.logger)
if err != nil {
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
"component", "tcp_sink",
"error", err)
return
}
t.authenticator = authenticator
t.logger.Info("msg", "Authentication configured for TCP sink",
"component", "tcp_sink",
"auth_type", authCfg.Type)
}

View File

@ -7,7 +7,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"logwisp/src/internal/auth"
"net" "net"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -16,7 +18,6 @@ import (
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/format" "logwisp/src/internal/format"
"logwisp/src/internal/scram"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
) )
@ -24,7 +25,8 @@ import (
// Forwards log entries to a remote TCP endpoint // Forwards log entries to a remote TCP endpoint
type TCPClientSink struct { type TCPClientSink struct {
input chan core.LogEntry input chan core.LogEntry
config TCPClientConfig config *config.TCPClientSinkOptions
address string
conn net.Conn conn net.Conn
connMu sync.RWMutex connMu sync.RWMutex
done chan struct{} done chan struct{}
@ -46,101 +48,17 @@ type TCPClientSink struct {
connectionUptime atomic.Value // time.Duration connectionUptime atomic.Value // time.Duration
} }
// Holds TCP client sink configuration
type TCPClientConfig struct {
Address string `toml:"address"`
BufferSize int64 `toml:"buffer_size"`
DialTimeout time.Duration `toml:"dial_timeout_seconds"`
WriteTimeout time.Duration `toml:"write_timeout_seconds"`
ReadTimeout time.Duration `toml:"read_timeout_seconds"`
KeepAlive time.Duration `toml:"keep_alive_seconds"`
// Security
AuthType string `toml:"auth_type"`
Username string `toml:"username"`
Password string `toml:"password"`
// Reconnection settings
ReconnectDelay time.Duration `toml:"reconnect_delay_ms"`
MaxReconnectDelay time.Duration `toml:"max_reconnect_delay_seconds"`
ReconnectBackoff float64 `toml:"reconnect_backoff"`
}
// Creates a new TCP client sink // Creates a new TCP client sink
func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) { func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) {
cfg := TCPClientConfig{ // Validation and defaults are handled in config package
BufferSize: int64(1000), if opts == nil {
DialTimeout: 10 * time.Second, return nil, fmt.Errorf("TCP client sink options cannot be nil")
WriteTimeout: 30 * time.Second,
ReadTimeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
ReconnectDelay: time.Second,
MaxReconnectDelay: 30 * time.Second,
ReconnectBackoff: float64(1.5),
}
// Extract address
address, ok := options["address"].(string)
if !ok || address == "" {
return nil, fmt.Errorf("tcp_client sink requires 'address' option")
}
// Validate address format
_, _, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("invalid address format (expected host:port): %w", err)
}
cfg.Address = address
// Extract other options
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
cfg.BufferSize = bufSize
}
if dialTimeout, ok := options["dial_timeout_seconds"].(int64); ok && dialTimeout > 0 {
cfg.DialTimeout = time.Duration(dialTimeout) * time.Second
}
if writeTimeout, ok := options["write_timeout_seconds"].(int64); ok && writeTimeout > 0 {
cfg.WriteTimeout = time.Duration(writeTimeout) * time.Second
}
if readTimeout, ok := options["read_timeout_seconds"].(int64); ok && readTimeout > 0 {
cfg.ReadTimeout = time.Duration(readTimeout) * time.Second
}
if keepAlive, ok := options["keep_alive_seconds"].(int64); ok && keepAlive > 0 {
cfg.KeepAlive = time.Duration(keepAlive) * time.Second
}
if reconnectDelay, ok := options["reconnect_delay_ms"].(int64); ok && reconnectDelay > 0 {
cfg.ReconnectDelay = time.Duration(reconnectDelay) * time.Millisecond
}
if maxReconnectDelay, ok := options["max_reconnect_delay_seconds"].(int64); ok && maxReconnectDelay > 0 {
cfg.MaxReconnectDelay = time.Duration(maxReconnectDelay) * time.Second
}
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
cfg.ReconnectBackoff = backoff
}
if authType, ok := options["auth_type"].(string); ok {
switch authType {
case "none":
cfg.AuthType = authType
case "scram":
cfg.AuthType = authType
if username, ok := options["username"].(string); ok && username != "" {
cfg.Username = username
} else {
return nil, fmt.Errorf("invalid scram username")
}
if password, ok := options["password"].(string); ok && password != "" {
cfg.Password = password
} else {
return nil, fmt.Errorf("invalid scram password")
}
default:
return nil, fmt.Errorf("tcp_client sink: invalid auth_type '%s' (must be 'none' or 'scram')", authType)
}
} }
t := &TCPClientSink{ t := &TCPClientSink{
input: make(chan core.LogEntry, cfg.BufferSize), config: opts,
config: cfg, address: opts.Host + ":" + strconv.Itoa(int(opts.Port)),
input: make(chan core.LogEntry, opts.BufferSize),
done: make(chan struct{}), done: make(chan struct{}),
startTime: time.Now(), startTime: time.Now(),
logger: logger, logger: logger,
@ -167,7 +85,8 @@ func (t *TCPClientSink) Start(ctx context.Context) error {
t.logger.Info("msg", "TCP client sink started", t.logger.Info("msg", "TCP client sink started",
"component", "tcp_client_sink", "component", "tcp_client_sink",
"address", t.config.Address) "host", t.config.Host,
"port", t.config.Port)
return nil return nil
} }
@ -209,7 +128,7 @@ func (t *TCPClientSink) GetStats() SinkStats {
StartTime: t.startTime, StartTime: t.startTime,
LastProcessed: lastProc, LastProcessed: lastProc,
Details: map[string]any{ Details: map[string]any{
"address": t.config.Address, "address": t.address,
"connected": connected, "connected": connected,
"reconnecting": t.reconnecting.Load(), "reconnecting": t.reconnecting.Load(),
"total_failed": t.totalFailed.Load(), "total_failed": t.totalFailed.Load(),
@ -223,7 +142,7 @@ func (t *TCPClientSink) GetStats() SinkStats {
func (t *TCPClientSink) connectionManager(ctx context.Context) { func (t *TCPClientSink) connectionManager(ctx context.Context) {
defer t.wg.Done() defer t.wg.Done()
reconnectDelay := t.config.ReconnectDelay reconnectDelay := time.Duration(t.config.ReconnectDelayMS) * time.Millisecond
for { for {
select { select {
@ -243,9 +162,9 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
t.lastConnectErr = err t.lastConnectErr = err
t.logger.Warn("msg", "Failed to connect to TCP server", t.logger.Warn("msg", "Failed to connect to TCP server",
"component", "tcp_client_sink", "component", "tcp_client_sink",
"address", t.config.Address, "address", t.address,
"error", err, "error", err,
"retry_delay", reconnectDelay) "retry_delay_ms", reconnectDelay)
// Wait before retry // Wait before retry
select { select {
@ -258,15 +177,15 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
// Exponential backoff // Exponential backoff
reconnectDelay = time.Duration(float64(reconnectDelay) * t.config.ReconnectBackoff) reconnectDelay = time.Duration(float64(reconnectDelay) * t.config.ReconnectBackoff)
if reconnectDelay > t.config.MaxReconnectDelay { if reconnectDelay > time.Duration(t.config.MaxReconnectDelayMS)*time.Millisecond {
reconnectDelay = t.config.MaxReconnectDelay reconnectDelay = time.Duration(t.config.MaxReconnectDelayMS)
} }
continue continue
} }
// Connection successful // Connection successful
t.lastConnectErr = nil t.lastConnectErr = nil
reconnectDelay = t.config.ReconnectDelay // Reset backoff reconnectDelay = time.Duration(t.config.ReconnectDelayMS) * time.Millisecond // Reset backoff
t.connectTime = time.Now() t.connectTime = time.Now()
t.totalReconnects.Add(1) t.totalReconnects.Add(1)
@ -276,7 +195,7 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
t.logger.Info("msg", "Connected to TCP server", t.logger.Info("msg", "Connected to TCP server",
"component", "tcp_client_sink", "component", "tcp_client_sink",
"address", t.config.Address, "address", t.address,
"local_addr", conn.LocalAddr()) "local_addr", conn.LocalAddr())
// Monitor connection // Monitor connection
@ -293,18 +212,18 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
t.logger.Warn("msg", "Lost connection to TCP server", t.logger.Warn("msg", "Lost connection to TCP server",
"component", "tcp_client_sink", "component", "tcp_client_sink",
"address", t.config.Address, "address", t.address,
"uptime", uptime) "uptime", uptime)
} }
} }
func (t *TCPClientSink) connect() (net.Conn, error) { func (t *TCPClientSink) connect() (net.Conn, error) {
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: t.config.DialTimeout, Timeout: time.Duration(t.config.DialTimeout) * time.Second,
KeepAlive: t.config.KeepAlive, KeepAlive: time.Duration(t.config.KeepAlive) * time.Second,
} }
conn, err := dialer.Dial("tcp", t.config.Address) conn, err := dialer.Dial("tcp", t.address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -312,18 +231,18 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
// Set TCP keep-alive // Set TCP keep-alive
if tcpConn, ok := conn.(*net.TCPConn); ok { if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true) tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive) tcpConn.SetKeepAlivePeriod(time.Duration(t.config.KeepAlive) * time.Second)
} }
// SCRAM authentication if credentials configured // SCRAM authentication if credentials configured
if t.config.AuthType == "scram" { if t.config.Auth != nil && t.config.Auth.Type == "scram" {
if err := t.performSCRAMAuth(conn); err != nil { if err := t.performSCRAMAuth(conn); err != nil {
conn.Close() conn.Close()
return nil, fmt.Errorf("SCRAM authentication failed: %w", err) return nil, fmt.Errorf("SCRAM authentication failed: %w", err)
} }
t.logger.Debug("msg", "SCRAM authentication completed", t.logger.Debug("msg", "SCRAM authentication completed",
"component", "tcp_client_sink", "component", "tcp_client_sink",
"address", t.config.Address) "address", t.address)
} }
return conn, nil return conn, nil
@ -333,7 +252,17 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
// Create SCRAM client // Create SCRAM client
scramClient := scram.NewClient(t.config.Username, t.config.Password) scramClient := auth.NewScramClient(t.config.Auth.Username, t.config.Auth.Password)
// Wait for AUTH_REQUIRED from server
authPrompt, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read auth prompt: %w", err)
}
if strings.TrimSpace(authPrompt) != "AUTH_REQUIRED" {
return fmt.Errorf("unexpected server greeting: %s", authPrompt)
}
// Step 1: Send ClientFirst // Step 1: Send ClientFirst
clientFirst, err := scramClient.StartAuthentication() clientFirst, err := scramClient.StartAuthentication()
@ -341,8 +270,10 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
return fmt.Errorf("failed to start SCRAM: %w", err) return fmt.Errorf("failed to start SCRAM: %w", err)
} }
clientFirstJSON, _ := json.Marshal(clientFirst) msg, err := auth.FormatSCRAMRequest("SCRAM-FIRST", clientFirst)
msg := fmt.Sprintf("SCRAM-FIRST %s\n", clientFirstJSON) if err != nil {
return err
}
if _, err := conn.Write([]byte(msg)); err != nil { if _, err := conn.Write([]byte(msg)); err != nil {
return fmt.Errorf("failed to send SCRAM-FIRST: %w", err) return fmt.Errorf("failed to send SCRAM-FIRST: %w", err)
@ -354,13 +285,17 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
return fmt.Errorf("failed to read SCRAM challenge: %w", err) return fmt.Errorf("failed to read SCRAM challenge: %w", err)
} }
parts := strings.Fields(strings.TrimSpace(response)) command, data, err := auth.ParseSCRAMResponse(response)
if len(parts) != 2 || parts[0] != "SCRAM-CHALLENGE" { if err != nil {
return fmt.Errorf("unexpected server response: %s", response) return err
} }
var serverFirst scram.ServerFirst if command != "SCRAM-CHALLENGE" {
if err := json.Unmarshal([]byte(parts[1]), &serverFirst); err != nil { return fmt.Errorf("unexpected server response: %s", command)
}
var serverFirst auth.ServerFirst
if err := json.Unmarshal([]byte(data), &serverFirst); err != nil {
return fmt.Errorf("failed to parse server challenge: %w", err) return fmt.Errorf("failed to parse server challenge: %w", err)
} }
@ -370,8 +305,10 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
return fmt.Errorf("failed to process challenge: %w", err) return fmt.Errorf("failed to process challenge: %w", err)
} }
clientFinalJSON, _ := json.Marshal(clientFinal) msg, err = auth.FormatSCRAMRequest("SCRAM-PROOF", clientFinal)
msg = fmt.Sprintf("SCRAM-PROOF %s\n", clientFinalJSON) if err != nil {
return err
}
if _, err := conn.Write([]byte(msg)); err != nil { if _, err := conn.Write([]byte(msg)); err != nil {
return fmt.Errorf("failed to send SCRAM-PROOF: %w", err) return fmt.Errorf("failed to send SCRAM-PROOF: %w", err)
@ -383,19 +320,15 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
return fmt.Errorf("failed to read SCRAM result: %w", err) return fmt.Errorf("failed to read SCRAM result: %w", err)
} }
parts = strings.Fields(strings.TrimSpace(response)) command, data, err = auth.ParseSCRAMResponse(response)
if len(parts) < 1 { if err != nil {
return fmt.Errorf("empty server response") return err
} }
switch parts[0] { switch command {
case "SCRAM-OK": case "SCRAM-OK":
if len(parts) != 2 { var serverFinal auth.ServerFinal
return fmt.Errorf("invalid SCRAM-OK response") if err := json.Unmarshal([]byte(data), &serverFinal); err != nil {
}
var serverFinal scram.ServerFinal
if err := json.Unmarshal([]byte(parts[1]), &serverFinal); err != nil {
return fmt.Errorf("failed to parse server signature: %w", err) return fmt.Errorf("failed to parse server signature: %w", err)
} }
@ -406,21 +339,21 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
t.logger.Info("msg", "SCRAM authentication successful", t.logger.Info("msg", "SCRAM authentication successful",
"component", "tcp_client_sink", "component", "tcp_client_sink",
"address", t.config.Address, "address", t.address,
"username", t.config.Username, "username", t.config.Auth.Username,
"session_id", serverFinal.SessionID) "session_id", serverFinal.SessionID)
return nil return nil
case "SCRAM-FAIL": case "SCRAM-FAIL":
reason := "unknown" reason := data
if len(parts) > 1 { if reason == "" {
reason = strings.Join(parts[1:], " ") reason = "unknown"
} }
return fmt.Errorf("authentication failed: %s", reason) return fmt.Errorf("authentication failed: %s", reason)
default: default:
return fmt.Errorf("unexpected response: %s", response) return fmt.Errorf("unexpected response: %s", command)
} }
} }
@ -436,7 +369,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) {
return return
case <-ticker.C: case <-ticker.C:
// Set read deadline // Set read deadline
if err := conn.SetReadDeadline(time.Now().Add(t.config.ReadTimeout)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(time.Duration(t.config.ReadTimeout) * time.Second)); err != nil {
t.logger.Debug("msg", "Failed to set read deadline", "error", err) t.logger.Debug("msg", "Failed to set read deadline", "error", err)
return return
} }
@ -502,7 +435,7 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
} }
// Set write deadline // Set write deadline
if err := conn.SetWriteDeadline(time.Now().Add(t.config.WriteTimeout)); err != nil { if err := conn.SetWriteDeadline(time.Now().Add(time.Duration(t.config.WriteTimeout) * time.Second)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err) return fmt.Errorf("failed to set write deadline: %w", err)
} }
@ -519,9 +452,3 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
return nil return nil
} }
// Not applicable, Clients authenticate to remote servers using Username/Password in config
func (h *TCPClientSink) SetAuth(authCfg *config.AuthConfig) {
// No-op: client sinks don't validate incoming connections
// They authenticate to remote servers using Username/Password fields
}

View File

@ -21,9 +21,7 @@ import (
// Monitors a directory for log files // Monitors a directory for log files
type DirectorySource struct { type DirectorySource struct {
path string config *config.DirectorySourceOptions
pattern string
checkInterval time.Duration
subscribers []chan core.LogEntry subscribers []chan core.LogEntry
watchers map[string]*fileWatcher watchers map[string]*fileWatcher
mu sync.RWMutex mu sync.RWMutex
@ -38,34 +36,16 @@ type DirectorySource struct {
} }
// Creates a new directory monitoring source // Creates a new directory monitoring source
func NewDirectorySource(options map[string]any, logger *log.Logger) (*DirectorySource, error) { func NewDirectorySource(opts *config.DirectorySourceOptions, logger *log.Logger) (*DirectorySource, error) {
path, ok := options["path"].(string) if opts == nil {
if !ok { return nil, fmt.Errorf("directory source options cannot be nil")
return nil, fmt.Errorf("directory source requires 'path' option")
}
pattern, _ := options["pattern"].(string)
if pattern == "" {
pattern = "*"
}
checkInterval := 100 * time.Millisecond
if ms, ok := options["check_interval_ms"].(int64); ok && ms > 0 {
checkInterval = time.Duration(ms) * time.Millisecond
}
absPath, err := filepath.Abs(path)
if err != nil {
return nil, fmt.Errorf("invalid path %s: %w", path, err)
} }
ds := &DirectorySource{ ds := &DirectorySource{
path: absPath, config: opts,
pattern: pattern, watchers: make(map[string]*fileWatcher),
checkInterval: checkInterval, startTime: time.Now(),
watchers: make(map[string]*fileWatcher), logger: logger,
startTime: time.Now(),
logger: logger,
} }
ds.lastEntryTime.Store(time.Time{}) ds.lastEntryTime.Store(time.Time{})
@ -88,9 +68,9 @@ func (ds *DirectorySource) Start() error {
ds.logger.Info("msg", "Directory source started", ds.logger.Info("msg", "Directory source started",
"component", "directory_source", "component", "directory_source",
"path", ds.path, "path", ds.config.Path,
"pattern", ds.pattern, "pattern", ds.config.Pattern,
"check_interval_ms", ds.checkInterval.Milliseconds()) "check_interval_ms", ds.config.CheckIntervalMS)
return nil return nil
} }
@ -111,7 +91,7 @@ func (ds *DirectorySource) Stop() {
ds.logger.Info("msg", "Directory source stopped", ds.logger.Info("msg", "Directory source stopped",
"component", "directory_source", "component", "directory_source",
"path", ds.path) "path", ds.config.Path)
} }
func (ds *DirectorySource) GetStats() SourceStats { func (ds *DirectorySource) GetStats() SourceStats {
@ -171,7 +151,7 @@ func (ds *DirectorySource) monitorLoop() {
ds.checkTargets() ds.checkTargets()
ticker := time.NewTicker(ds.checkInterval) ticker := time.NewTicker(time.Duration(ds.config.CheckIntervalMS) * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
for { for {
@ -189,8 +169,8 @@ func (ds *DirectorySource) checkTargets() {
if err != nil { if err != nil {
ds.logger.Warn("msg", "Failed to scan directory", ds.logger.Warn("msg", "Failed to scan directory",
"component", "directory_source", "component", "directory_source",
"path", ds.path, "path", ds.config.Path,
"pattern", ds.pattern, "pattern", ds.config.Pattern,
"error", err) "error", err)
return return
} }
@ -203,13 +183,13 @@ func (ds *DirectorySource) checkTargets() {
} }
func (ds *DirectorySource) scanDirectory() ([]string, error) { func (ds *DirectorySource) scanDirectory() ([]string, error) {
entries, err := os.ReadDir(ds.path) entries, err := os.ReadDir(ds.config.Path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Convert glob pattern to regex // Convert glob pattern to regex
regexPattern := globToRegex(ds.pattern) regexPattern := globToRegex(ds.config.Pattern)
re, err := regexp.Compile(regexPattern) re, err := regexp.Compile(regexPattern)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid pattern regex: %w", err) return nil, fmt.Errorf("invalid pattern regex: %w", err)
@ -223,7 +203,7 @@ func (ds *DirectorySource) scanDirectory() ([]string, error) {
name := entry.Name() name := entry.Name()
if re.MatchString(name) { if re.MatchString(name) {
files = append(files, filepath.Join(ds.path, name)) files = append(files, filepath.Join(ds.config.Path, name))
} }
} }
@ -288,7 +268,3 @@ func globToRegex(glob string) string {
regex = strings.ReplaceAll(regex, `\?`, `.`) regex = strings.ReplaceAll(regex, `\?`, `.`)
return "^" + regex + "$" return "^" + regex + "$"
} }
func (ds *DirectorySource) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to directory source
}

View File

@ -14,7 +14,6 @@ import (
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/limit" "logwisp/src/internal/limit"
"logwisp/src/internal/tls" "logwisp/src/internal/tls"
"logwisp/src/internal/version"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
@ -22,12 +21,7 @@ import (
// Receives log entries via HTTP POST requests // Receives log entries via HTTP POST requests
type HTTPSource struct { type HTTPSource struct {
// Config config *config.HTTPSourceOptions
host string
port int64
path string
bufferSize int64
maxRequestBodySize int64
// Application // Application
server *fasthttp.Server server *fasthttp.Server
@ -42,11 +36,9 @@ type HTTPSource struct {
// Security // Security
authenticator *auth.Authenticator authenticator *auth.Authenticator
authConfig *config.AuthConfig
authFailures atomic.Uint64 authFailures atomic.Uint64
authSuccesses atomic.Uint64 authSuccesses atomic.Uint64
tlsManager *tls.Manager tlsManager *tls.Manager
tlsConfig *config.TLSConfig
// Statistics // Statistics
totalEntries atomic.Uint64 totalEntries atomic.Uint64
@ -57,108 +49,52 @@ type HTTPSource struct {
} }
// Creates a new HTTP server source // Creates a new HTTP server source
func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, error) { func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSource, error) {
host := "0.0.0.0" // Validation done in config package
if h, ok := options["host"].(string); ok && h != "" { if opts == nil {
host = h return nil, fmt.Errorf("HTTP source options cannot be nil")
}
port, ok := options["port"].(int64)
if !ok || port < 1 || port > 65535 {
return nil, fmt.Errorf("http source requires valid 'port' option")
}
ingestPath := "/ingest"
if path, ok := options["path"].(string); ok && path != "" {
ingestPath = path
}
bufferSize := int64(1000)
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
bufferSize = bufSize
}
maxRequestBodySize := int64(10 * 1024 * 1024) // fasthttp default 10MB
if maxBodySize, ok := options["max_body_size"].(int64); ok && maxBodySize > 0 && maxBodySize < maxRequestBodySize {
maxRequestBodySize = maxBodySize
} }
h := &HTTPSource{ h := &HTTPSource{
host: host, config: opts,
port: port, done: make(chan struct{}),
path: ingestPath, startTime: time.Now(),
bufferSize: bufferSize, logger: logger,
maxRequestBodySize: maxRequestBodySize,
done: make(chan struct{}),
startTime: time.Now(),
logger: logger,
} }
h.lastEntryTime.Store(time.Time{}) h.lastEntryTime.Store(time.Time{})
// Initialize net limiter if configured // Initialize net limiter if configured
if nl, ok := options["net_limit"].(map[string]any); ok { if opts.NetLimit != nil && (opts.NetLimit.Enabled ||
if enabled, _ := nl["enabled"].(bool); enabled { len(opts.NetLimit.IPWhitelist) > 0 ||
cfg := config.NetLimitConfig{ len(opts.NetLimit.IPBlacklist) > 0) {
Enabled: true, h.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
}
if rps, ok := nl["requests_per_second"].(float64); ok {
cfg.RequestsPerSecond = rps
}
if burst, ok := nl["burst_size"].(int64); ok {
cfg.BurstSize = burst
}
if respCode, ok := nl["response_code"].(int64); ok {
cfg.ResponseCode = respCode
}
if msg, ok := nl["response_message"].(string); ok {
cfg.ResponseMessage = msg
}
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
cfg.MaxConnectionsPerIP = maxPerIP
}
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.MaxConnectionsTotal = maxTotal
}
h.netLimiter = limit.NewNetLimiter(cfg, logger)
}
} }
// Extract TLS config after existing options // Initialize TLS manager if configured
if tc, ok := options["tls"].(map[string]any); ok { if opts.TLS != nil && opts.TLS.Enabled {
h.tlsConfig = &config.TLSConfig{} tlsManager, err := tls.NewManager(opts.TLS, logger)
h.tlsConfig.Enabled, _ = tc["enabled"].(bool) if err != nil {
if certFile, ok := tc["cert_file"].(string); ok { return nil, fmt.Errorf("failed to create TLS manager: %w", err)
h.tlsConfig.CertFile = certFile
} }
if keyFile, ok := tc["key_file"].(string); ok { h.tlsManager = tlsManager
h.tlsConfig.KeyFile = keyFile }
}
h.tlsConfig.ClientAuth, _ = tc["client_auth"].(bool) // Initialize authenticator if configured
if caFile, ok := tc["client_ca_file"].(string); ok { if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" {
h.tlsConfig.ClientCAFile = caFile // Verify TLS is enabled for auth (validation should have caught this)
} if h.tlsManager == nil {
h.tlsConfig.VerifyClientCert, _ = tc["verify_client_cert"].(bool) return nil, fmt.Errorf("authentication requires TLS to be enabled")
h.tlsConfig.InsecureSkipVerify, _ = tc["insecure_skip_verify"].(bool)
if minVer, ok := tc["min_version"].(string); ok {
h.tlsConfig.MinVersion = minVer
}
if maxVer, ok := tc["max_version"].(string); ok {
h.tlsConfig.MaxVersion = maxVer
}
if ciphers, ok := tc["cipher_suites"].(string); ok {
h.tlsConfig.CipherSuites = ciphers
} }
// Create TLS manager authenticator, err := auth.NewAuthenticator(opts.Auth, logger)
if h.tlsConfig.Enabled { if err != nil {
tlsManager, err := tls.NewManager(h.tlsConfig, logger) return nil, fmt.Errorf("failed to create authenticator: %w", err)
if err != nil {
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
}
h.tlsManager = tlsManager
} }
h.authenticator = authenticator
logger.Info("msg", "Authentication configured for HTTP source",
"component", "http_source",
"auth_type", opts.Auth.Type)
} }
return h, nil return h, nil
@ -168,23 +104,24 @@ func (h *HTTPSource) Subscribe() <-chan core.LogEntry {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
ch := make(chan core.LogEntry, h.bufferSize) ch := make(chan core.LogEntry, h.config.BufferSize)
h.subscribers = append(h.subscribers, ch) h.subscribers = append(h.subscribers, ch)
return ch return ch
} }
func (h *HTTPSource) Start() error { func (h *HTTPSource) Start() error {
h.server = &fasthttp.Server{ h.server = &fasthttp.Server{
Name: fmt.Sprintf("LogWisp/%s", version.Short()),
Handler: h.requestHandler, Handler: h.requestHandler,
DisableKeepalive: false, DisableKeepalive: false,
StreamRequestBody: true, StreamRequestBody: true,
CloseOnShutdown: true, CloseOnShutdown: true,
MaxRequestBodySize: int(h.maxRequestBodySize), ReadTimeout: time.Duration(h.config.ReadTimeout) * time.Millisecond,
WriteTimeout: time.Duration(h.config.WriteTimeout) * time.Millisecond,
MaxRequestBodySize: int(h.config.MaxRequestBodySize),
} }
// Use configured host and port // Use configured host and port
addr := fmt.Sprintf("%s:%d", h.host, h.port) addr := fmt.Sprintf("%s:%d", h.config.Host, h.config.Port)
// Start server in background // Start server in background
h.wg.Add(1) h.wg.Add(1)
@ -193,35 +130,35 @@ func (h *HTTPSource) Start() error {
defer h.wg.Done() defer h.wg.Done()
h.logger.Info("msg", "HTTP source server starting", h.logger.Info("msg", "HTTP source server starting",
"component", "http_source", "component", "http_source",
"port", h.port, "port", h.config.Port,
"path", h.path, "ingest_path", h.config.IngestPath,
"tls_enabled", h.tlsManager != nil) "tls_enabled", h.tlsManager != nil,
"auth_enabled", h.authenticator != nil)
var err error var err error
// Check for TLS manager and start the appropriate server type
if h.tlsManager != nil { if h.tlsManager != nil {
// HTTPS server
h.server.TLSConfig = h.tlsManager.GetHTTPConfig() h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
err = h.server.ListenAndServeTLS(addr, h.tlsConfig.CertFile, h.tlsConfig.KeyFile) err = h.server.ListenAndServeTLS(addr, h.config.TLS.CertFile, h.config.TLS.KeyFile)
} else { } else {
// HTTP server
err = h.server.ListenAndServe(addr) err = h.server.ListenAndServe(addr)
} }
if err != nil { if err != nil {
h.logger.Error("msg", "HTTP source server failed", h.logger.Error("msg", "HTTP source server failed",
"component", "http_source", "component", "http_source",
"port", h.port, "port", h.config.Port,
"error", err) "error", err)
errChan <- err errChan <- err
} }
}() }()
// Robust server startup check with timeout // Wait briefly for server startup
select { select {
case err := <-errChan: case err := <-errChan:
// Server failed to start
return fmt.Errorf("HTTP server failed to start: %w", err) return fmt.Errorf("HTTP server failed to start: %w", err)
case <-time.After(250 * time.Millisecond): case <-time.After(250 * time.Millisecond):
// Server started successfully (no immediate error)
return nil return nil
} }
} }
@ -263,6 +200,21 @@ func (h *HTTPSource) GetStats() SourceStats {
netLimitStats = h.netLimiter.GetStats() netLimitStats = h.netLimiter.GetStats()
} }
var authStats map[string]any
if h.authenticator != nil {
authStats = map[string]any{
"enabled": true,
"type": h.config.Auth.Type,
"failures": h.authFailures.Load(),
"successes": h.authSuccesses.Load(),
}
}
var tlsStats map[string]any
if h.tlsManager != nil {
tlsStats = h.tlsManager.GetStats()
}
return SourceStats{ return SourceStats{
Type: "http", Type: "http",
TotalEntries: h.totalEntries.Load(), TotalEntries: h.totalEntries.Load(),
@ -270,10 +222,13 @@ func (h *HTTPSource) GetStats() SourceStats {
StartTime: h.startTime, StartTime: h.startTime,
LastEntryTime: lastEntry, LastEntryTime: lastEntry,
Details: map[string]any{ Details: map[string]any{
"port": h.port, "host": h.config.Host,
"path": h.path, "port": h.config.Port,
"path": h.config.IngestPath,
"invalid_entries": h.invalidEntries.Load(), "invalid_entries": h.invalidEntries.Load(),
"net_limit": netLimitStats, "net_limit": netLimitStats,
"auth": authStats,
"tls": tlsStats,
}, },
} }
} }
@ -307,17 +262,10 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
} }
} }
// 2.5. Check TLS requirement for auth (early reject) // 3. Check TLS requirement for auth
if h.authenticator != nil && h.authConfig.Type != "none" { if h.authenticator != nil {
// Check if connection is TLS
isTLS := ctx.IsTLS() || h.tlsManager != nil isTLS := ctx.IsTLS() || h.tlsManager != nil
if !isTLS { if !isTLS {
h.logger.Error("msg", "Authentication configured but connection is not TLS",
"component", "http_source",
"remote_addr", remoteAddr,
"auth_type", h.authConfig.Type)
ctx.SetStatusCode(fasthttp.StatusForbidden) ctx.SetStatusCode(fasthttp.StatusForbidden)
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{ json.NewEncoder(ctx).Encode(map[string]string{
@ -326,21 +274,45 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
}) })
return return
} }
// Authenticate request
authHeader := string(ctx.Request.Header.Peek("Authorization"))
session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
if err != nil {
h.authFailures.Add(1)
h.logger.Warn("msg", "Authentication failed",
"component", "http_source",
"remote_addr", remoteAddr,
"error", err)
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
if h.config.Auth.Type == "basic" && h.config.Auth.Basic != nil && h.config.Auth.Basic.Realm != "" {
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, h.config.Auth.Basic.Realm))
}
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "Authentication failed",
})
return
}
h.authSuccesses.Add(1)
_ = session // Session can be used for audit logging
} }
// 3. Path check (only process ingest path) // 4. Path check
path := string(ctx.Path()) path := string(ctx.Path())
if path != h.path { if path != h.config.IngestPath {
ctx.SetStatusCode(fasthttp.StatusNotFound) ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{ json.NewEncoder(ctx).Encode(map[string]string{
"error": "Not Found", "error": "Not Found",
"hint": fmt.Sprintf("POST logs to %s", h.path), "hint": fmt.Sprintf("POST logs to %s", h.config.IngestPath),
}) })
return return
} }
// 4. Method check (only accept POST) // 5. Method check (only accepts POST)
if string(ctx.Method()) != "POST" { if string(ctx.Method()) != "POST" {
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed) ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
@ -352,43 +324,10 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
return return
} }
// 5. Authentication check (if configured) // 6. Process log entry
if h.authenticator != nil {
authHeader := string(ctx.Request.Header.Peek("Authorization"))
session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
if err != nil {
h.authFailures.Add(1)
h.logger.Warn("msg", "Authentication failed",
"component", "http_source",
"remote_addr", remoteAddr,
"error", err)
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
realm := h.authConfig.BasicAuth.Realm
if realm == "" {
realm = "Restricted"
}
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm))
} else if h.authConfig.Type == "bearer" {
ctx.Response.Header.Set("WWW-Authenticate", "Bearer")
}
ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{
"error": "Unauthorized",
})
return
}
h.authSuccesses.Add(1)
h.logger.Debug("msg", "Request authenticated",
"component", "http_source",
"remote_addr", remoteAddr,
"username", session.Username)
}
// 6. Process request body
body := ctx.PostBody() body := ctx.PostBody()
if len(body) == 0 { if len(body) == 0 {
h.invalidEntries.Add(1)
ctx.SetStatusCode(fasthttp.StatusBadRequest) ctx.SetStatusCode(fasthttp.StatusBadRequest)
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{ json.NewEncoder(ctx).Encode(map[string]string{
@ -397,32 +336,34 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
return return
} }
// 7. Parse log entries var entry core.LogEntry
entries, err := h.parseEntries(body) if err := json.Unmarshal(body, &entry); err != nil {
if err != nil {
h.invalidEntries.Add(1) h.invalidEntries.Add(1)
ctx.SetStatusCode(fasthttp.StatusBadRequest) ctx.SetStatusCode(fasthttp.StatusBadRequest)
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]string{ json.NewEncoder(ctx).Encode(map[string]string{
"error": fmt.Sprintf("Invalid log format: %v", err), "error": fmt.Sprintf("Invalid JSON: %v", err),
}) })
return return
} }
// 8. Publish entries to subscribers // Set defaults
accepted := 0 if entry.Time.IsZero() {
for _, entry := range entries { entry.Time = time.Now()
if h.publish(entry) {
accepted++
}
} }
if entry.Source == "" {
entry.Source = "http"
}
entry.RawSize = int64(len(body))
// 9. Return success response // Publish to subscribers
h.publish(entry)
// Success response
ctx.SetStatusCode(fasthttp.StatusAccepted) ctx.SetStatusCode(fasthttp.StatusAccepted)
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
json.NewEncoder(ctx).Encode(map[string]any{ json.NewEncoder(ctx).Encode(map[string]string{
"accepted": accepted, "status": "accepted",
"total": len(entries),
}) })
} }
@ -501,29 +442,22 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) {
return entries, nil return entries, nil
} }
func (h *HTTPSource) publish(entry core.LogEntry) bool { func (h *HTTPSource) publish(entry core.LogEntry) {
h.mu.RLock() h.mu.RLock()
defer h.mu.RUnlock() defer h.mu.RUnlock()
h.totalEntries.Add(1) h.totalEntries.Add(1)
h.lastEntryTime.Store(entry.Time) h.lastEntryTime.Store(entry.Time)
dropped := false
for _, ch := range h.subscribers { for _, ch := range h.subscribers {
select { select {
case ch <- entry: case ch <- entry:
default: default:
dropped = true
h.droppedEntries.Add(1) h.droppedEntries.Add(1)
h.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
"component", "http_source")
} }
} }
if dropped {
h.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
"component", "http_source")
}
return true
} }
// Splits bytes into lines, handling both \n and \r\n // Splits bytes into lines, handling both \n and \r\n
@ -550,24 +484,3 @@ func splitLines(data []byte) [][]byte {
return lines return lines
} }
// Configure HTTP source auth
func (h *HTTPSource) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
h.authConfig = authCfg
authenticator, err := auth.NewAuthenticator(authCfg, h.logger)
if err != nil {
h.logger.Error("msg", "Failed to initialize authenticator for HTTP source",
"component", "http_source",
"error", err)
return
}
h.authenticator = authenticator
h.logger.Info("msg", "Authentication configured for HTTP source",
"component", "http_source",
"auth_type", authCfg.Type)
}

View File

@ -4,7 +4,6 @@ package source
import ( import (
"time" "time"
"logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
) )
@ -21,9 +20,6 @@ type Source interface {
// Returns source statistics // Returns source statistics
GetStats() SourceStats GetStats() SourceStats
// Configure authentication
SetAuth(auth *config.AuthConfig)
} }
// Contains statistics about a source // Contains statistics about a source

View File

@ -15,24 +15,25 @@ import (
// Reads log entries from standard input // Reads log entries from standard input
type StdinSource struct { type StdinSource struct {
config *config.StdinSourceOptions
subscribers []chan core.LogEntry subscribers []chan core.LogEntry
done chan struct{} done chan struct{}
totalEntries atomic.Uint64 totalEntries atomic.Uint64
droppedEntries atomic.Uint64 droppedEntries atomic.Uint64
bufferSize int64
startTime time.Time startTime time.Time
lastEntryTime atomic.Value // time.Time lastEntryTime atomic.Value // time.Time
logger *log.Logger logger *log.Logger
} }
func NewStdinSource(options map[string]any, logger *log.Logger) (*StdinSource, error) { func NewStdinSource(opts *config.StdinSourceOptions, logger *log.Logger) (*StdinSource, error) {
bufferSize := int64(1000) // default if opts == nil {
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 { opts = &config.StdinSourceOptions{
bufferSize = bufSize BufferSize: 1000, // Default
}
} }
source := &StdinSource{ source := &StdinSource{
bufferSize: bufferSize, config: opts,
subscribers: make([]chan core.LogEntry, 0), subscribers: make([]chan core.LogEntry, 0),
done: make(chan struct{}), done: make(chan struct{}),
logger: logger, logger: logger,
@ -43,7 +44,7 @@ func NewStdinSource(options map[string]any, logger *log.Logger) (*StdinSource, e
} }
func (s *StdinSource) Subscribe() <-chan core.LogEntry { func (s *StdinSource) Subscribe() <-chan core.LogEntry {
ch := make(chan core.LogEntry, s.bufferSize) ch := make(chan core.LogEntry, s.config.BufferSize)
s.subscribers = append(s.subscribers, ch) s.subscribers = append(s.subscribers, ch)
return ch return ch
} }
@ -120,7 +121,3 @@ func (s *StdinSource) publish(entry core.LogEntry) {
} }
} }
} }
func (s *StdinSource) SetAuth(auth *config.AuthConfig) {
// Authentication does not apply to stdin source
}

View File

@ -4,7 +4,6 @@ package source
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
@ -17,7 +16,6 @@ import (
"logwisp/src/internal/config" "logwisp/src/internal/config"
"logwisp/src/internal/core" "logwisp/src/internal/core"
"logwisp/src/internal/limit" "logwisp/src/internal/limit"
"logwisp/src/internal/scram"
"github.com/lixenwraith/log" "github.com/lixenwraith/log"
"github.com/lixenwraith/log/compat" "github.com/lixenwraith/log/compat"
@ -31,19 +29,19 @@ const (
// Receives log entries via TCP connections // Receives log entries via TCP connections
type TCPSource struct { type TCPSource struct {
host string config *config.TCPSourceOptions
port int64 server *tcpSourceServer
bufferSize int64 subscribers []chan core.LogEntry
server *tcpSourceServer mu sync.RWMutex
subscribers []chan core.LogEntry done chan struct{}
mu sync.RWMutex engine *gnet.Engine
done chan struct{} engineMu sync.Mutex
engine *gnet.Engine wg sync.WaitGroup
engineMu sync.Mutex authenticator *auth.Authenticator
wg sync.WaitGroup netLimiter *limit.NetLimiter
netLimiter *limit.NetLimiter logger *log.Logger
logger *log.Logger scramManager *auth.ScramManager
scramManager *scram.ScramManager scramProtocolHandler *auth.ScramProtocolHandler
// Statistics // Statistics
totalEntries atomic.Uint64 totalEntries atomic.Uint64
@ -57,60 +55,36 @@ type TCPSource struct {
} }
// Creates a new TCP server source // Creates a new TCP server source
func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error) { func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource, error) {
host := "0.0.0.0" // Accept typed config - validation done in config package
if h, ok := options["host"].(string); ok && h != "" { if opts == nil {
host = h return nil, fmt.Errorf("TCP source options cannot be nil")
}
port, ok := options["port"].(int64)
if !ok || port < 1 || port > 65535 {
return nil, fmt.Errorf("tcp source requires valid 'port' option")
}
bufferSize := int64(1000)
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
bufferSize = bufSize
} }
t := &TCPSource{ t := &TCPSource{
host: host, config: opts,
port: port, done: make(chan struct{}),
bufferSize: bufferSize, startTime: time.Now(),
done: make(chan struct{}), logger: logger,
startTime: time.Now(),
logger: logger,
} }
t.lastEntryTime.Store(time.Time{}) t.lastEntryTime.Store(time.Time{})
// Initialize net limiter if configured // Initialize net limiter if configured
if nl, ok := options["net_limit"].(map[string]any); ok { if opts.NetLimit != nil && (opts.NetLimit.Enabled ||
if enabled, _ := nl["enabled"].(bool); enabled { len(opts.NetLimit.IPWhitelist) > 0 ||
cfg := config.NetLimitConfig{ len(opts.NetLimit.IPBlacklist) > 0) {
Enabled: true, t.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
} }
if rps, ok := nl["requests_per_second"].(float64); ok { // Initialize SCRAM
cfg.RequestsPerSecond = rps if opts.Auth != nil && opts.Auth.Type == "scram" && opts.Auth.Scram != nil {
} t.scramManager = auth.NewScramManager(opts.Auth.Scram)
if burst, ok := nl["burst_size"].(int64); ok { t.scramProtocolHandler = auth.NewScramProtocolHandler(t.scramManager, logger)
cfg.BurstSize = burst logger.Info("msg", "SCRAM authentication configured for TCP source",
} "component", "tcp_source",
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok { "users", len(opts.Auth.Scram.Users))
cfg.MaxConnectionsPerIP = maxPerIP } else if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" {
} return nil, fmt.Errorf("TCP source only supports 'none' or 'scram' auth")
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
cfg.MaxConnectionsPerUser = maxPerUser
}
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
cfg.MaxConnectionsPerToken = maxPerToken
}
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
cfg.MaxConnectionsTotal = maxTotal
}
t.netLimiter = limit.NewNetLimiter(cfg, logger)
}
} }
return t, nil return t, nil
@ -120,7 +94,7 @@ func (t *TCPSource) Subscribe() <-chan core.LogEntry {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
ch := make(chan core.LogEntry, t.bufferSize) ch := make(chan core.LogEntry, t.config.BufferSize)
t.subscribers = append(t.subscribers, ch) t.subscribers = append(t.subscribers, ch)
return ch return ch
} }
@ -132,7 +106,7 @@ func (t *TCPSource) Start() error {
} }
// Use configured host and port // Use configured host and port
addr := fmt.Sprintf("tcp://%s:%d", t.host, t.port) addr := fmt.Sprintf("tcp://%s:%d", t.config.Host, t.config.Port)
// Create a gnet adapter using the existing logger instance // Create a gnet adapter using the existing logger instance
gnetLogger := compat.NewGnetAdapter(t.logger) gnetLogger := compat.NewGnetAdapter(t.logger)
@ -144,17 +118,19 @@ func (t *TCPSource) Start() error {
defer t.wg.Done() defer t.wg.Done()
t.logger.Info("msg", "TCP source server starting", t.logger.Info("msg", "TCP source server starting",
"component", "tcp_source", "component", "tcp_source",
"port", t.port) "port", t.config.Port,
"auth_enabled", t.authenticator != nil)
err := gnet.Run(t.server, addr, err := gnet.Run(t.server, addr,
gnet.WithLogger(gnetLogger), gnet.WithLogger(gnetLogger),
gnet.WithMulticore(true), gnet.WithMulticore(true),
gnet.WithReusePort(true), gnet.WithReusePort(true),
gnet.WithTCPKeepAlive(time.Duration(t.config.KeepAlivePeriod)*time.Millisecond),
) )
if err != nil { if err != nil {
t.logger.Error("msg", "TCP source server failed", t.logger.Error("msg", "TCP source server failed",
"component", "tcp_source", "component", "tcp_source",
"port", t.port, "port", t.config.Port,
"error", err) "error", err)
} }
errChan <- err errChan <- err
@ -169,7 +145,7 @@ func (t *TCPSource) Start() error {
return err return err
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
// Server started successfully // Server started successfully
t.logger.Info("msg", "TCP server started", "port", t.port) t.logger.Info("msg", "TCP server started", "port", t.config.Port)
return nil return nil
} }
} }
@ -214,6 +190,16 @@ func (t *TCPSource) GetStats() SourceStats {
netLimitStats = t.netLimiter.GetStats() netLimitStats = t.netLimiter.GetStats()
} }
var authStats map[string]any
if t.authenticator != nil {
authStats = map[string]any{
"enabled": true,
"type": t.config.Auth.Type,
"failures": t.authFailures.Load(),
"successes": t.authSuccesses.Load(),
}
}
return SourceStats{ return SourceStats{
Type: "tcp", Type: "tcp",
TotalEntries: t.totalEntries.Load(), TotalEntries: t.totalEntries.Load(),
@ -221,49 +207,41 @@ func (t *TCPSource) GetStats() SourceStats {
StartTime: t.startTime, StartTime: t.startTime,
LastEntryTime: lastEntry, LastEntryTime: lastEntry,
Details: map[string]any{ Details: map[string]any{
"port": t.port, "port": t.config.Port,
"active_connections": t.activeConns.Load(), "active_connections": t.activeConns.Load(),
"invalid_entries": t.invalidEntries.Load(), "invalid_entries": t.invalidEntries.Load(),
"net_limit": netLimitStats, "net_limit": netLimitStats,
"auth": authStats,
}, },
} }
} }
func (t *TCPSource) publish(entry core.LogEntry) bool { func (t *TCPSource) publish(entry core.LogEntry) {
t.mu.RLock() t.mu.RLock()
defer t.mu.RUnlock() defer t.mu.RUnlock()
t.totalEntries.Add(1) t.totalEntries.Add(1)
t.lastEntryTime.Store(entry.Time) t.lastEntryTime.Store(entry.Time)
dropped := false
for _, ch := range t.subscribers { for _, ch := range t.subscribers {
select { select {
case ch <- entry: case ch <- entry:
default: default:
dropped = true
t.droppedEntries.Add(1) t.droppedEntries.Add(1)
t.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
"component", "tcp_source")
} }
} }
if dropped {
t.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
"component", "tcp_source")
}
return true
} }
// Represents a connected TCP client // Represents a connected TCP client
type tcpClient struct { type tcpClient struct {
conn gnet.Conn conn gnet.Conn
buffer *bytes.Buffer buffer *bytes.Buffer
authenticated bool authenticated bool
authTimeout time.Time authTimeout time.Time
session *auth.Session session *auth.Session
maxBufferSeen int maxBufferSeen int
cumulativeEncrypted int64
scramState *scram.HandshakeState
} }
// Handles gnet events // Handles gnet events
@ -282,7 +260,7 @@ func (s *tcpSourceServer) OnBoot(eng gnet.Engine) gnet.Action {
s.source.logger.Debug("msg", "TCP source server booted", s.source.logger.Debug("msg", "TCP source server booted",
"component", "tcp_source", "component", "tcp_source",
"port", s.source.port) "port", s.source.config.Port)
return gnet.None return gnet.None
} }
@ -303,6 +281,16 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
return nil, gnet.Close return nil, gnet.Close
} }
// Check if connection is allowed
ip := tcpAddr.IP
if ip.To4() == nil {
// Reject IPv6
s.source.logger.Warn("msg", "IPv6 connection rejected",
"component", "tcp_source",
"remote_addr", remoteAddr)
return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close
}
if !s.source.netLimiter.CheckTCP(tcpAddr) { if !s.source.netLimiter.CheckTCP(tcpAddr) {
s.source.logger.Warn("msg", "TCP connection net limited", s.source.logger.Warn("msg", "TCP connection net limited",
"component", "tcp_source", "component", "tcp_source",
@ -311,49 +299,66 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
} }
// Track connection // Track connection
s.source.netLimiter.AddConnection(remoteAddr) // s.source.netLimiter.AddConnection(remoteAddr)
if !s.source.netLimiter.TrackConnection(ip.String(), "", "") {
s.source.logger.Warn("msg", "TCP connection limit exceeded",
"component", "tcp_source",
"remote_addr", remoteAddr)
return nil, gnet.Close
}
} }
// Create client state // Create client state
client := &tcpClient{ client := &tcpClient{
conn: c, conn: c,
buffer: bytes.NewBuffer(nil), buffer: bytes.NewBuffer(nil),
authTimeout: time.Now().Add(30 * time.Second), authenticated: s.source.authenticator == nil, // No auth = auto authenticated
authenticated: s.source.scramManager == nil, }
if s.source.authenticator != nil {
// Set auth timeout
client.authTimeout = time.Now().Add(10 * time.Second)
// Send auth challenge for SCRAM
if s.source.config.Auth.Type == "scram" {
out = []byte("AUTH_REQUIRED\n")
}
} }
s.mu.Lock() s.mu.Lock()
s.clients[c] = client s.clients[c] = client
s.mu.Unlock() s.mu.Unlock()
newCount := s.source.activeConns.Add(1) s.source.activeConns.Add(1)
s.source.logger.Debug("msg", "TCP connection opened", s.source.logger.Debug("msg", "TCP connection opened",
"component", "tcp_source", "component", "tcp_source",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
"active_connections", newCount, "auth_enabled", s.source.authenticator != nil)
"requires_auth", s.source.scramManager != nil)
return nil, gnet.None return out, gnet.None
} }
func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action { func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
remoteAddr := c.RemoteAddr().String() remoteAddr := c.RemoteAddr().String()
// Untrack connection
if s.source.netLimiter != nil {
if tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr); err == nil {
s.source.netLimiter.ReleaseConnection(tcpAddr.IP.String(), "", "")
// s.source.netLimiter.RemoveConnection(remoteAddr)
}
}
// Remove client state // Remove client state
s.mu.Lock() s.mu.Lock()
delete(s.clients, c) delete(s.clients, c)
s.mu.Unlock() s.mu.Unlock()
// Remove connection tracking newConnectionCount := s.source.activeConns.Add(-1)
if s.source.netLimiter != nil {
s.source.netLimiter.RemoveConnection(remoteAddr)
}
newCount := s.source.activeConns.Add(-1)
s.source.logger.Debug("msg", "TCP connection closed", s.source.logger.Debug("msg", "TCP connection closed",
"component", "tcp_source", "component", "tcp_source",
"remote_addr", remoteAddr, "remote_addr", remoteAddr,
"active_connections", newCount, "active_connections", newConnectionCount,
"error", err) "error", err)
return gnet.None return gnet.None
} }
@ -383,6 +388,8 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
s.source.logger.Warn("msg", "Authentication timeout", s.source.logger.Warn("msg", "Authentication timeout",
"component", "tcp_source", "component", "tcp_source",
"remote_addr", c.RemoteAddr().String()) "remote_addr", c.RemoteAddr().String())
s.source.authFailures.Add(1)
c.AsyncWrite([]byte("AUTH_TIMEOUT\n"), nil)
return gnet.Close return gnet.Close
} }
@ -392,7 +399,12 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
client.buffer.Write(data) client.buffer.Write(data)
// Look for complete line // Use centralized SCRAM protocol handler
if s.source.scramProtocolHandler == nil {
s.source.scramProtocolHandler = auth.NewScramProtocolHandler(s.source.scramManager, s.source.logger)
}
// Look for complete auth line
for { for {
idx := bytes.IndexByte(client.buffer.Bytes(), '\n') idx := bytes.IndexByte(client.buffer.Bytes(), '\n')
if idx < 0 { if idx < 0 {
@ -402,85 +414,44 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
line := client.buffer.Bytes()[:idx] line := client.buffer.Bytes()[:idx]
client.buffer.Next(idx + 1) client.buffer.Next(idx + 1)
// Parse SCRAM messages // Process auth message through handler
parts := strings.Fields(string(line)) authenticated, session, err := s.source.scramProtocolHandler.HandleAuthMessage(line, c)
if len(parts) < 2 { if err != nil {
c.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil) s.source.logger.Warn("msg", "SCRAM authentication failed",
return gnet.Close "component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"error", err)
if strings.Contains(err.Error(), "unknown command") {
return gnet.Close
}
// Continue for other errors (might be multi-step auth)
} }
switch parts[0] { if authenticated && session != nil {
case "SCRAM-FIRST":
// Parse ClientFirst JSON
var clientFirst scram.ClientFirst
if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil {
c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
return gnet.Close
}
// Process with SCRAM server
serverFirst, err := s.source.scramManager.HandleClientFirst(&clientFirst)
if err != nil {
// Still send challenge to prevent user enumeration
response, _ := json.Marshal(serverFirst)
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
return gnet.Close
}
// Send ServerFirst challenge
response, _ := json.Marshal(serverFirst)
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
case "SCRAM-PROOF":
// Parse ClientFinal JSON
var clientFinal scram.ClientFinal
if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil {
c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
return gnet.Close
}
// Verify proof
serverFinal, err := s.source.scramManager.HandleClientFinal(&clientFinal)
if err != nil {
s.source.logger.Warn("msg", "SCRAM authentication failed",
"component", "tcp_source",
"remote_addr", c.RemoteAddr().String(),
"error", err)
c.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil)
return gnet.Close
}
// Authentication successful // Authentication successful
s.mu.Lock() s.mu.Lock()
client.authenticated = true client.authenticated = true
client.session = &auth.Session{ client.session = session
ID: serverFinal.SessionID,
Method: "scram-sha-256",
RemoteAddr: c.RemoteAddr().String(),
CreatedAt: time.Now(),
}
s.mu.Unlock() s.mu.Unlock()
// Send ServerFinal with signature
response, _ := json.Marshal(serverFinal)
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil)
s.source.logger.Info("msg", "Client authenticated via SCRAM", s.source.logger.Info("msg", "Client authenticated via SCRAM",
"component", "tcp_source", "component", "tcp_source",
"remote_addr", c.RemoteAddr().String(), "remote_addr", c.RemoteAddr().String(),
"session_id", serverFinal.SessionID) "session_id", session.ID)
// Clear auth buffer // Clear auth buffer
client.buffer.Reset() client.buffer.Reset()
break
default:
c.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil)
return gnet.Close
} }
} }
return gnet.None return gnet.None
} }
return s.processLogData(c, client, data)
}
func (s *tcpSourceServer) processLogData(c gnet.Conn, client *tcpClient, data []byte) gnet.Action {
// Check if appending the new data would exceed the client buffer limit. // Check if appending the new data would exceed the client buffer limit.
if client.buffer.Len()+len(data) > maxClientBufferSize { if client.buffer.Len()+len(data) > maxClientBufferSize {
s.source.logger.Warn("msg", "Client buffer limit exceeded, closing connection.", s.source.logger.Warn("msg", "Client buffer limit exceeded, closing connection.",
@ -572,47 +543,3 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
return gnet.None return gnet.None
} }
func (t *TCPSource) InitSCRAMManager(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type != "scram" || authCfg.ScramAuth == nil {
return
}
t.scramManager = scram.NewScramManager()
// Load users from SCRAM config
for _, user := range authCfg.ScramAuth.Users {
storedKey, _ := base64.StdEncoding.DecodeString(user.StoredKey)
serverKey, _ := base64.StdEncoding.DecodeString(user.ServerKey)
salt, _ := base64.StdEncoding.DecodeString(user.Salt)
cred := &scram.Credential{
Username: user.Username,
StoredKey: storedKey,
ServerKey: serverKey,
Salt: salt,
ArgonTime: user.ArgonTime,
ArgonMemory: user.ArgonMemory,
ArgonThreads: user.ArgonThreads,
}
t.scramManager.AddCredential(cred)
}
t.logger.Info("msg", "SCRAM authentication configured",
"component", "tcp_source",
"users", len(authCfg.ScramAuth.Users))
}
// Configure TCP source auth
func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) {
if authCfg == nil || authCfg.Type == "none" {
return
}
// Initialize SCRAM manager
if authCfg.Type == "scram" {
t.InitSCRAMManager(authCfg)
t.logger.Info("msg", "SCRAM authentication configured for TCP source",
"component", "tcp_source")
}
}