v0.7.0 major configuration and sub-command restructuring, not tested, docs and default config outdated
This commit is contained in:
13
go.mod
13
go.mod
@ -3,14 +3,12 @@ module logwisp
|
||||
go 1.25.1
|
||||
|
||||
require (
|
||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3
|
||||
github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6
|
||||
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2
|
||||
github.com/panjf2000/gnet/v2 v2.9.4
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/valyala/fasthttp v1.66.0
|
||||
golang.org/x/crypto v0.42.0
|
||||
golang.org/x/term v0.35.0
|
||||
golang.org/x/time v0.13.0
|
||||
github.com/valyala/fasthttp v1.67.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/term v0.36.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -20,12 +18,11 @@ require (
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/panjf2000/ants/v2 v2.11.3 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap v1.27.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
22
go.sum
22
go.sum
@ -8,8 +8,8 @@ github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGL
|
||||
github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3 h1:+RwUb7dUz9mGdUSW+E0WuqJgTVg1yFnPb94Wyf5ma/0=
|
||||
github.com/lixenwraith/config v0.0.0-20250908085506-537a4d49d2c3/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
|
||||
github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6 h1:G9qP8biXBT6bwBOjEe1tZwjA0gPuB5DC+fLBRXDNXqo=
|
||||
github.com/lixenwraith/config v0.0.0-20251003140149-580459b815f6/go.mod h1:I7ddNPT8MouXXz/ae4DQfBKMq5EisxdDLRX0C7Dv4O0=
|
||||
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2 h1:9Qf+BR83sKjok2E1Nct+3Sfzoj2dLGwC/zyQDVNmmqs=
|
||||
github.com/lixenwraith/log v0.0.0-20250929145347-45cc8a5099c2/go.mod h1:E7REMCVTr6DerzDtd2tpEEaZ9R9nduyAIKQFOqHqKr0=
|
||||
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
|
||||
@ -22,8 +22,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.66.0 h1:M87A0Z7EayeyNaV6pfO3tUTUiYO0dZfEJnRGXTVNuyU=
|
||||
github.com/valyala/fasthttp v1.66.0/go.mod h1:Y4eC+zwoocmXSVCB1JmhNbYtS7tZPRI2ztPB72EVObs=
|
||||
github.com/valyala/fasthttp v1.67.0 h1:tqKlJMUP6iuNG8hGjK/s9J4kadH7HLV4ijEcPGsezac=
|
||||
github.com/valyala/fasthttp v1.67.0/go.mod h1:qYSIpqt/0XNmShgo/8Aq8E3UYWVVwNS2QYmzd8WIEPM=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
@ -32,16 +32,14 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
|
||||
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
|
||||
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
|
||||
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
|
||||
@ -24,7 +24,7 @@ func bootstrapService(ctx context.Context, cfg *config.Config) (*service.Service
|
||||
logger.Info("msg", "Initializing pipeline", "pipeline", pipelineCfg.Name)
|
||||
|
||||
// Create the pipeline
|
||||
if err := svc.NewPipeline(pipelineCfg); err != nil {
|
||||
if err := svc.NewPipeline(&pipelineCfg); err != nil {
|
||||
logger.Error("msg", "Failed to create pipeline",
|
||||
"pipeline", pipelineCfg.Name,
|
||||
"error", err)
|
||||
|
||||
@ -1,138 +0,0 @@
|
||||
// FILE: src/cmd/logwisp/commands.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/tls"
|
||||
"logwisp/src/internal/version"
|
||||
)
|
||||
|
||||
// Handles subcommand routing before main app initialization
|
||||
type CommandRouter struct {
|
||||
commands map[string]CommandHandler
|
||||
}
|
||||
|
||||
// Defines the interface for subcommands
|
||||
type CommandHandler interface {
|
||||
Execute(args []string) error
|
||||
Description() string
|
||||
}
|
||||
|
||||
// Creates and initializes the command router
|
||||
func NewCommandRouter() *CommandRouter {
|
||||
router := &CommandRouter{
|
||||
commands: make(map[string]CommandHandler),
|
||||
}
|
||||
|
||||
// Register available commands
|
||||
router.commands["auth"] = &authCommand{}
|
||||
router.commands["version"] = &versionCommand{}
|
||||
router.commands["help"] = &helpCommand{}
|
||||
router.commands["tls"] = &tlsCommand{}
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// Checks for and executes subcommands
|
||||
func (r *CommandRouter) Route(args []string) error {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for help flags anywhere in args
|
||||
for _, arg := range args[1:] { // Skip program name
|
||||
if arg == "-h" || arg == "--help" || arg == "help" {
|
||||
// Show main help and exit regardless of other flags
|
||||
r.commands["help"].Execute(nil)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for commands
|
||||
if len(args) > 1 {
|
||||
cmdName := args[1]
|
||||
|
||||
if handler, exists := r.commands[cmdName]; exists {
|
||||
if err := handler.Execute(args[2:]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Check if it looks like a mistyped command (not a flag)
|
||||
if cmdName[0] != '-' {
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", cmdName)
|
||||
fmt.Fprintln(os.Stderr, "\nAvailable commands:")
|
||||
r.ShowCommands()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Displays available subcommands
|
||||
func (r *CommandRouter) ShowCommands() {
|
||||
fmt.Fprintln(os.Stderr, " auth Generate authentication credentials")
|
||||
fmt.Fprintln(os.Stderr, " tls Generate TLS certificates")
|
||||
fmt.Fprintln(os.Stderr, " version Show version information")
|
||||
fmt.Fprintln(os.Stderr, " help Display help information")
|
||||
fmt.Fprintln(os.Stderr, "\nUse 'logwisp <command> --help' for command-specific help")
|
||||
}
|
||||
|
||||
// TODO: Future: refactor with a new command interface
|
||||
type helpCommand struct{}
|
||||
|
||||
func (c *helpCommand) Execute(args []string) error {
|
||||
// Check if help is requested for a specific command
|
||||
if len(args) > 0 {
|
||||
// TODO: Future: show command-specific help
|
||||
// For now, just show general help
|
||||
}
|
||||
fmt.Print(helpText)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *helpCommand) Description() string {
|
||||
return "Display help information"
|
||||
}
|
||||
|
||||
// authCommand wrapper
|
||||
type authCommand struct{}
|
||||
|
||||
func (c *authCommand) Execute(args []string) error {
|
||||
gen := auth.NewAuthGeneratorCommand()
|
||||
return gen.Execute(args)
|
||||
}
|
||||
|
||||
func (c *authCommand) Description() string {
|
||||
return "Generate authentication credentials (passwords, tokens)"
|
||||
}
|
||||
|
||||
// versionCommand wrapper
|
||||
type versionCommand struct{}
|
||||
|
||||
func (c *versionCommand) Execute(args []string) error {
|
||||
fmt.Println(version.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *versionCommand) Description() string {
|
||||
return "Show version information"
|
||||
}
|
||||
|
||||
// tlsCommand wrapper
|
||||
type tlsCommand struct{}
|
||||
|
||||
func (c *tlsCommand) Execute(args []string) error {
|
||||
gen := tls.NewCertGeneratorCommand()
|
||||
return gen.Execute(args)
|
||||
}
|
||||
|
||||
func (c *tlsCommand) Description() string {
|
||||
return "Generate TLS certificates (CA, server, client)"
|
||||
}
|
||||
355
src/cmd/logwisp/commands/auth.go
Normal file
355
src/cmd/logwisp/commands/auth.go
Normal file
@ -0,0 +1,355 @@
|
||||
// FILE: src/cmd/logwisp/commands/auth.go
|
||||
package commands
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"logwisp/src/internal/auth"
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type AuthCommand struct {
|
||||
output io.Writer
|
||||
errOut io.Writer
|
||||
}
|
||||
|
||||
func NewAuthCommand() *AuthCommand {
|
||||
return &AuthCommand{
|
||||
output: os.Stdout,
|
||||
errOut: os.Stderr,
|
||||
}
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) Execute(args []string) error {
|
||||
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
|
||||
cmd.SetOutput(ac.errOut)
|
||||
|
||||
var (
|
||||
// User credentials
|
||||
username = cmd.String("u", "", "Username")
|
||||
usernameLong = cmd.String("user", "", "Username")
|
||||
password = cmd.String("p", "", "Password (will prompt if not provided)")
|
||||
passwordLong = cmd.String("password", "", "Password (will prompt if not provided)")
|
||||
|
||||
// Auth type selection (multiple ways to specify)
|
||||
authType = cmd.String("t", "", "Auth type: basic, scram, or token")
|
||||
authTypeLong = cmd.String("type", "", "Auth type: basic, scram, or token")
|
||||
useScram = cmd.Bool("s", false, "Generate SCRAM credentials (TCP)")
|
||||
useScramLong = cmd.Bool("scram", false, "Generate SCRAM credentials (TCP)")
|
||||
useBasic = cmd.Bool("b", false, "Generate basic auth credentials (HTTP)")
|
||||
useBasicLong = cmd.Bool("basic", false, "Generate basic auth credentials (HTTP)")
|
||||
|
||||
// Token generation
|
||||
genToken = cmd.Bool("k", false, "Generate random bearer token")
|
||||
genTokenLong = cmd.Bool("token", false, "Generate random bearer token")
|
||||
tokenLen = cmd.Int("l", 32, "Token length in bytes")
|
||||
tokenLenLong = cmd.Int("length", 32, "Token length in bytes")
|
||||
|
||||
// Migration option
|
||||
migrate = cmd.Bool("m", false, "Convert basic auth PHC to SCRAM")
|
||||
migrateLong = cmd.Bool("migrate", false, "Convert basic auth PHC to SCRAM")
|
||||
phcHash = cmd.String("phc", "", "PHC hash to migrate (required with --migrate)")
|
||||
)
|
||||
|
||||
cmd.Usage = func() {
|
||||
fmt.Fprintln(ac.errOut, "Generate authentication credentials for LogWisp")
|
||||
fmt.Fprintln(ac.errOut, "\nUsage: logwisp auth [options]")
|
||||
fmt.Fprintln(ac.errOut, "\nExamples:")
|
||||
fmt.Fprintln(ac.errOut, " # Generate basic auth hash for HTTP sources/sinks")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth -u admin -b")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth --user=admin --basic")
|
||||
fmt.Fprintln(ac.errOut, " ")
|
||||
fmt.Fprintln(ac.errOut, " # Generate SCRAM credentials for TCP")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth -u tcpuser -s")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth --user=tcpuser --scram")
|
||||
fmt.Fprintln(ac.errOut, " ")
|
||||
fmt.Fprintln(ac.errOut, " # Generate bearer token")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth -k -l 64")
|
||||
fmt.Fprintln(ac.errOut, " logwisp auth --token --length=64")
|
||||
fmt.Fprintln(ac.errOut, "\nOptions:")
|
||||
cmd.PrintDefaults()
|
||||
}
|
||||
|
||||
if err := cmd.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for unparsed arguments
|
||||
if cmd.NArg() > 0 {
|
||||
return fmt.Errorf("unexpected argument(s): %s", strings.Join(cmd.Args(), " "))
|
||||
}
|
||||
|
||||
// Merge short and long form values
|
||||
finalUsername := coalesceString(*username, *usernameLong)
|
||||
finalPassword := coalesceString(*password, *passwordLong)
|
||||
finalAuthType := coalesceString(*authType, *authTypeLong)
|
||||
finalGenToken := coalesceBool(*genToken, *genTokenLong)
|
||||
finalTokenLen := coalesceInt(*tokenLen, *tokenLenLong, core.DefaultTokenLength)
|
||||
finalUseScram := coalesceBool(*useScram, *useScramLong)
|
||||
finalUseBasic := coalesceBool(*useBasic, *useBasicLong)
|
||||
finalMigrate := coalesceBool(*migrate, *migrateLong)
|
||||
|
||||
// Handle migration mode
|
||||
if finalMigrate {
|
||||
if *phcHash == "" || finalUsername == "" || finalPassword == "" {
|
||||
return fmt.Errorf("--migrate requires --user, --password, and --phc flags")
|
||||
}
|
||||
return ac.migrateToScram(finalUsername, finalPassword, *phcHash)
|
||||
}
|
||||
|
||||
// Determine auth type from flags
|
||||
if finalGenToken || finalAuthType == "token" {
|
||||
return ac.generateToken(finalTokenLen)
|
||||
}
|
||||
|
||||
// Determine credential type
|
||||
credType := "basic" // default
|
||||
|
||||
// Check explicit type flags
|
||||
if finalUseScram || finalAuthType == "scram" {
|
||||
credType = "scram"
|
||||
} else if finalUseBasic || finalAuthType == "basic" {
|
||||
credType = "basic"
|
||||
} else if finalAuthType != "" {
|
||||
return fmt.Errorf("invalid auth type: %s (valid: basic, scram, token)", finalAuthType)
|
||||
}
|
||||
|
||||
// Username required for password-based auth
|
||||
if finalUsername == "" {
|
||||
cmd.Usage()
|
||||
return fmt.Errorf("username required for %s auth generation", credType)
|
||||
}
|
||||
|
||||
return ac.generatePasswordHash(finalUsername, finalPassword, credType)
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) Description() string {
|
||||
return "Generate authentication credentials (passwords, tokens, SCRAM)"
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) Help() string {
|
||||
return `Auth Command - Generate authentication credentials for LogWisp
|
||||
|
||||
Usage:
|
||||
logwisp auth [options]
|
||||
|
||||
Authentication Types:
|
||||
HTTP/HTTPS Sources & Sinks (TLS required):
|
||||
- Basic Auth: Username/password with Argon2id hashing
|
||||
- Bearer Token: Random cryptographic tokens
|
||||
|
||||
TCP Sources & Sinks (No TLS):
|
||||
- SCRAM: Argon2-SCRAM-SHA256 for plaintext connections
|
||||
|
||||
Options:
|
||||
-u, --user <name> Username for credential generation
|
||||
-p, --password <pass> Password (will prompt if not provided)
|
||||
-t, --type <type> Auth type: "basic", "scram", or "token"
|
||||
-b, --basic Generate basic auth credentials (HTTP/HTTPS)
|
||||
-s, --scram Generate SCRAM credentials (TCP)
|
||||
-k, --token Generate random bearer token
|
||||
-l, --length <bytes> Token length in bytes (default: 32)
|
||||
|
||||
Examples:
|
||||
Examples:
|
||||
# Generate basic auth hash for HTTP/HTTPS (with TLS)
|
||||
logwisp auth -u admin -b
|
||||
logwisp auth --user=admin --basic
|
||||
|
||||
# Generate SCRAM credentials for TCP (without TLS)
|
||||
logwisp auth -u tcpuser -s
|
||||
logwisp auth --user=tcpuser --type=scram
|
||||
|
||||
# Generate 64-byte bearer token
|
||||
logwisp auth -k -l 64
|
||||
logwisp auth --token --length=64
|
||||
|
||||
# Convert existing basic auth to SCRAM (HTTPS to TCP conversion)
|
||||
logwisp auth -u admin -m --phc='$argon2id$v=19$m=65536...' --password='secret'
|
||||
|
||||
Output:
|
||||
The command outputs configuration snippets ready to paste into logwisp.toml
|
||||
and the raw credential values for external auth files.
|
||||
|
||||
Security Notes:
|
||||
- Basic auth and tokens require TLS encryption for HTTP connections
|
||||
- SCRAM provides authentication but NOT encryption for TCP connections
|
||||
- Use strong passwords (12+ characters with mixed case, numbers, symbols)
|
||||
- Store credentials securely and never commit them to version control
|
||||
`
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) generatePasswordHash(username, password, credType string) error {
|
||||
// Get password if not provided
|
||||
if password == "" {
|
||||
var err error
|
||||
password, err = ac.promptForPassword()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch credType {
|
||||
case "basic":
|
||||
return ac.generateBasicAuth(username, password)
|
||||
case "scram":
|
||||
return ac.generateScramAuth(username, password)
|
||||
default:
|
||||
return fmt.Errorf("invalid credential type: %s", credType)
|
||||
}
|
||||
}
|
||||
|
||||
// promptForPassword handles password prompting with confirmation
|
||||
func (ac *AuthCommand) promptForPassword() (string, error) {
|
||||
pass1 := ac.promptPassword("Enter password: ")
|
||||
pass2 := ac.promptPassword("Confirm password: ")
|
||||
if pass1 != pass2 {
|
||||
return "", fmt.Errorf("passwords don't match")
|
||||
}
|
||||
return pass1, nil
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) promptPassword(prompt string) string {
|
||||
fmt.Fprint(ac.errOut, prompt)
|
||||
password, err := term.ReadPassword(syscall.Stdin)
|
||||
fmt.Fprintln(ac.errOut)
|
||||
if err != nil {
|
||||
fmt.Fprintf(ac.errOut, "Failed to read password: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return string(password)
|
||||
}
|
||||
|
||||
// generateBasicAuth creates Argon2id hash for HTTP basic auth
|
||||
func (ac *AuthCommand) generateBasicAuth(username, password string) error {
|
||||
// Generate salt
|
||||
salt := make([]byte, core.Argon2SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Generate Argon2id hash
|
||||
cred, err := auth.DeriveCredential(username, password, salt,
|
||||
core.Argon2Time, core.Argon2Memory, core.Argon2Threads)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to derive credential: %w", err)
|
||||
}
|
||||
|
||||
// Output configuration snippets
|
||||
fmt.Fprintln(ac.output, "\n# Basic Auth Configuration (HTTP sources/sinks)")
|
||||
fmt.Fprintln(ac.output, "# REQUIRES HTTPS/TLS for security")
|
||||
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[pipelines.auth]")
|
||||
fmt.Fprintln(ac.output, `type = "basic"`)
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[[pipelines.auth.basic_auth.users]]")
|
||||
fmt.Fprintf(ac.output, "username = %q\n", username)
|
||||
fmt.Fprintf(ac.output, "password_hash = %q\n\n", cred.PHCHash)
|
||||
|
||||
fmt.Fprintln(ac.output, "# For external users file:")
|
||||
fmt.Fprintf(ac.output, "%s:%s\n", username, cred.PHCHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateScramAuth creates Argon2id-SCRAM-SHA256 credentials for TCP
|
||||
func (ac *AuthCommand) generateScramAuth(username, password string) error {
|
||||
// Generate salt
|
||||
salt := make([]byte, core.Argon2SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Use internal auth package to derive SCRAM credentials
|
||||
cred, err := auth.DeriveCredential(username, password, salt,
|
||||
core.Argon2Time, core.Argon2Memory, core.Argon2Threads)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to derive SCRAM credential: %w", err)
|
||||
}
|
||||
|
||||
// Output SCRAM configuration
|
||||
fmt.Fprintln(ac.output, "\n# SCRAM Auth Configuration (TCP sources/sinks)")
|
||||
fmt.Fprintln(ac.output, "# Provides authentication but NOT encryption")
|
||||
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[pipelines.auth]")
|
||||
fmt.Fprintln(ac.output, `type = "scram"`)
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]")
|
||||
fmt.Fprintf(ac.output, "username = %q\n", username)
|
||||
fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
|
||||
fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
|
||||
fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
|
||||
fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime)
|
||||
fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory)
|
||||
fmt.Fprintf(ac.output, "argon_threads = %d\n\n", cred.ArgonThreads)
|
||||
|
||||
fmt.Fprintln(ac.output, "# Note: SCRAM provides authentication only.")
|
||||
fmt.Fprintln(ac.output, "# Use TLS/mTLS for encryption if needed.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ac *AuthCommand) generateToken(length int) error {
|
||||
if length < 16 {
|
||||
fmt.Fprintln(ac.errOut, "Warning: tokens < 16 bytes are cryptographically weak")
|
||||
}
|
||||
if length > 512 {
|
||||
return fmt.Errorf("token length exceeds maximum (512 bytes)")
|
||||
}
|
||||
|
||||
token := make([]byte, length)
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
return fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
|
||||
hex := fmt.Sprintf("%x", token)
|
||||
|
||||
fmt.Fprintln(ac.output, "\n# Token Configuration")
|
||||
fmt.Fprintln(ac.output, "# Add to logwisp.toml:")
|
||||
fmt.Fprintf(ac.output, "tokens = [%q]\n\n", b64)
|
||||
|
||||
fmt.Fprintln(ac.output, "# Generated Token:")
|
||||
fmt.Fprintf(ac.output, "Base64: %s\n", b64)
|
||||
fmt.Fprintf(ac.output, "Hex: %s\n", hex)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateToScram converts basic auth PHC hash to SCRAM credentials
|
||||
func (ac *AuthCommand) migrateToScram(username, password, phcHash string) error {
|
||||
// CHANGED: Moved from internal/auth to CLI command layer
|
||||
cred, err := auth.MigrateFromPHC(username, password, phcHash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
// Output SCRAM configuration (reuse format from generateScramAuth)
|
||||
fmt.Fprintln(ac.output, "\n# Migrated SCRAM Credentials")
|
||||
fmt.Fprintln(ac.output, "# Add to logwisp.toml under [[pipelines]]:")
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[pipelines.auth]")
|
||||
fmt.Fprintln(ac.output, `type = "scram"`)
|
||||
fmt.Fprintln(ac.output, "")
|
||||
fmt.Fprintln(ac.output, "[[pipelines.auth.scram_auth.users]]")
|
||||
fmt.Fprintf(ac.output, "username = %q\n", username)
|
||||
fmt.Fprintf(ac.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
|
||||
fmt.Fprintf(ac.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
|
||||
fmt.Fprintf(ac.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
|
||||
fmt.Fprintf(ac.output, "argon_time = %d\n", cred.ArgonTime)
|
||||
fmt.Fprintf(ac.output, "argon_memory = %d\n", cred.ArgonMemory)
|
||||
fmt.Fprintf(ac.output, "argon_threads = %d\n", cred.ArgonThreads)
|
||||
|
||||
return nil
|
||||
}
|
||||
123
src/cmd/logwisp/commands/help.go
Normal file
123
src/cmd/logwisp/commands/help.go
Normal file
@ -0,0 +1,123 @@
|
||||
// FILE: src/cmd/logwisp/commands/help.go
|
||||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const generalHelpTemplate = `LogWisp: A flexible log transport and processing tool.
|
||||
|
||||
Usage:
|
||||
logwisp [command] [options]
|
||||
logwisp [options]
|
||||
|
||||
Commands:
|
||||
%s
|
||||
|
||||
Application Options:
|
||||
-c, --config <path> Path to configuration file (default: logwisp.toml)
|
||||
-h, --help Display this help message and exit
|
||||
-v, --version Display version information and exit
|
||||
-b, --background Run LogWisp in the background as a daemon
|
||||
-q, --quiet Suppress all console output, including errors
|
||||
|
||||
Runtime Options:
|
||||
--disable-status-reporter Disable the periodic status reporter
|
||||
--config-auto-reload Enable config reload on file change
|
||||
|
||||
For command-specific help:
|
||||
logwisp help <command>
|
||||
logwisp <command> --help
|
||||
|
||||
Configuration Sources (Precedence: CLI > Env > File > Defaults):
|
||||
- CLI flags override all other settings
|
||||
- Environment variables override file settings
|
||||
- TOML configuration file is the primary method
|
||||
|
||||
Examples:
|
||||
# Generate password for admin user
|
||||
logwisp auth -u admin
|
||||
|
||||
# Start service with custom config
|
||||
logwisp -c /etc/logwisp/prod.toml
|
||||
|
||||
# Run in background with config reload
|
||||
logwisp -b --config-auto-reload
|
||||
|
||||
For detailed configuration options, please refer to the documentation.
|
||||
`
|
||||
|
||||
// HelpCommand handles help display
|
||||
type HelpCommand struct {
|
||||
router *CommandRouter
|
||||
}
|
||||
|
||||
// NewHelpCommand creates a new help command
|
||||
func NewHelpCommand(router *CommandRouter) *HelpCommand {
|
||||
return &HelpCommand{router: router}
|
||||
}
|
||||
|
||||
// Execute displays help information
|
||||
func (c *HelpCommand) Execute(args []string) error {
|
||||
// Check if help is requested for a specific command
|
||||
if len(args) > 0 && args[0] != "" {
|
||||
cmdName := args[0]
|
||||
|
||||
if handler, exists := c.router.GetCommand(cmdName); exists {
|
||||
fmt.Print(handler.Help())
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown command: %s", cmdName)
|
||||
}
|
||||
|
||||
// Display general help with command list
|
||||
fmt.Printf(generalHelpTemplate, c.formatCommandList())
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatCommandList creates a formatted list of available commands
|
||||
func (c *HelpCommand) formatCommandList() string {
|
||||
commands := c.router.GetCommands()
|
||||
|
||||
// Sort command names for consistent output
|
||||
names := make([]string, 0, len(commands))
|
||||
maxLen := 0
|
||||
for name := range commands {
|
||||
names = append(names, name)
|
||||
if len(name) > maxLen {
|
||||
maxLen = len(name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
// Format each command with aligned descriptions
|
||||
var lines []string
|
||||
for _, name := range names {
|
||||
handler := commands[name]
|
||||
padding := strings.Repeat(" ", maxLen-len(name)+2)
|
||||
lines = append(lines, fmt.Sprintf(" %s%s%s", name, padding, handler.Description()))
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (c *HelpCommand) Description() string {
|
||||
return "Display help information"
|
||||
}
|
||||
|
||||
func (c *HelpCommand) Help() string {
|
||||
return `Help Command - Display help information
|
||||
|
||||
Usage:
|
||||
logwisp help Show general help
|
||||
logwisp help <command> Show help for a specific command
|
||||
|
||||
Examples:
|
||||
logwisp help # Show general help
|
||||
logwisp help auth # Show auth command help
|
||||
logwisp auth --help # Alternative way to get command help
|
||||
`
|
||||
}
|
||||
118
src/cmd/logwisp/commands/router.go
Normal file
118
src/cmd/logwisp/commands/router.go
Normal file
@ -0,0 +1,118 @@
|
||||
// FILE: src/cmd/logwisp/commands/router.go
|
||||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Handler defines the interface for subcommands
|
||||
type Handler interface {
|
||||
Execute(args []string) error
|
||||
Description() string
|
||||
Help() string
|
||||
}
|
||||
|
||||
// CommandRouter handles subcommand routing before main app initialization
|
||||
type CommandRouter struct {
|
||||
commands map[string]Handler
|
||||
}
|
||||
|
||||
// NewCommandRouter creates and initializes the command router
|
||||
func NewCommandRouter() *CommandRouter {
|
||||
router := &CommandRouter{
|
||||
commands: make(map[string]Handler),
|
||||
}
|
||||
|
||||
// Register available commands
|
||||
router.commands["auth"] = NewAuthCommand()
|
||||
router.commands["tls"] = NewTLSCommand()
|
||||
router.commands["version"] = NewVersionCommand()
|
||||
router.commands["help"] = NewHelpCommand(router)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// Route checks for and executes subcommands
|
||||
func (r *CommandRouter) Route(args []string) (bool, error) {
|
||||
if len(args) < 2 {
|
||||
return false, nil // No command specified, let main app continue
|
||||
}
|
||||
|
||||
cmdName := args[1]
|
||||
|
||||
// Special case: help flag at any position shows general help
|
||||
for _, arg := range args[1:] {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
// If it's after a valid command, show command-specific help
|
||||
if handler, exists := r.commands[cmdName]; exists && cmdName != "help" {
|
||||
fmt.Print(handler.Help())
|
||||
return true, nil
|
||||
}
|
||||
// Otherwise show general help
|
||||
return true, r.commands["help"].Execute(nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a known command
|
||||
handler, exists := r.commands[cmdName]
|
||||
if !exists {
|
||||
// Check if it looks like a mistyped command (not a flag)
|
||||
if cmdName[0] != '-' {
|
||||
return false, fmt.Errorf("unknown command: %s\n\nRun 'logwisp help' for usage", cmdName)
|
||||
}
|
||||
// It's a flag, let main app handle it
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Execute the command
|
||||
return true, handler.Execute(args[2:])
|
||||
}
|
||||
|
||||
// GetCommand returns a command handler by name
|
||||
func (r *CommandRouter) GetCommand(name string) (Handler, bool) {
|
||||
cmd, exists := r.commands[name]
|
||||
return cmd, exists
|
||||
}
|
||||
|
||||
// GetCommands returns all registered commands
|
||||
func (r *CommandRouter) GetCommands() map[string]Handler {
|
||||
return r.commands
|
||||
}
|
||||
|
||||
// ShowCommands displays available subcommands
|
||||
func (r *CommandRouter) ShowCommands() {
|
||||
for name, handler := range r.commands {
|
||||
fmt.Fprintf(os.Stderr, " %-10s %s\n", name, handler.Description())
|
||||
}
|
||||
fmt.Fprintln(os.Stderr, "\nUse 'logwisp <command> --help' for command-specific help")
|
||||
}
|
||||
|
||||
// Helper functions to merge short and long options
|
||||
func coalesceString(values ...string) string {
|
||||
for _, v := range values {
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func coalesceInt(primary, secondary, defaultVal int) int {
|
||||
if primary != defaultVal {
|
||||
return primary
|
||||
}
|
||||
if secondary != defaultVal {
|
||||
return secondary
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func coalesceBool(values ...bool) bool {
|
||||
for _, v := range values {
|
||||
if v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
// FILE: src/internal/tls/generator.go
|
||||
package tls
|
||||
// FILE: src/cmd/logwisp/commands/tls.go
|
||||
package commands
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@ -17,40 +17,50 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type CertGeneratorCommand struct {
|
||||
type TLSCommand struct {
|
||||
output io.Writer
|
||||
errOut io.Writer
|
||||
}
|
||||
|
||||
func NewCertGeneratorCommand() *CertGeneratorCommand {
|
||||
return &CertGeneratorCommand{
|
||||
func NewTLSCommand() *TLSCommand {
|
||||
return &TLSCommand{
|
||||
output: os.Stdout,
|
||||
errOut: os.Stderr,
|
||||
}
|
||||
}
|
||||
|
||||
func (cg *CertGeneratorCommand) Execute(args []string) error {
|
||||
func (tc *TLSCommand) Execute(args []string) error {
|
||||
cmd := flag.NewFlagSet("tls", flag.ContinueOnError)
|
||||
cmd.SetOutput(cg.errOut)
|
||||
cmd.SetOutput(tc.errOut)
|
||||
|
||||
// Subcommands
|
||||
// Certificate type flags
|
||||
var (
|
||||
genCA = cmd.Bool("ca", false, "Generate CA certificate")
|
||||
genServer = cmd.Bool("server", false, "Generate server certificate")
|
||||
genClient = cmd.Bool("client", false, "Generate client certificate")
|
||||
selfSign = cmd.Bool("self-signed", false, "Generate self-signed certificate")
|
||||
|
||||
// Common options
|
||||
// Common options - short forms
|
||||
commonName = cmd.String("cn", "", "Common name (required)")
|
||||
org = cmd.String("org", "LogWisp", "Organization")
|
||||
country = cmd.String("country", "US", "Country code")
|
||||
validDays = cmd.Int("days", 365, "Validity period in days")
|
||||
keySize = cmd.Int("bits", 2048, "RSA key size")
|
||||
org = cmd.String("o", "LogWisp", "Organization")
|
||||
country = cmd.String("c", "US", "Country code")
|
||||
validDays = cmd.Int("d", 365, "Validity period in days")
|
||||
keySize = cmd.Int("b", 2048, "RSA key size")
|
||||
|
||||
// Server/Client specific
|
||||
hosts = cmd.String("hosts", "", "Comma-separated hostnames/IPs (server cert)")
|
||||
caFile = cmd.String("ca-cert", "", "CA certificate file (for signing)")
|
||||
caKeyFile = cmd.String("ca-key", "", "CA key file (for signing)")
|
||||
// Common options - long forms
|
||||
commonNameLong = cmd.String("common-name", "", "Common name (required)")
|
||||
orgLong = cmd.String("org", "LogWisp", "Organization")
|
||||
countryLong = cmd.String("country", "US", "Country code")
|
||||
validDaysLong = cmd.Int("days", 365, "Validity period in days")
|
||||
keySizeLong = cmd.Int("bits", 2048, "RSA key size")
|
||||
|
||||
// Server/Client specific - short forms
|
||||
hosts = cmd.String("h", "", "Comma-separated hostnames/IPs")
|
||||
caFile = cmd.String("ca-cert", "", "CA certificate file")
|
||||
caKey = cmd.String("ca-key", "", "CA key file")
|
||||
|
||||
// Server/Client specific - long forms
|
||||
hostsLong = cmd.String("hosts", "", "Comma-separated hostnames/IPs")
|
||||
|
||||
// Output files
|
||||
certOut = cmd.String("cert-out", "", "Output certificate file")
|
||||
@ -58,51 +68,135 @@ func (cg *CertGeneratorCommand) Execute(args []string) error {
|
||||
)
|
||||
|
||||
cmd.Usage = func() {
|
||||
fmt.Fprintln(cg.errOut, "Generate TLS certificates for LogWisp")
|
||||
fmt.Fprintln(cg.errOut, "\nUsage: logwisp tls [options]")
|
||||
fmt.Fprintln(cg.errOut, "\nExamples:")
|
||||
fmt.Fprintln(cg.errOut, " # Generate self-signed certificate")
|
||||
fmt.Fprintln(cg.errOut, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1")
|
||||
fmt.Fprintln(cg.errOut, " ")
|
||||
fmt.Fprintln(cg.errOut, " # Generate CA certificate")
|
||||
fmt.Fprintln(cg.errOut, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key")
|
||||
fmt.Fprintln(cg.errOut, " ")
|
||||
fmt.Fprintln(cg.errOut, " # Generate server certificate signed by CA")
|
||||
fmt.Fprintln(cg.errOut, " logwisp tls --server --cn server.example.com --hosts server.example.com \\")
|
||||
fmt.Fprintln(cg.errOut, " --ca-cert ca.crt --ca-key ca.key")
|
||||
fmt.Fprintln(cg.errOut, "\nOptions:")
|
||||
fmt.Fprintln(tc.errOut, "Generate TLS certificates for LogWisp")
|
||||
fmt.Fprintln(tc.errOut, "\nUsage: logwisp tls [options]")
|
||||
fmt.Fprintln(tc.errOut, "\nExamples:")
|
||||
fmt.Fprintln(tc.errOut, " # Generate self-signed certificate")
|
||||
fmt.Fprintln(tc.errOut, " logwisp tls --self-signed --cn localhost --hosts localhost,127.0.0.1")
|
||||
fmt.Fprintln(tc.errOut, " ")
|
||||
fmt.Fprintln(tc.errOut, " # Generate CA certificate")
|
||||
fmt.Fprintln(tc.errOut, " logwisp tls --ca --cn \"LogWisp CA\" --cert-out ca.crt --key-out ca.key")
|
||||
fmt.Fprintln(tc.errOut, " ")
|
||||
fmt.Fprintln(tc.errOut, " # Generate server certificate signed by CA")
|
||||
fmt.Fprintln(tc.errOut, " logwisp tls --server --cn server.example.com --hosts server.example.com \\")
|
||||
fmt.Fprintln(tc.errOut, " --ca-cert ca.crt --ca-key ca.key")
|
||||
fmt.Fprintln(tc.errOut, "\nOptions:")
|
||||
cmd.PrintDefaults()
|
||||
fmt.Fprintln(cg.errOut)
|
||||
fmt.Fprintln(tc.errOut)
|
||||
}
|
||||
|
||||
if err := cmd.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for unparsed arguments
|
||||
if cmd.NArg() > 0 {
|
||||
return fmt.Errorf("unexpected argument(s): %s", strings.Join(cmd.Args(), " "))
|
||||
}
|
||||
|
||||
// Merge short and long options
|
||||
finalCN := coalesceString(*commonName, *commonNameLong)
|
||||
finalOrg := coalesceString(*org, *orgLong, "LogWisp")
|
||||
finalCountry := coalesceString(*country, *countryLong, "US")
|
||||
finalDays := coalesceInt(*validDays, *validDaysLong, 365)
|
||||
finalKeySize := coalesceInt(*keySize, *keySizeLong, 2048)
|
||||
finalHosts := coalesceString(*hosts, *hostsLong)
|
||||
finalCAFile := *caFile // no short form
|
||||
finalCAKey := *caKey // no short form
|
||||
finalCertOut := *certOut // no short form
|
||||
finalKeyOut := *keyOut // no short form
|
||||
|
||||
// Validate common name
|
||||
if *commonName == "" {
|
||||
if finalCN == "" {
|
||||
cmd.Usage()
|
||||
return fmt.Errorf("common name (--cn) is required")
|
||||
}
|
||||
|
||||
// Validate RSA key size
|
||||
if finalKeySize != 2048 && finalKeySize != 3072 && finalKeySize != 4096 {
|
||||
return fmt.Errorf("invalid key size: %d (valid: 2048, 3072, 4096)", finalKeySize)
|
||||
}
|
||||
|
||||
// Route to appropriate generator
|
||||
switch {
|
||||
case *genCA:
|
||||
return cg.generateCA(*commonName, *org, *country, *validDays, *keySize, *certOut, *keyOut)
|
||||
return tc.generateCA(finalCN, finalOrg, finalCountry, finalDays, finalKeySize, finalCertOut, finalKeyOut)
|
||||
case *selfSign:
|
||||
return cg.generateSelfSigned(*commonName, *org, *country, *hosts, *validDays, *keySize, *certOut, *keyOut)
|
||||
return tc.generateSelfSigned(finalCN, finalOrg, finalCountry, finalHosts, finalDays, finalKeySize, finalCertOut, finalKeyOut)
|
||||
case *genServer:
|
||||
return cg.generateServerCert(*commonName, *org, *country, *hosts, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut)
|
||||
return tc.generateServerCert(finalCN, finalOrg, finalCountry, finalHosts, finalCAFile, finalCAKey, finalDays, finalKeySize, finalCertOut, finalKeyOut)
|
||||
case *genClient:
|
||||
return cg.generateClientCert(*commonName, *org, *country, *caFile, *caKeyFile, *validDays, *keySize, *certOut, *keyOut)
|
||||
return tc.generateClientCert(finalCN, finalOrg, finalCountry, finalCAFile, finalCAKey, finalDays, finalKeySize, finalCertOut, finalKeyOut)
|
||||
default:
|
||||
cmd.Usage()
|
||||
return fmt.Errorf("specify certificate type: --ca, --self-signed, --server, or --client")
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TLSCommand) Description() string {
|
||||
return "Generate TLS certificates (CA, server, client, self-signed)"
|
||||
}
|
||||
|
||||
func (tc *TLSCommand) Help() string {
|
||||
return `TLS Command - Generate TLS certificates for LogWisp
|
||||
|
||||
Usage:
|
||||
logwisp tls [options]
|
||||
|
||||
Certificate Types:
|
||||
--ca Generate Certificate Authority (CA) certificate
|
||||
--server Generate server certificate (requires CA or self-signed)
|
||||
--client Generate client certificate (for mTLS)
|
||||
--self-signed Generate self-signed certificate (single cert for testing)
|
||||
|
||||
Common Options:
|
||||
--cn, --common-name <name> Common Name (required)
|
||||
-o, --org <organization> Organization name (default: "LogWisp")
|
||||
-c, --country <code> Country code (default: "US")
|
||||
-d, --days <number> Validity period in days (default: 365)
|
||||
-b, --bits <size> RSA key size (default: 2048)
|
||||
|
||||
Server Certificate Options:
|
||||
-h, --hosts <list> Comma-separated hostnames/IPs
|
||||
Example: "localhost,10.0.0.1,example.com"
|
||||
--ca-cert <file> CA certificate file (for signing)
|
||||
--ca-key <file> CA key file (for signing)
|
||||
|
||||
Output Options:
|
||||
--cert-out <file> Output certificate file (default: stdout)
|
||||
--key-out <file> Output private key file (default: stdout)
|
||||
|
||||
Examples:
|
||||
# Generate self-signed certificate for testing
|
||||
logwisp tls --self-signed --cn localhost --hosts "localhost,127.0.0.1" \
|
||||
--cert-out server.crt --key-out server.key
|
||||
|
||||
# Generate CA certificate
|
||||
logwisp tls --ca --cn "LogWisp CA" --days 3650 \
|
||||
--cert-out ca.crt --key-out ca.key
|
||||
|
||||
# Generate server certificate signed by CA
|
||||
logwisp tls --server --cn "logwisp.example.com" \
|
||||
--hosts "logwisp.example.com,10.0.0.100" \
|
||||
--ca-cert ca.crt --ca-key ca.key \
|
||||
--cert-out server.crt --key-out server.key
|
||||
|
||||
# Generate client certificate for mTLS
|
||||
logwisp tls --client --cn "client1" \
|
||||
--ca-cert ca.crt --ca-key ca.key \
|
||||
--cert-out client.crt --key-out client.key
|
||||
|
||||
Security Notes:
|
||||
- Keep private keys secure and never share them
|
||||
- Use 2048-bit RSA minimum, 3072 or 4096 for higher security
|
||||
- For production, use certificates from a trusted CA
|
||||
- Self-signed certificates are only for development/testing
|
||||
- Rotate certificates before expiration
|
||||
`
|
||||
}
|
||||
|
||||
// Create and manage private CA
|
||||
func (cg *CertGeneratorCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error {
|
||||
func (tc *TLSCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error {
|
||||
// Generate RSA key
|
||||
priv, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
@ -178,7 +272,7 @@ func parseHosts(hostList string) ([]string, []net.IP) {
|
||||
}
|
||||
|
||||
// Generate self-signed certificate
|
||||
func (cg *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error {
|
||||
func (tc *TLSCommand) generateSelfSigned(cn, org, country, hosts string, days, bits int, certFile, keyFile string) error {
|
||||
// 1. Generate an RSA private key with the specified bit size
|
||||
priv, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
@ -245,7 +339,7 @@ func (cg *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts strin
|
||||
}
|
||||
|
||||
// Generate server cert with CA
|
||||
func (cg *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
|
||||
func (tc *TLSCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
|
||||
caCert, caKey, err := loadCA(caFile, caKeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -308,7 +402,7 @@ func (cg *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFi
|
||||
}
|
||||
|
||||
// Generate client cert with CA
|
||||
func (cg *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
|
||||
func (tc *TLSCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error {
|
||||
caCert, caKey, err := loadCA(caFile, caKeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
41
src/cmd/logwisp/commands/version.go
Normal file
41
src/cmd/logwisp/commands/version.go
Normal file
@ -0,0 +1,41 @@
|
||||
// FILE: src/cmd/logwisp/commands/version.go
|
||||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"logwisp/src/internal/version"
|
||||
)
|
||||
|
||||
// VersionCommand handles version display
|
||||
type VersionCommand struct{}
|
||||
|
||||
// NewVersionCommand creates a new version command
|
||||
func NewVersionCommand() *VersionCommand {
|
||||
return &VersionCommand{}
|
||||
}
|
||||
|
||||
func (c *VersionCommand) Execute(args []string) error {
|
||||
fmt.Println(version.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *VersionCommand) Description() string {
|
||||
return "Show version information"
|
||||
}
|
||||
|
||||
func (c *VersionCommand) Help() string {
|
||||
return `Version Command - Show LogWisp version information
|
||||
|
||||
Usage:
|
||||
logwisp version
|
||||
logwisp -v
|
||||
logwisp --version
|
||||
|
||||
Output includes:
|
||||
- Version number
|
||||
- Build date
|
||||
- Git commit hash (if available)
|
||||
- Go version used for compilation
|
||||
`
|
||||
}
|
||||
@ -1,59 +0,0 @@
|
||||
// FILE: logwisp/src/cmd/logwisp/help.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
const helpText = `LogWisp: A flexible log transport and processing tool.
|
||||
|
||||
Usage:
|
||||
logwisp [command] [options]
|
||||
logwisp [options]
|
||||
|
||||
Commands:
|
||||
auth Generate authentication credentials
|
||||
version Display version information
|
||||
|
||||
Application Control:
|
||||
-c, --config <path> Path to configuration file (default: logwisp.toml)
|
||||
-h, --help Display this help message and exit
|
||||
-v, --version Display version information and exit
|
||||
-b, --background Run LogWisp in the background as a daemon
|
||||
-q, --quiet Suppress all console output, including errors
|
||||
|
||||
Runtime Behavior:
|
||||
--disable-status-reporter Disable the periodic status reporter
|
||||
--config-auto-reload Enable config reload on file change
|
||||
|
||||
For command-specific help:
|
||||
logwisp <command> --help
|
||||
|
||||
Configuration Sources (Precedence: CLI > Env > File > Defaults):
|
||||
- CLI flags override all other settings
|
||||
- Environment variables override file settings
|
||||
- TOML configuration file is the primary method
|
||||
|
||||
Examples:
|
||||
# Generate password for admin user
|
||||
logwisp auth -u admin
|
||||
|
||||
# Start service with custom config
|
||||
logwisp -c /etc/logwisp/prod.toml
|
||||
|
||||
# Run in background
|
||||
logwisp -b --config-auto-reload
|
||||
|
||||
For detailed configuration options, please refer to the documentation.
|
||||
`
|
||||
|
||||
// Scans arguments for help flags and prints help text if found.
|
||||
func CheckAndDisplayHelp(args []string) {
|
||||
for _, arg := range args {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
fmt.Fprint(os.Stdout, helpText)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"logwisp/src/cmd/logwisp/commands"
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/version"
|
||||
|
||||
@ -22,12 +23,22 @@ var logger *log.Logger
|
||||
func main() {
|
||||
// Handle subcommands before any config loading
|
||||
// This prevents flag conflicts with lixenwraith/config
|
||||
router := NewCommandRouter()
|
||||
if router.Route(os.Args) != nil {
|
||||
// Subcommand was handled, exit already called
|
||||
return
|
||||
router := commands.NewCommandRouter()
|
||||
handled, err := router.Route(os.Args)
|
||||
|
||||
if err != nil {
|
||||
// Command execution error
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if handled {
|
||||
// Command was successfully handled
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// No subcommand, continue with main application
|
||||
|
||||
// Emulates nohup
|
||||
signal.Ignore(syscall.SIGHUP)
|
||||
|
||||
@ -158,8 +169,6 @@ func main() {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Save configuration after graceful shutdown (no reload manager in static mode)
|
||||
saveConfigurationOnExit(cfg, nil, logger)
|
||||
logger.Info("msg", "Shutdown complete")
|
||||
case <-shutdownCtx.Done():
|
||||
logger.Error("msg", "Shutdown timeout exceeded - forcing exit")
|
||||
@ -172,9 +181,6 @@ func main() {
|
||||
// Wait for context cancellation
|
||||
<-ctx.Done()
|
||||
|
||||
// Save configuration before final shutdown, handled by reloadManager
|
||||
saveConfigurationOnExit(cfg, reloadManager, logger)
|
||||
|
||||
// Shutdown is handled by ReloadManager.Shutdown() in defer
|
||||
logger.Info("msg", "Shutdown complete")
|
||||
}
|
||||
@ -187,47 +193,3 @@ func shutdownLogger() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Saves the configuration to file on exist
|
||||
func saveConfigurationOnExit(cfg *config.Config, reloadManager *ReloadManager, logger *log.Logger) {
|
||||
// Only save if explicitly enabled and we have a valid path
|
||||
if !cfg.ConfigSaveOnExit || cfg.ConfigFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Create a context with timeout for save operation
|
||||
saveCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Perform save in goroutine to respect timeout
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
var err error
|
||||
if reloadManager != nil && reloadManager.lcfg != nil {
|
||||
// Use existing lconfig instance from reload manager
|
||||
// This ensures we save through the same configuration system
|
||||
err = reloadManager.lcfg.Save(cfg.ConfigFile)
|
||||
} else {
|
||||
// Static mode: create temporary lconfig for saving
|
||||
err = cfg.SaveToFile(cfg.ConfigFile)
|
||||
}
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.Error("msg", "Failed to save configuration on exit",
|
||||
"path", cfg.ConfigFile,
|
||||
"error", err)
|
||||
// Don't fail the exit on save error
|
||||
} else {
|
||||
logger.Info("msg", "Configuration saved successfully",
|
||||
"path", cfg.ConfigFile)
|
||||
}
|
||||
case <-saveCtx.Done():
|
||||
logger.Error("msg", "Configuration save timeout exceeded",
|
||||
"path", cfg.ConfigFile,
|
||||
"timeout", "5s")
|
||||
}
|
||||
}
|
||||
@ -338,14 +338,6 @@ func (rm *ReloadManager) stopStatusReporter() {
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper to save the config
|
||||
func (rm *ReloadManager) SaveConfig(path string) error {
|
||||
if rm.lcfg == nil {
|
||||
return fmt.Errorf("no lconfig instance available")
|
||||
}
|
||||
return rm.lcfg.Save(path)
|
||||
}
|
||||
|
||||
// Stops the reload manager
|
||||
func (rm *ReloadManager) Shutdown() {
|
||||
rm.logger.Info("msg", "Shutting down reload manager")
|
||||
|
||||
@ -114,81 +114,76 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
||||
for i, sinkCfg := range cfg.Sinks {
|
||||
switch sinkCfg.Type {
|
||||
case "tcp":
|
||||
if port, ok := sinkCfg.Options["port"].(int64); ok {
|
||||
host := "0.0.0.0" // Get host or default to 0.0.0.0
|
||||
if h, ok := sinkCfg.Options["host"].(string); ok && h != "" {
|
||||
host = h
|
||||
if sinkCfg.TCP != nil {
|
||||
host := "0.0.0.0"
|
||||
if sinkCfg.TCP.Host != "" {
|
||||
host = sinkCfg.TCP.Host
|
||||
}
|
||||
|
||||
logger.Info("msg", "TCP endpoint configured",
|
||||
"component", "main",
|
||||
"pipeline", cfg.Name,
|
||||
"sink_index", i,
|
||||
"listen", fmt.Sprintf("%s:%d", host, port))
|
||||
"listen", fmt.Sprintf("%s:%d", host, sinkCfg.TCP.Port))
|
||||
|
||||
// Display net limit info if configured
|
||||
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
||||
if enabled, ok := nl["enabled"].(bool); ok && enabled {
|
||||
if sinkCfg.TCP.NetLimit != nil && sinkCfg.TCP.NetLimit.Enabled {
|
||||
logger.Info("msg", "TCP net limiting enabled",
|
||||
"pipeline", cfg.Name,
|
||||
"sink_index", i,
|
||||
"requests_per_second", nl["requests_per_second"],
|
||||
"burst_size", nl["burst_size"])
|
||||
}
|
||||
"requests_per_second", sinkCfg.TCP.NetLimit.RequestsPerSecond,
|
||||
"burst_size", sinkCfg.TCP.NetLimit.BurstSize)
|
||||
}
|
||||
}
|
||||
|
||||
case "http":
|
||||
if port, ok := sinkCfg.Options["port"].(int64); ok {
|
||||
if sinkCfg.HTTP != nil {
|
||||
host := "0.0.0.0"
|
||||
if h, ok := sinkCfg.Options["host"].(string); ok && h != "" {
|
||||
host = h
|
||||
if sinkCfg.HTTP.Host != "" {
|
||||
host = sinkCfg.HTTP.Host
|
||||
}
|
||||
|
||||
streamPath := "/stream"
|
||||
statusPath := "/status"
|
||||
if path, ok := sinkCfg.Options["stream_path"].(string); ok {
|
||||
streamPath = path
|
||||
if sinkCfg.HTTP.StreamPath != "" {
|
||||
streamPath = sinkCfg.HTTP.StreamPath
|
||||
}
|
||||
if path, ok := sinkCfg.Options["status_path"].(string); ok {
|
||||
statusPath = path
|
||||
if sinkCfg.HTTP.StatusPath != "" {
|
||||
statusPath = sinkCfg.HTTP.StatusPath
|
||||
}
|
||||
|
||||
logger.Info("msg", "HTTP endpoints configured",
|
||||
"pipeline", cfg.Name,
|
||||
"sink_index", i,
|
||||
"listen", fmt.Sprintf("%s:%d", host, port),
|
||||
"stream_url", fmt.Sprintf("http://%s:%d%s", host, port, streamPath),
|
||||
"status_url", fmt.Sprintf("http://%s:%d%s", host, port, statusPath))
|
||||
"listen", fmt.Sprintf("%s:%d", host, sinkCfg.HTTP.Port),
|
||||
"stream_url", fmt.Sprintf("http://%s:%d%s", host, sinkCfg.HTTP.Port, streamPath),
|
||||
"status_url", fmt.Sprintf("http://%s:%d%s", host, sinkCfg.HTTP.Port, statusPath))
|
||||
|
||||
// Display net limit info if configured
|
||||
if nl, ok := sinkCfg.Options["net_limit"].(map[string]any); ok {
|
||||
if enabled, ok := nl["enabled"].(bool); ok && enabled {
|
||||
if sinkCfg.HTTP.NetLimit != nil && sinkCfg.HTTP.NetLimit.Enabled {
|
||||
logger.Info("msg", "HTTP net limiting enabled",
|
||||
"pipeline", cfg.Name,
|
||||
"sink_index", i,
|
||||
"requests_per_second", nl["requests_per_second"],
|
||||
"burst_size", nl["burst_size"])
|
||||
}
|
||||
"requests_per_second", sinkCfg.HTTP.NetLimit.RequestsPerSecond,
|
||||
"burst_size", sinkCfg.HTTP.NetLimit.BurstSize)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
if dir, ok := sinkCfg.Options["directory"].(string); ok {
|
||||
name, _ := sinkCfg.Options["name"].(string)
|
||||
if sinkCfg.File != nil {
|
||||
logger.Info("msg", "File sink configured",
|
||||
"pipeline", cfg.Name,
|
||||
"sink_index", i,
|
||||
"directory", dir,
|
||||
"name", name)
|
||||
"directory", sinkCfg.File.Directory,
|
||||
"name", sinkCfg.File.Name)
|
||||
}
|
||||
|
||||
case "console":
|
||||
if target, ok := sinkCfg.Options["target"].(string); ok {
|
||||
if sinkCfg.Console != nil {
|
||||
logger.Info("msg", "Console sink configured",
|
||||
"pipeline", cfg.Name,
|
||||
"sink_index", i,
|
||||
"target", target)
|
||||
"target", sinkCfg.Console.Target)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -197,10 +192,10 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
||||
for i, sourceCfg := range cfg.Sources {
|
||||
switch sourceCfg.Type {
|
||||
case "http":
|
||||
if port, ok := sourceCfg.Options["port"].(int64); ok {
|
||||
if sourceCfg.HTTP != nil {
|
||||
host := "0.0.0.0"
|
||||
if h, ok := sourceCfg.Options["host"].(string); ok && h != "" {
|
||||
host = h
|
||||
if sourceCfg.HTTP.Host != "" {
|
||||
host = sourceCfg.HTTP.Host
|
||||
}
|
||||
|
||||
displayHost := host
|
||||
@ -209,22 +204,22 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
||||
}
|
||||
|
||||
ingestPath := "/ingest"
|
||||
if path, ok := sourceCfg.Options["ingest_path"].(string); ok {
|
||||
ingestPath = path
|
||||
if sourceCfg.HTTP.IngestPath != "" {
|
||||
ingestPath = sourceCfg.HTTP.IngestPath
|
||||
}
|
||||
|
||||
logger.Info("msg", "HTTP source configured",
|
||||
"pipeline", cfg.Name,
|
||||
"source_index", i,
|
||||
"listen", fmt.Sprintf("%s:%d", host, port),
|
||||
"ingest_url", fmt.Sprintf("http://%s:%d%s", displayHost, port, ingestPath))
|
||||
"listen", fmt.Sprintf("%s:%d", host, sourceCfg.HTTP.Port),
|
||||
"ingest_url", fmt.Sprintf("http://%s:%d%s", displayHost, sourceCfg.HTTP.Port, ingestPath))
|
||||
}
|
||||
|
||||
case "tcp":
|
||||
if port, ok := sourceCfg.Options["port"].(int64); ok {
|
||||
if sourceCfg.TCP != nil {
|
||||
host := "0.0.0.0"
|
||||
if h, ok := sourceCfg.Options["host"].(string); ok && h != "" {
|
||||
host = h
|
||||
if sourceCfg.TCP.Host != "" {
|
||||
host = sourceCfg.TCP.Host
|
||||
}
|
||||
|
||||
displayHost := host
|
||||
@ -235,19 +230,24 @@ func displayPipelineEndpoints(cfg config.PipelineConfig) {
|
||||
logger.Info("msg", "TCP source configured",
|
||||
"pipeline", cfg.Name,
|
||||
"source_index", i,
|
||||
"listen", fmt.Sprintf("%s:%d", host, port),
|
||||
"endpoint", fmt.Sprintf("%s:%d", displayHost, port))
|
||||
"listen", fmt.Sprintf("%s:%d", host, sourceCfg.TCP.Port),
|
||||
"endpoint", fmt.Sprintf("%s:%d", displayHost, sourceCfg.TCP.Port))
|
||||
}
|
||||
|
||||
// TODO: missing other types of source, to be added
|
||||
}
|
||||
}
|
||||
|
||||
// Display authentication information
|
||||
if cfg.Auth != nil && cfg.Auth.Type != "none" {
|
||||
logger.Info("msg", "Authentication enabled",
|
||||
case "directory":
|
||||
if sourceCfg.Directory != nil {
|
||||
logger.Info("msg", "Directory source configured",
|
||||
"pipeline", cfg.Name,
|
||||
"auth_type", cfg.Auth.Type)
|
||||
"source_index", i,
|
||||
"path", sourceCfg.Directory.Path,
|
||||
"pattern", sourceCfg.Directory.Pattern)
|
||||
}
|
||||
|
||||
case "stdin":
|
||||
logger.Info("msg", "Stdin source configured",
|
||||
"pipeline", cfg.Name,
|
||||
"source_index", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Display filter information
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -13,7 +12,6 @@ import (
|
||||
"logwisp/src/internal/config"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Prevent unbounded map growth
|
||||
@ -21,40 +19,28 @@ const maxAuthTrackedIPs = 10000
|
||||
|
||||
// Handles all authentication methods for a pipeline
|
||||
type Authenticator struct {
|
||||
config *config.AuthConfig
|
||||
config *config.ServerAuthConfig
|
||||
logger *log.Logger
|
||||
bearerTokens map[string]bool // token -> valid
|
||||
tokens map[string]bool // token -> valid
|
||||
mu sync.RWMutex
|
||||
|
||||
// Session tracking
|
||||
sessions map[string]*Session
|
||||
sessionMu sync.RWMutex
|
||||
|
||||
// Brute-force protection
|
||||
ipAuthAttempts map[string]*ipAuthState
|
||||
authMu sync.RWMutex
|
||||
}
|
||||
|
||||
// Per-IP auth attempt tracking
|
||||
type ipAuthState struct {
|
||||
limiter *rate.Limiter
|
||||
failCount int
|
||||
lastAttempt time.Time
|
||||
blockedUntil time.Time
|
||||
}
|
||||
|
||||
// Represents an authenticated connection
|
||||
type Session struct {
|
||||
ID string
|
||||
Username string
|
||||
Method string // basic, bearer, mtls
|
||||
Method string // basic, token, mtls
|
||||
RemoteAddr string
|
||||
CreatedAt time.Time
|
||||
LastActivity time.Time
|
||||
}
|
||||
|
||||
// Creates a new authenticator from config
|
||||
func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticator, error) {
|
||||
func NewAuthenticator(cfg *config.ServerAuthConfig, logger *log.Logger) (*Authenticator, error) {
|
||||
// SCRAM is handled by ScramManager in sources
|
||||
if cfg == nil || cfg.Type == "none" || cfg.Type == "scram" {
|
||||
return nil, nil
|
||||
@ -63,24 +49,20 @@ func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticato
|
||||
a := &Authenticator{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
bearerTokens: make(map[string]bool),
|
||||
tokens: make(map[string]bool),
|
||||
sessions: make(map[string]*Session),
|
||||
ipAuthAttempts: make(map[string]*ipAuthState),
|
||||
}
|
||||
|
||||
// Initialize Bearer tokens
|
||||
if cfg.Type == "bearer" && cfg.BearerAuth != nil {
|
||||
for _, token := range cfg.BearerAuth.Tokens {
|
||||
a.bearerTokens[token] = true
|
||||
// Initialize tokens
|
||||
if cfg.Type == "token" && cfg.Token != nil {
|
||||
for _, token := range cfg.Token.Tokens {
|
||||
a.tokens[token] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Start session cleanup
|
||||
go a.sessionCleanup()
|
||||
|
||||
// Start auth attempt cleanup
|
||||
go a.authAttemptCleanup()
|
||||
|
||||
logger.Info("msg", "Authenticator initialized",
|
||||
"component", "auth",
|
||||
"type", cfg.Type)
|
||||
@ -88,129 +70,6 @@ func NewAuthenticator(cfg *config.AuthConfig, logger *log.Logger) (*Authenticato
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Check and enforce rate limits
|
||||
func (a *Authenticator) checkRateLimit(remoteAddr string) error {
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
ip = remoteAddr // Fallback for malformed addresses
|
||||
}
|
||||
|
||||
a.authMu.Lock()
|
||||
defer a.authMu.Unlock()
|
||||
|
||||
state, exists := a.ipAuthAttempts[ip]
|
||||
now := time.Now()
|
||||
|
||||
if !exists {
|
||||
// Check map size limit before creating new entry
|
||||
if len(a.ipAuthAttempts) >= maxAuthTrackedIPs {
|
||||
// Evict an old entry using simplified LRU
|
||||
// Sample 20 random entries and evict the oldest
|
||||
const sampleSize = 20
|
||||
var oldestIP string
|
||||
oldestTime := now
|
||||
|
||||
// Build sample
|
||||
sampled := 0
|
||||
for sampledIP, sampledState := range a.ipAuthAttempts {
|
||||
if sampledState.lastAttempt.Before(oldestTime) {
|
||||
oldestIP = sampledIP
|
||||
oldestTime = sampledState.lastAttempt
|
||||
}
|
||||
sampled++
|
||||
if sampled >= sampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Evict the oldest from our sample
|
||||
if oldestIP != "" {
|
||||
delete(a.ipAuthAttempts, oldestIP)
|
||||
a.logger.Debug("msg", "Evicted old auth attempt state",
|
||||
"component", "auth",
|
||||
"evicted_ip", oldestIP,
|
||||
"last_seen", oldestTime)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new state for this IP
|
||||
// 5 attempts per minute, burst of 3
|
||||
state = &ipAuthState{
|
||||
limiter: rate.NewLimiter(rate.Every(12*time.Second), 3),
|
||||
lastAttempt: now,
|
||||
}
|
||||
a.ipAuthAttempts[ip] = state
|
||||
}
|
||||
|
||||
// Check if IP is temporarily blocked
|
||||
if now.Before(state.blockedUntil) {
|
||||
remaining := state.blockedUntil.Sub(now)
|
||||
a.logger.Warn("msg", "IP temporarily blocked",
|
||||
"component", "auth",
|
||||
"ip", ip,
|
||||
"remaining", remaining)
|
||||
// Sleep to slow down even blocked attempts
|
||||
time.Sleep(2 * time.Second)
|
||||
return fmt.Errorf("temporarily blocked, try again in %v", remaining.Round(time.Second))
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
if !state.limiter.Allow() {
|
||||
state.failCount++
|
||||
|
||||
// Only set new blockedUntil if not already blocked
|
||||
// This prevents indefinite block extension
|
||||
if state.blockedUntil.IsZero() || now.After(state.blockedUntil) {
|
||||
// Progressive blocking: 2^failCount minutes
|
||||
blockMinutes := 1 << min(state.failCount, 6) // Cap at 64 minutes
|
||||
state.blockedUntil = now.Add(time.Duration(blockMinutes) * time.Minute)
|
||||
|
||||
a.logger.Warn("msg", "Rate limit exceeded, blocking IP",
|
||||
"component", "auth",
|
||||
"ip", ip,
|
||||
"fail_count", state.failCount,
|
||||
"block_duration", time.Duration(blockMinutes)*time.Minute)
|
||||
}
|
||||
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
|
||||
state.lastAttempt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
// Record failed attempt
|
||||
func (a *Authenticator) recordFailure(remoteAddr string) {
|
||||
ip, _, _ := net.SplitHostPort(remoteAddr)
|
||||
if ip == "" {
|
||||
ip = remoteAddr
|
||||
}
|
||||
|
||||
a.authMu.Lock()
|
||||
defer a.authMu.Unlock()
|
||||
|
||||
if state, exists := a.ipAuthAttempts[ip]; exists {
|
||||
state.failCount++
|
||||
state.lastAttempt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// Reset failure count on success
|
||||
func (a *Authenticator) recordSuccess(remoteAddr string) {
|
||||
ip, _, _ := net.SplitHostPort(remoteAddr)
|
||||
if ip == "" {
|
||||
ip = remoteAddr
|
||||
}
|
||||
|
||||
a.authMu.Lock()
|
||||
defer a.authMu.Unlock()
|
||||
|
||||
if state, exists := a.ipAuthAttempts[ip]; exists {
|
||||
state.failCount = 0
|
||||
state.blockedUntil = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
// Handles HTTP authentication headers
|
||||
func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Session, error) {
|
||||
if a == nil || a.config.Type == "none" {
|
||||
@ -222,77 +81,27 @@ func (a *Authenticator) AuthenticateHTTP(authHeader, remoteAddr string) (*Sessio
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
if err := a.checkRateLimit(remoteAddr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var session *Session
|
||||
var err error
|
||||
|
||||
switch a.config.Type {
|
||||
case "bearer":
|
||||
session, err = a.authenticateBearer(authHeader, remoteAddr)
|
||||
case "token":
|
||||
session, err = a.authenticateToken(authHeader, remoteAddr)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported auth type: %s", a.config.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.recordFailure(remoteAddr)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.recordSuccess(remoteAddr)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Handles TCP connection authentication
|
||||
func (a *Authenticator) AuthenticateTCP(method, credentials, remoteAddr string) (*Session, error) {
|
||||
if a == nil || a.config.Type == "none" {
|
||||
return &Session{
|
||||
ID: generateSessionID(),
|
||||
Method: "none",
|
||||
RemoteAddr: remoteAddr,
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Check rate limit first
|
||||
if err := a.checkRateLimit(remoteAddr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var session *Session
|
||||
var err error
|
||||
|
||||
// TCP auth protocol: AUTH <method> <credentials>
|
||||
switch strings.ToLower(method) {
|
||||
case "token":
|
||||
if a.config.Type != "bearer" {
|
||||
err = fmt.Errorf("token auth not configured")
|
||||
} else {
|
||||
session, err = a.validateToken(credentials, remoteAddr)
|
||||
}
|
||||
|
||||
default:
|
||||
err = fmt.Errorf("unsupported auth method: %s", method)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.recordFailure(remoteAddr)
|
||||
// Add delay on failure
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.recordSuccess(remoteAddr)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Session, error) {
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return nil, fmt.Errorf("invalid bearer auth header")
|
||||
func (a *Authenticator) authenticateToken(authHeader, remoteAddr string) (*Session, error) {
|
||||
if !strings.HasPrefix(authHeader, "Token") {
|
||||
return nil, fmt.Errorf("invalid token auth header")
|
||||
}
|
||||
|
||||
token := authHeader[7:]
|
||||
@ -302,7 +111,7 @@ func (a *Authenticator) authenticateBearer(authHeader, remoteAddr string) (*Sess
|
||||
func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error) {
|
||||
// Check static tokens first
|
||||
a.mu.RLock()
|
||||
isValid := a.bearerTokens[token]
|
||||
isValid := a.tokens[token]
|
||||
a.mu.RUnlock()
|
||||
|
||||
if !isValid {
|
||||
@ -311,7 +120,7 @@ func (a *Authenticator) validateToken(token, remoteAddr string) (*Session, error
|
||||
|
||||
session := &Session{
|
||||
ID: generateSessionID(),
|
||||
Method: "bearer",
|
||||
Method: "token",
|
||||
RemoteAddr: remoteAddr,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivity: time.Now(),
|
||||
@ -352,27 +161,6 @@ func (a *Authenticator) sessionCleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup old auth attempts
|
||||
func (a *Authenticator) authAttemptCleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
a.authMu.Lock()
|
||||
now := time.Now()
|
||||
for ip, state := range a.ipAuthAttempts {
|
||||
// Remove entries older than 1 hour with no recent activity
|
||||
if now.Sub(state.lastAttempt) > time.Hour {
|
||||
delete(a.ipAuthAttempts, ip)
|
||||
a.logger.Debug("msg", "Cleaned up auth attempt state",
|
||||
"component", "auth",
|
||||
"ip", ip)
|
||||
}
|
||||
}
|
||||
a.authMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
@ -418,6 +206,6 @@ func (a *Authenticator) GetStats() map[string]any {
|
||||
"enabled": true,
|
||||
"type": a.config.Type,
|
||||
"active_sessions": sessionCount,
|
||||
"static_tokens": len(a.bearerTokens),
|
||||
"static_tokens": len(a.tokens),
|
||||
}
|
||||
}
|
||||
@ -1,207 +0,0 @@
|
||||
// FILE: src/internal/auth/generator.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"logwisp/src/internal/scram"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Argon2id parameters
|
||||
const (
|
||||
argon2Time = 3
|
||||
argon2Memory = 64 * 1024 // 64 MB
|
||||
argon2Threads = 4
|
||||
argon2SaltLen = 16
|
||||
argon2KeyLen = 32
|
||||
)
|
||||
|
||||
type AuthGeneratorCommand struct {
|
||||
output io.Writer
|
||||
errOut io.Writer
|
||||
}
|
||||
|
||||
func NewAuthGeneratorCommand() *AuthGeneratorCommand {
|
||||
return &AuthGeneratorCommand{
|
||||
output: os.Stdout,
|
||||
errOut: os.Stderr,
|
||||
}
|
||||
}
|
||||
|
||||
func (ag *AuthGeneratorCommand) Execute(args []string) error {
|
||||
cmd := flag.NewFlagSet("auth", flag.ContinueOnError)
|
||||
cmd.SetOutput(ag.errOut)
|
||||
|
||||
var (
|
||||
username = cmd.String("u", "", "Username")
|
||||
password = cmd.String("p", "", "Password (will prompt if not provided)")
|
||||
authType = cmd.String("type", "basic", "Auth type: basic (HTTP) or scram (TCP)")
|
||||
genToken = cmd.Bool("t", false, "Generate random bearer token")
|
||||
tokenLen = cmd.Int("l", 32, "Token length in bytes (min 16, max 512)")
|
||||
)
|
||||
|
||||
cmd.Usage = func() {
|
||||
fmt.Fprintln(ag.errOut, "Generate authentication credentials for LogWisp")
|
||||
fmt.Fprintln(ag.errOut, "\nUsage: logwisp auth [options]")
|
||||
fmt.Fprintln(ag.errOut, "\nExamples:")
|
||||
fmt.Fprintln(ag.errOut, " # Generate basic auth hash for HTTP sources/sinks")
|
||||
fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type basic")
|
||||
fmt.Fprintln(ag.errOut, " ")
|
||||
fmt.Fprintln(ag.errOut, " # Generate SCRAM credentials for TCP sources/sinks")
|
||||
fmt.Fprintln(ag.errOut, " logwisp auth -u admin -type scram")
|
||||
fmt.Fprintln(ag.errOut, " ")
|
||||
fmt.Fprintln(ag.errOut, " # Generate 64-byte bearer token")
|
||||
fmt.Fprintln(ag.errOut, " logwisp auth -t -l 64")
|
||||
fmt.Fprintln(ag.errOut, "\nOptions:")
|
||||
cmd.PrintDefaults()
|
||||
fmt.Fprintln(ag.errOut)
|
||||
}
|
||||
|
||||
if err := cmd.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *genToken {
|
||||
return ag.generateToken(*tokenLen)
|
||||
}
|
||||
|
||||
if *username == "" {
|
||||
cmd.Usage()
|
||||
return fmt.Errorf("username required for credential generation")
|
||||
}
|
||||
|
||||
switch *authType {
|
||||
case "basic":
|
||||
return ag.generateBasicAuth(*username, *password)
|
||||
case "scram":
|
||||
return ag.generateScramAuth(*username, *password)
|
||||
default:
|
||||
return fmt.Errorf("invalid auth type: %s (use 'basic' or 'scram')", *authType)
|
||||
}
|
||||
}
|
||||
|
||||
func (ag *AuthGeneratorCommand) generateBasicAuth(username, password string) error {
|
||||
// Get password if not provided
|
||||
if password == "" {
|
||||
pass1 := ag.promptPassword("Enter password: ")
|
||||
pass2 := ag.promptPassword("Confirm password: ")
|
||||
if pass1 != pass2 {
|
||||
return fmt.Errorf("passwords don't match")
|
||||
}
|
||||
password = pass1
|
||||
}
|
||||
|
||||
// Generate salt
|
||||
salt := make([]byte, argon2SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Generate Argon2id hash
|
||||
hash := argon2.IDKey([]byte(password), salt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
|
||||
|
||||
// Encode in PHC format: $argon2id$v=19$m=65536,t=3,p=4$salt$hash
|
||||
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
||||
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, argon2Memory, argon2Time, argon2Threads, saltB64, hashB64)
|
||||
|
||||
// Output configuration snippets
|
||||
fmt.Fprintln(ag.output, "\n# Basic Auth Configuration (HTTP sources/sinks)")
|
||||
fmt.Fprintln(ag.output, "# REQUIRES HTTPS/TLS for security")
|
||||
fmt.Fprintln(ag.output, "# Add to logwisp.toml under [[pipelines]]:")
|
||||
fmt.Fprintln(ag.output, "")
|
||||
fmt.Fprintln(ag.output, "[pipelines.auth]")
|
||||
fmt.Fprintln(ag.output, `type = "basic"`)
|
||||
fmt.Fprintln(ag.output, "")
|
||||
fmt.Fprintln(ag.output, "[[pipelines.auth.basic_auth.users]]")
|
||||
fmt.Fprintf(ag.output, "username = %q\n", username)
|
||||
fmt.Fprintf(ag.output, "password_hash = %q\n\n", phcHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ag *AuthGeneratorCommand) generateScramAuth(username, password string) error {
|
||||
// Get password if not provided
|
||||
if password == "" {
|
||||
pass1 := ag.promptPassword("Enter password: ")
|
||||
pass2 := ag.promptPassword("Confirm password: ")
|
||||
if pass1 != pass2 {
|
||||
return fmt.Errorf("passwords don't match")
|
||||
}
|
||||
password = pass1
|
||||
}
|
||||
|
||||
// Generate salt
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Derive SCRAM credential
|
||||
cred, err := scram.DeriveCredential(username, password, salt, 3, 65536, 4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to derive SCRAM credential: %w", err)
|
||||
}
|
||||
|
||||
// Output SCRAM configuration
|
||||
fmt.Fprintln(ag.output, "\n# SCRAM Auth Configuration (for TCP sources/sinks)")
|
||||
fmt.Fprintln(ag.output, "# Add to logwisp.toml:")
|
||||
fmt.Fprintln(ag.output, "[[pipelines.auth.scram_auth.users]]")
|
||||
fmt.Fprintf(ag.output, "username = %q\n", username)
|
||||
fmt.Fprintf(ag.output, "stored_key = %q\n", base64.StdEncoding.EncodeToString(cred.StoredKey))
|
||||
fmt.Fprintf(ag.output, "server_key = %q\n", base64.StdEncoding.EncodeToString(cred.ServerKey))
|
||||
fmt.Fprintf(ag.output, "salt = %q\n", base64.StdEncoding.EncodeToString(cred.Salt))
|
||||
fmt.Fprintf(ag.output, "argon_time = %d\n", cred.ArgonTime)
|
||||
fmt.Fprintf(ag.output, "argon_memory = %d\n", cred.ArgonMemory)
|
||||
fmt.Fprintf(ag.output, "argon_threads = %d\n\n", cred.ArgonThreads)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ag *AuthGeneratorCommand) generateToken(length int) error {
|
||||
if length < 16 {
|
||||
fmt.Fprintln(ag.errOut, "Warning: tokens < 16 bytes are cryptographically weak")
|
||||
}
|
||||
if length > 512 {
|
||||
return fmt.Errorf("token length exceeds maximum (512 bytes)")
|
||||
}
|
||||
|
||||
token := make([]byte, length)
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
return fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
b64 := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(token)
|
||||
hex := fmt.Sprintf("%x", token)
|
||||
|
||||
fmt.Fprintln(ag.output, "\n# Bearer Token Configuration")
|
||||
fmt.Fprintln(ag.output, "# Add to logwisp.toml:")
|
||||
fmt.Fprintf(ag.output, "tokens = [%q]\n\n", b64)
|
||||
|
||||
fmt.Fprintln(ag.output, "# Generated Token:")
|
||||
fmt.Fprintf(ag.output, "Base64: %s\n", b64)
|
||||
fmt.Fprintf(ag.output, "Hex: %s\n", hex)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ag *AuthGeneratorCommand) promptPassword(prompt string) string {
|
||||
fmt.Fprint(ag.errOut, prompt)
|
||||
password, err := term.ReadPassword(syscall.Stdin)
|
||||
fmt.Fprintln(ag.errOut)
|
||||
if err != nil {
|
||||
fmt.Fprintf(ag.errOut, "Failed to read password: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return string(password)
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
// FILE: src/internal/scram/client.go
|
||||
package scram
|
||||
// FILE: src/internal/auth/scram_client.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
// Client handles SCRAM client-side authentication
|
||||
type Client struct {
|
||||
type ScramClient struct {
|
||||
Username string
|
||||
Password string
|
||||
|
||||
@ -23,16 +23,16 @@ type Client struct {
|
||||
serverKey []byte
|
||||
}
|
||||
|
||||
// NewClient creates SCRAM client
|
||||
func NewClient(username, password string) *Client {
|
||||
return &Client{
|
||||
// NewScramClient creates SCRAM client
|
||||
func NewScramClient(username, password string) *ScramClient {
|
||||
return &ScramClient{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
// StartAuthentication generates ClientFirst message
|
||||
func (c *Client) StartAuthentication() (*ClientFirst, error) {
|
||||
func (c *ScramClient) StartAuthentication() (*ClientFirst, error) {
|
||||
// Generate client nonce
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
@ -47,7 +47,7 @@ func (c *Client) StartAuthentication() (*ClientFirst, error) {
|
||||
}
|
||||
|
||||
// ProcessServerFirst handles server challenge
|
||||
func (c *Client) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) {
|
||||
func (c *ScramClient) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) {
|
||||
c.serverFirst = msg
|
||||
|
||||
// Decode salt
|
||||
@ -83,7 +83,7 @@ func (c *Client) ProcessServerFirst(msg *ServerFirst) (*ClientFinal, error) {
|
||||
}
|
||||
|
||||
// VerifyServerFinal validates server signature
|
||||
func (c *Client) VerifyServerFinal(msg *ServerFinal) error {
|
||||
func (c *ScramClient) VerifyServerFinal(msg *ServerFinal) error {
|
||||
if c.authMessage == "" || c.serverKey == nil {
|
||||
return fmt.Errorf("invalid handshake state")
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
// FILE: src/internal/scram/credential.go
|
||||
package scram
|
||||
// FILE: src/internal/auth/scram_credential.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
@ -9,6 +9,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
@ -31,7 +33,13 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
|
||||
}
|
||||
|
||||
// Derive salted password using Argon2id
|
||||
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, 32)
|
||||
saltedPassword := argon2.IDKey([]byte(password), salt, time, memory, threads, core.Argon2KeyLen)
|
||||
|
||||
// Construct PHC format for basic auth compatibility
|
||||
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||
hashB64 := base64.RawStdEncoding.EncodeToString(saltedPassword)
|
||||
phcHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, memory, time, threads, saltB64, hashB64)
|
||||
|
||||
// Derive keys
|
||||
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||
@ -46,6 +54,7 @@ func DeriveCredential(username, password string, salt []byte, time, memory uint3
|
||||
ArgonThreads: threads,
|
||||
StoredKey: storedKey[:],
|
||||
ServerKey: serverKey,
|
||||
PHCHash: phcHash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
83
src/internal/auth/scram_manager.go
Normal file
83
src/internal/auth/scram_manager.go
Normal file
@ -0,0 +1,83 @@
|
||||
// FILE: src/internal/auth/scram_manager.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
)
|
||||
|
||||
// ScramManager provides high-level SCRAM operations with rate limiting
|
||||
type ScramManager struct {
|
||||
server *ScramServer
|
||||
}
|
||||
|
||||
// NewScramManager creates SCRAM manager
|
||||
func NewScramManager(scramAuthCfg *config.ScramAuthConfig) *ScramManager {
|
||||
manager := &ScramManager{
|
||||
server: NewScramServer(),
|
||||
}
|
||||
|
||||
// Load users from SCRAM config
|
||||
for _, user := range scramAuthCfg.Users {
|
||||
storedKey, err := base64.StdEncoding.DecodeString(user.StoredKey)
|
||||
if err != nil {
|
||||
// Skip user with invalid stored key
|
||||
continue
|
||||
}
|
||||
|
||||
serverKey, err := base64.StdEncoding.DecodeString(user.ServerKey)
|
||||
if err != nil {
|
||||
// Skip user with invalid server key
|
||||
continue
|
||||
}
|
||||
|
||||
salt, err := base64.StdEncoding.DecodeString(user.Salt)
|
||||
if err != nil {
|
||||
// Skip user with invalid salt
|
||||
continue
|
||||
}
|
||||
|
||||
cred := &Credential{
|
||||
Username: user.Username,
|
||||
StoredKey: storedKey,
|
||||
ServerKey: serverKey,
|
||||
Salt: salt,
|
||||
ArgonTime: user.ArgonTime,
|
||||
ArgonMemory: user.ArgonMemory,
|
||||
ArgonThreads: user.ArgonThreads,
|
||||
}
|
||||
manager.server.AddCredential(cred)
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// RegisterUser creates new user credential
|
||||
func (sm *ScramManager) RegisterUser(username, password string) error {
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("salt generation failed: %w", err)
|
||||
}
|
||||
|
||||
cred, err := DeriveCredential(username, password, salt,
|
||||
sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sm.server.AddCredential(cred)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleClientFirst wraps server's HandleClientFirst
|
||||
func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
|
||||
return sm.server.HandleClientFirst(msg)
|
||||
}
|
||||
|
||||
// HandleClientFinal wraps server's HandleClientFinal
|
||||
func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
|
||||
return sm.server.HandleClientFinal(msg)
|
||||
}
|
||||
38
src/internal/auth/scram_message.go
Normal file
38
src/internal/auth/scram_message.go
Normal file
@ -0,0 +1,38 @@
|
||||
// FILE: src/internal/auth/scram_message.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ClientFirst initiates authentication
|
||||
type ClientFirst struct {
|
||||
Username string `json:"u"`
|
||||
ClientNonce string `json:"n"`
|
||||
}
|
||||
|
||||
// ServerFirst contains server challenge
|
||||
type ServerFirst struct {
|
||||
FullNonce string `json:"r"` // client_nonce + server_nonce
|
||||
Salt string `json:"s"` // base64
|
||||
ArgonTime uint32 `json:"t"`
|
||||
ArgonMemory uint32 `json:"m"`
|
||||
ArgonThreads uint8 `json:"p"`
|
||||
}
|
||||
|
||||
// ClientFinal contains client proof
|
||||
type ClientFinal struct {
|
||||
FullNonce string `json:"r"`
|
||||
ClientProof string `json:"p"` // base64
|
||||
}
|
||||
|
||||
// ServerFinal contains server signature for mutual auth
|
||||
type ServerFinal struct {
|
||||
ServerSignature string `json:"v"` // base64
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
}
|
||||
|
||||
func (sf *ServerFirst) Marshal() string {
|
||||
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
|
||||
sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads)
|
||||
}
|
||||
117
src/internal/auth/scram_protocol.go
Normal file
117
src/internal/auth/scram_protocol.go
Normal file
@ -0,0 +1,117 @@
|
||||
// FILE: src/internal/auth/scram_protocol.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
)
|
||||
|
||||
// ScramProtocolHandler handles SCRAM message exchange for TCP
|
||||
type ScramProtocolHandler struct {
|
||||
manager *ScramManager
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewScramProtocolHandler creates protocol handler
|
||||
func NewScramProtocolHandler(manager *ScramManager, logger *log.Logger) *ScramProtocolHandler {
|
||||
return &ScramProtocolHandler{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAuthMessage processes a complete auth line from buffer
|
||||
func (sph *ScramProtocolHandler) HandleAuthMessage(line []byte, conn gnet.Conn) (authenticated bool, session *Session, err error) {
|
||||
// Parse SCRAM messages
|
||||
parts := strings.Fields(string(line))
|
||||
if len(parts) < 2 {
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil)
|
||||
return false, nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
switch parts[0] {
|
||||
case "SCRAM-FIRST":
|
||||
// Parse ClientFirst JSON
|
||||
var clientFirst ClientFirst
|
||||
if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil {
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
|
||||
return false, nil, fmt.Errorf("invalid JSON")
|
||||
}
|
||||
|
||||
// Process with SCRAM server
|
||||
serverFirst, err := sph.manager.HandleClientFirst(&clientFirst)
|
||||
if err != nil {
|
||||
// Still send challenge to prevent user enumeration
|
||||
response, _ := json.Marshal(serverFirst)
|
||||
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
// Send ServerFirst challenge
|
||||
response, _ := json.Marshal(serverFirst)
|
||||
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
|
||||
return false, nil, nil // Not authenticated yet
|
||||
|
||||
case "SCRAM-PROOF":
|
||||
// Parse ClientFinal JSON
|
||||
var clientFinal ClientFinal
|
||||
if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil {
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
|
||||
return false, nil, fmt.Errorf("invalid JSON")
|
||||
}
|
||||
|
||||
// Verify proof
|
||||
serverFinal, err := sph.manager.HandleClientFinal(&clientFinal)
|
||||
if err != nil {
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil)
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
// Authentication successful
|
||||
session = &Session{
|
||||
ID: serverFinal.SessionID,
|
||||
Method: "scram-sha-256",
|
||||
RemoteAddr: conn.RemoteAddr().String(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Send ServerFinal with signature
|
||||
response, _ := json.Marshal(serverFinal)
|
||||
conn.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil)
|
||||
|
||||
return true, session, nil
|
||||
|
||||
default:
|
||||
conn.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil)
|
||||
return false, nil, fmt.Errorf("unknown command: %s", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSCRAMRequest formats a SCRAM protocol message for TCP
|
||||
func FormatSCRAMRequest(command string, data interface{}) (string, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal %s: %w", command, err)
|
||||
}
|
||||
return fmt.Sprintf("%s %s\n", command, jsonData), nil
|
||||
}
|
||||
|
||||
// ParseSCRAMResponse parses a SCRAM protocol response from TCP
|
||||
func ParseSCRAMResponse(response string) (command string, data string, err error) {
|
||||
response = strings.TrimSpace(response)
|
||||
parts := strings.SplitN(response, " ", 2)
|
||||
if len(parts) < 1 {
|
||||
return "", "", fmt.Errorf("empty response")
|
||||
}
|
||||
|
||||
command = parts[0]
|
||||
if len(parts) > 1 {
|
||||
data = parts[1]
|
||||
}
|
||||
return command, data, nil
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
// FILE: src/internal/scram/server.go
|
||||
package scram
|
||||
// FILE: src/internal/auth/scram_server.go
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@ -9,14 +9,17 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
// Server handles SCRAM authentication
|
||||
type Server struct {
|
||||
type ScramServer struct {
|
||||
credentials map[string]*Credential
|
||||
handshakes map[string]*HandshakeState
|
||||
mu sync.RWMutex
|
||||
|
||||
// TODO: configurability useful? to be included in config or refactor to use core.const directly for simplicity
|
||||
// Default Argon2 params for new registrations
|
||||
DefaultTime uint32
|
||||
DefaultMemory uint32
|
||||
@ -29,32 +32,30 @@ type HandshakeState struct {
|
||||
ClientNonce string
|
||||
ServerNonce string
|
||||
FullNonce string
|
||||
AuthMessage string
|
||||
Credential *Credential
|
||||
CreatedAt time.Time
|
||||
ClientProof []byte
|
||||
}
|
||||
|
||||
// NewServer creates SCRAM server
|
||||
func NewServer() *Server {
|
||||
return &Server{
|
||||
// NewScramServer creates SCRAM server
|
||||
func NewScramServer() *ScramServer {
|
||||
return &ScramServer{
|
||||
credentials: make(map[string]*Credential),
|
||||
handshakes: make(map[string]*HandshakeState),
|
||||
DefaultTime: 3,
|
||||
DefaultMemory: 64 * 1024,
|
||||
DefaultThreads: 4,
|
||||
DefaultTime: core.Argon2Time,
|
||||
DefaultMemory: core.Argon2Memory,
|
||||
DefaultThreads: core.Argon2Threads,
|
||||
}
|
||||
}
|
||||
|
||||
// AddCredential registers user credential
|
||||
func (s *Server) AddCredential(cred *Credential) {
|
||||
func (s *ScramServer) AddCredential(cred *Credential) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.credentials[cred.Username] = cred
|
||||
}
|
||||
|
||||
// HandleClientFirst processes initial auth request
|
||||
func (s *Server) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
|
||||
func (s *ScramServer) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@ -103,7 +104,7 @@ func (s *Server) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
|
||||
}
|
||||
|
||||
// HandleClientFinal verifies client proof
|
||||
func (s *Server) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
|
||||
func (s *ScramServer) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@ -157,7 +158,7 @@ func (s *Server) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) cleanupHandshakes() {
|
||||
func (s *ScramServer) cleanupHandshakes() {
|
||||
cutoff := time.Now().Add(-60 * time.Second)
|
||||
for nonce, state := range s.handshakes {
|
||||
if state.CreatedAt.Before(cutoff) {
|
||||
@ -171,9 +172,3 @@ func generateNonce() string {
|
||||
rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
b := make([]byte, 24)
|
||||
rand.Read(b)
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
@ -1,81 +0,0 @@
|
||||
// FILE: logwisp/src/internal/config/auth.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type AuthConfig struct {
|
||||
// Authentication type: "none", "basic", "scram", "bearer", "mtls"
|
||||
Type string `toml:"type"`
|
||||
|
||||
BasicAuth *BasicAuthConfig `toml:"basic_auth"`
|
||||
ScramAuth *ScramAuthConfig `toml:"scram_auth"`
|
||||
BearerAuth *BearerAuthConfig `toml:"bearer_auth"`
|
||||
}
|
||||
|
||||
type BasicAuthConfig struct {
|
||||
Users []BasicAuthUser `toml:"users"`
|
||||
Realm string `toml:"realm"`
|
||||
}
|
||||
|
||||
type BasicAuthUser struct {
|
||||
Username string `toml:"username"`
|
||||
PasswordHash string `toml:"password_hash"` // Argon2
|
||||
}
|
||||
|
||||
type ScramAuthConfig struct {
|
||||
Users []ScramUser `toml:"users"`
|
||||
}
|
||||
|
||||
type ScramUser struct {
|
||||
Username string `toml:"username"`
|
||||
StoredKey string `toml:"stored_key"` // base64
|
||||
ServerKey string `toml:"server_key"` // base64
|
||||
Salt string `toml:"salt"` // base64
|
||||
ArgonTime uint32 `toml:"argon_time"`
|
||||
ArgonMemory uint32 `toml:"argon_memory"`
|
||||
ArgonThreads uint8 `toml:"argon_threads"`
|
||||
}
|
||||
|
||||
type BearerAuthConfig struct {
|
||||
// Static tokens
|
||||
Tokens []string `toml:"tokens"`
|
||||
|
||||
// TODO: Maybe future development
|
||||
// // JWT validation
|
||||
// JWT *JWTConfig `toml:"jwt"`
|
||||
}
|
||||
|
||||
// TODO: Maybe future development
|
||||
// type JWTConfig struct {
|
||||
// JWKSURL string `toml:"jwks_url"`
|
||||
// SigningKey string `toml:"signing_key"`
|
||||
// Issuer string `toml:"issuer"`
|
||||
// Audience string `toml:"audience"`
|
||||
// }
|
||||
|
||||
func validateAuth(pipelineName string, auth *AuthConfig) error {
|
||||
if auth == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
validTypes := map[string]bool{"none": true, "basic": true, "scram": true, "bearer": true, "mtls": true}
|
||||
if !validTypes[auth.Type] {
|
||||
return fmt.Errorf("pipeline '%s': invalid auth type: %s", pipelineName, auth.Type)
|
||||
}
|
||||
|
||||
if auth.Type == "basic" && auth.BasicAuth == nil {
|
||||
return fmt.Errorf("pipeline '%s': basic auth type specified but config missing", pipelineName)
|
||||
}
|
||||
|
||||
if auth.Type == "scram" && auth.ScramAuth == nil {
|
||||
return fmt.Errorf("pipeline '%s': scram auth type specified but config missing", pipelineName)
|
||||
}
|
||||
|
||||
if auth.Type == "bearer" && auth.BearerAuth == nil {
|
||||
return fmt.Errorf("pipeline '%s': bearer auth type specified but config missing", pipelineName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,6 +1,8 @@
|
||||
// FILE: logwisp/src/internal/config/config.go
|
||||
package config
|
||||
|
||||
// --- LogWisp Configuration Options ---
|
||||
|
||||
type Config struct {
|
||||
// Top-level flags for application control
|
||||
Background bool `toml:"background"`
|
||||
@ -10,7 +12,6 @@ type Config struct {
|
||||
// Runtime behavior flags
|
||||
DisableStatusReporter bool `toml:"disable_status_reporter"`
|
||||
ConfigAutoReload bool `toml:"config_auto_reload"`
|
||||
ConfigSaveOnExit bool `toml:"config_save_on_exit"`
|
||||
|
||||
// Internal flag indicating demonized child process
|
||||
BackgroundDaemon bool `toml:"background-daemon"`
|
||||
@ -22,3 +23,364 @@ type Config struct {
|
||||
Logging *LogConfig `toml:"logging"`
|
||||
Pipelines []PipelineConfig `toml:"pipelines"`
|
||||
}
|
||||
|
||||
// --- Logging Options ---
|
||||
|
||||
// Represents logging configuration for LogWisp
|
||||
type LogConfig struct {
|
||||
// Output mode: "file", "stdout", "stderr", "split", "all", "none"
|
||||
Output string `toml:"output"`
|
||||
|
||||
// Log level: "debug", "info", "warn", "error"
|
||||
Level string `toml:"level"`
|
||||
|
||||
// File output settings (when Output includes "file" or "all")
|
||||
File *LogFileConfig `toml:"file"`
|
||||
|
||||
// Console output settings
|
||||
Console *LogConsoleConfig `toml:"console"`
|
||||
}
|
||||
|
||||
type LogFileConfig struct {
|
||||
// Directory for log files
|
||||
Directory string `toml:"directory"`
|
||||
|
||||
// Base name for log files
|
||||
Name string `toml:"name"`
|
||||
|
||||
// Maximum size per log file in MB
|
||||
MaxSizeMB int64 `toml:"max_size_mb"`
|
||||
|
||||
// Maximum total size of all logs in MB
|
||||
MaxTotalSizeMB int64 `toml:"max_total_size_mb"`
|
||||
|
||||
// Log retention in hours (0 = disabled)
|
||||
RetentionHours float64 `toml:"retention_hours"`
|
||||
}
|
||||
|
||||
type LogConsoleConfig struct {
|
||||
// Target for console output: "stdout", "stderr", "split"
|
||||
// "split": info/debug to stdout, warn/error to stderr
|
||||
Target string `toml:"target"`
|
||||
|
||||
// Format: "txt" or "json"
|
||||
Format string `toml:"format"`
|
||||
}
|
||||
|
||||
// --- Pipeline Options ---
|
||||
|
||||
type PipelineConfig struct {
|
||||
Name string `toml:"name"`
|
||||
Sources []SourceConfig `toml:"sources"`
|
||||
RateLimit *RateLimitConfig `toml:"rate_limit"`
|
||||
Filters []FilterConfig `toml:"filters"`
|
||||
Format *FormatConfig `toml:"format"`
|
||||
|
||||
Sinks []SinkConfig `toml:"sinks"`
|
||||
// Auth *ServerAuthConfig `toml:"auth"` // Global auth for pipeline
|
||||
}
|
||||
|
||||
// Common configuration structs used across components
|
||||
|
||||
type NetLimitConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
MaxConnections int64 `toml:"max_connections"`
|
||||
RequestsPerSecond float64 `toml:"requests_per_second"`
|
||||
BurstSize int64 `toml:"burst_size"`
|
||||
ResponseMessage string `toml:"response_message"`
|
||||
ResponseCode int64 `toml:"response_code"` // Default: 429
|
||||
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
|
||||
MaxConnectionsPerUser int64 `toml:"max_connections_per_user"`
|
||||
MaxConnectionsPerToken int64 `toml:"max_connections_per_token"`
|
||||
MaxConnectionsTotal int64 `toml:"max_connections_total"`
|
||||
IPWhitelist []string `toml:"ip_whitelist"`
|
||||
IPBlacklist []string `toml:"ip_blacklist"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
CertFile string `toml:"cert_file"`
|
||||
KeyFile string `toml:"key_file"`
|
||||
CAFile string `toml:"ca_file"`
|
||||
ServerName string `toml:"server_name"` // for client verification
|
||||
SkipVerify bool `toml:"skip_verify"`
|
||||
|
||||
// Client certificate authentication
|
||||
ClientAuth bool `toml:"client_auth"`
|
||||
ClientCAFile string `toml:"client_ca_file"`
|
||||
VerifyClientCert bool `toml:"verify_client_cert"`
|
||||
|
||||
// TLS version constraints
|
||||
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
|
||||
MaxVersion string `toml:"max_version"`
|
||||
|
||||
// Cipher suites (comma-separated list)
|
||||
CipherSuites string `toml:"cipher_suites"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
Interval int64 `toml:"interval_ms"`
|
||||
IncludeTimestamp bool `toml:"include_timestamp"`
|
||||
IncludeStats bool `toml:"include_stats"`
|
||||
Format string `toml:"format"`
|
||||
}
|
||||
|
||||
type ClientAuthConfig struct {
|
||||
Type string `toml:"type"` // "none", "basic", "token", "scram"
|
||||
Username string `toml:"username"`
|
||||
Password string `toml:"password"`
|
||||
Token string `toml:"token"`
|
||||
}
|
||||
|
||||
// --- Source Options ---
|
||||
|
||||
type SourceConfig struct {
|
||||
Type string `toml:"type"`
|
||||
|
||||
// Polymorphic - only one populated based on type
|
||||
Directory *DirectorySourceOptions `toml:"directory,omitempty"`
|
||||
Stdin *StdinSourceOptions `toml:"stdin,omitempty"`
|
||||
HTTP *HTTPSourceOptions `toml:"http,omitempty"`
|
||||
TCP *TCPSourceOptions `toml:"tcp,omitempty"`
|
||||
}
|
||||
|
||||
type DirectorySourceOptions struct {
|
||||
Path string `toml:"path"`
|
||||
Pattern string `toml:"pattern"` // glob pattern
|
||||
CheckIntervalMS int64 `toml:"check_interval_ms"`
|
||||
Recursive bool `toml:"recursive"`
|
||||
FollowSymlinks bool `toml:"follow_symlinks"`
|
||||
DeleteAfterRead bool `toml:"delete_after_read"`
|
||||
MoveToDirectory string `toml:"move_to_directory"` // move after processing
|
||||
}
|
||||
|
||||
type StdinSourceOptions struct {
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
}
|
||||
|
||||
type HTTPSourceOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
IngestPath string `toml:"ingest_path"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
MaxRequestBodySize int64 `toml:"max_body_size"`
|
||||
ReadTimeout int64 `toml:"read_timeout_ms"`
|
||||
WriteTimeout int64 `toml:"write_timeout_ms"`
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
TLS *TLSConfig `toml:"tls"`
|
||||
Auth *ServerAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
type TCPSourceOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
ReadTimeout int64 `toml:"read_timeout_ms"`
|
||||
KeepAlive bool `toml:"keep_alive"`
|
||||
KeepAlivePeriod int64 `toml:"keep_alive_period_ms"`
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
Auth *ServerAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
// --- Sink Options ---
|
||||
|
||||
type SinkConfig struct {
|
||||
Type string `toml:"type"`
|
||||
|
||||
// Polymorphic - only one populated based on type
|
||||
Console *ConsoleSinkOptions `toml:"console,omitempty"`
|
||||
File *FileSinkOptions `toml:"file,omitempty"`
|
||||
HTTP *HTTPSinkOptions `toml:"http,omitempty"`
|
||||
TCP *TCPSinkOptions `toml:"tcp,omitempty"`
|
||||
HTTPClient *HTTPClientSinkOptions `toml:"http_client,omitempty"`
|
||||
TCPClient *TCPClientSinkOptions `toml:"tcp_client,omitempty"`
|
||||
}
|
||||
|
||||
type ConsoleSinkOptions struct {
|
||||
Target string `toml:"target"` // "stdout", "stderr", "split"
|
||||
Colorize bool `toml:"colorize"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
}
|
||||
|
||||
type FileSinkOptions struct {
|
||||
Directory string `toml:"directory"`
|
||||
Name string `toml:"name"`
|
||||
// Extension string `toml:"extension"`
|
||||
MaxSizeMB int64 `toml:"max_size_mb"`
|
||||
MaxTotalSizeMB int64 `toml:"max_total_size_mb"`
|
||||
MinDiskFreeMB int64 `toml:"min_disk_free_mb"`
|
||||
RetentionHours float64 `toml:"retention_hours"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
FlushInterval int64 `toml:"flush_interval_ms"`
|
||||
}
|
||||
|
||||
type HTTPSinkOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
StreamPath string `toml:"stream_path"`
|
||||
StatusPath string `toml:"status_path"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
WriteTimeout int64 `toml:"write_timeout_ms"`
|
||||
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
TLS *TLSConfig `toml:"tls"`
|
||||
Auth *ServerAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
type TCPSinkOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
WriteTimeout int64 `toml:"write_timeout_ms"`
|
||||
KeepAlive bool `toml:"keep_alive"`
|
||||
KeepAlivePeriod int64 `toml:"keep_alive_period_ms"`
|
||||
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
Auth *ServerAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
type HTTPClientSinkOptions struct {
|
||||
URL string `toml:"url"`
|
||||
Headers map[string]string `toml:"headers"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
BatchSize int64 `toml:"batch_size"`
|
||||
BatchDelayMS int64 `toml:"batch_delay_ms"`
|
||||
Timeout int64 `toml:"timeout_seconds"`
|
||||
MaxRetries int64 `toml:"max_retries"`
|
||||
RetryDelayMS int64 `toml:"retry_delay_ms"`
|
||||
RetryBackoff float64 `toml:"retry_backoff"`
|
||||
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
|
||||
TLS *TLSConfig `toml:"tls"`
|
||||
Auth *ClientAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
type TCPClientSinkOptions struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
DialTimeout int64 `toml:"dial_timeout_seconds"`
|
||||
WriteTimeout int64 `toml:"write_timeout_seconds"`
|
||||
ReadTimeout int64 `toml:"read_timeout_seconds"`
|
||||
KeepAlive int64 `toml:"keep_alive_seconds"`
|
||||
ReconnectDelayMS int64 `toml:"reconnect_delay_ms"`
|
||||
MaxReconnectDelayMS int64 `toml:"max_reconnect_delay_ms"`
|
||||
ReconnectBackoff float64 `toml:"reconnect_backoff"`
|
||||
Auth *ClientAuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
// --- Rate Limit Options ---
|
||||
|
||||
// Defines the action to take when a rate limit is exceeded.
|
||||
type RateLimitPolicy int
|
||||
|
||||
const (
|
||||
// PolicyPass allows all logs through, effectively disabling the limiter.
|
||||
PolicyPass RateLimitPolicy = iota
|
||||
// PolicyDrop drops logs that exceed the rate limit.
|
||||
PolicyDrop
|
||||
)
|
||||
|
||||
// Defines the configuration for pipeline-level rate limiting.
|
||||
type RateLimitConfig struct {
|
||||
// Rate is the number of log entries allowed per second. Default: 0 (disabled).
|
||||
Rate float64 `toml:"rate"`
|
||||
// Burst is the maximum number of log entries that can be sent in a short burst. Defaults to the Rate.
|
||||
Burst float64 `toml:"burst"`
|
||||
// Policy defines the action to take when the limit is exceeded. "pass" or "drop".
|
||||
Policy string `toml:"policy"`
|
||||
// MaxEntrySizeBytes is the maximum allowed size for a single log entry. 0 = no limit.
|
||||
MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"`
|
||||
}
|
||||
|
||||
// --- Filter Options ---
|
||||
|
||||
// Represents the filter type
|
||||
type FilterType string
|
||||
|
||||
const (
|
||||
FilterTypeInclude FilterType = "include" // Whitelist - only matching logs pass
|
||||
FilterTypeExclude FilterType = "exclude" // Blacklist - matching logs are dropped
|
||||
)
|
||||
|
||||
// Represents how multiple patterns are combined
|
||||
type FilterLogic string
|
||||
|
||||
const (
|
||||
FilterLogicOr FilterLogic = "or" // Match any pattern
|
||||
FilterLogicAnd FilterLogic = "and" // Match all patterns
|
||||
)
|
||||
|
||||
// Represents filter configuration
|
||||
type FilterConfig struct {
|
||||
Type FilterType `toml:"type"`
|
||||
Logic FilterLogic `toml:"logic"`
|
||||
Patterns []string `toml:"patterns"`
|
||||
}
|
||||
|
||||
// --- Formatter Options ---
|
||||
|
||||
type FormatConfig struct {
|
||||
// Format configuration - polymorphic like sources/sinks
|
||||
Type string `toml:"type"` // "json", "text", "raw"
|
||||
|
||||
// Only one will be populated based on format type
|
||||
JSONFormatOptions *JSONFormatterOptions `toml:"json_format,omitempty"`
|
||||
TextFormatOptions *TextFormatterOptions `toml:"text_format,omitempty"`
|
||||
RawFormatOptions *RawFormatterOptions `toml:"raw_format,omitempty"`
|
||||
}
|
||||
|
||||
type JSONFormatterOptions struct {
|
||||
Pretty bool `toml:"pretty"`
|
||||
TimestampField string `toml:"timestamp_field"`
|
||||
LevelField string `toml:"level_field"`
|
||||
MessageField string `toml:"message_field"`
|
||||
SourceField string `toml:"source_field"`
|
||||
}
|
||||
|
||||
type TextFormatterOptions struct {
|
||||
Template string `toml:"template"`
|
||||
TimestampFormat string `toml:"timestamp_format"`
|
||||
}
|
||||
|
||||
type RawFormatterOptions struct {
|
||||
AddNewLine bool `toml:"add_new_line"`
|
||||
}
|
||||
|
||||
// --- Server-side Auth (for sources) ---
|
||||
|
||||
type BasicAuthConfig struct {
|
||||
Users []BasicAuthUser `toml:"users"`
|
||||
Realm string `toml:"realm"`
|
||||
}
|
||||
|
||||
type BasicAuthUser struct {
|
||||
Username string `toml:"username"`
|
||||
PasswordHash string `toml:"password_hash"` // Argon2
|
||||
}
|
||||
|
||||
type ScramAuthConfig struct {
|
||||
Users []ScramUser `toml:"users"`
|
||||
}
|
||||
|
||||
type ScramUser struct {
|
||||
Username string `toml:"username"`
|
||||
StoredKey string `toml:"stored_key"` // base64
|
||||
ServerKey string `toml:"server_key"` // base64
|
||||
Salt string `toml:"salt"` // base64
|
||||
ArgonTime uint32 `toml:"argon_time"`
|
||||
ArgonMemory uint32 `toml:"argon_memory"`
|
||||
ArgonThreads uint8 `toml:"argon_threads"`
|
||||
}
|
||||
|
||||
type TokenAuthConfig struct {
|
||||
Tokens []string `toml:"tokens"`
|
||||
}
|
||||
|
||||
// Server auth wrapper (for sources accepting connections)
|
||||
type ServerAuthConfig struct {
|
||||
Type string `toml:"type"` // "none", "basic", "token", "scram"
|
||||
Basic *BasicAuthConfig `toml:"basic,omitempty"`
|
||||
Token *TokenAuthConfig `toml:"token,omitempty"`
|
||||
Scram *ScramAuthConfig `toml:"scram,omitempty"`
|
||||
}
|
||||
@ -1,65 +0,0 @@
|
||||
// FILE: logwisp/src/internal/config/filter.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// Represents the filter type
|
||||
type FilterType string
|
||||
|
||||
const (
|
||||
FilterTypeInclude FilterType = "include" // Whitelist - only matching logs pass
|
||||
FilterTypeExclude FilterType = "exclude" // Blacklist - matching logs are dropped
|
||||
)
|
||||
|
||||
// Represents how multiple patterns are combined
|
||||
type FilterLogic string
|
||||
|
||||
const (
|
||||
FilterLogicOr FilterLogic = "or" // Match any pattern
|
||||
FilterLogicAnd FilterLogic = "and" // Match all patterns
|
||||
)
|
||||
|
||||
// Represents filter configuration
|
||||
type FilterConfig struct {
|
||||
Type FilterType `toml:"type"`
|
||||
Logic FilterLogic `toml:"logic"`
|
||||
Patterns []string `toml:"patterns"`
|
||||
}
|
||||
|
||||
func validateFilter(pipelineName string, filterIndex int, cfg *FilterConfig) error {
|
||||
// Validate filter type
|
||||
switch cfg.Type {
|
||||
case FilterTypeInclude, FilterTypeExclude, "":
|
||||
// Valid types
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' filter[%d]: invalid type '%s' (must be 'include' or 'exclude')",
|
||||
pipelineName, filterIndex, cfg.Type)
|
||||
}
|
||||
|
||||
// Validate filter logic
|
||||
switch cfg.Logic {
|
||||
case FilterLogicOr, FilterLogicAnd, "":
|
||||
// Valid logic
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' filter[%d]: invalid logic '%s' (must be 'or' or 'and')",
|
||||
pipelineName, filterIndex, cfg.Logic)
|
||||
}
|
||||
|
||||
// Empty patterns is valid - passes everything
|
||||
if len(cfg.Patterns) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate regex patterns
|
||||
for i, pattern := range cfg.Patterns {
|
||||
if _, err := regexp.Compile(pattern); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' filter[%d] pattern[%d] '%s': invalid regex: %w",
|
||||
pipelineName, filterIndex, i, pattern, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,58 +0,0 @@
|
||||
// FILE: logwisp/src/internal/config/ratelimit.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Defines the action to take when a rate limit is exceeded.
|
||||
type RateLimitPolicy int
|
||||
|
||||
const (
|
||||
// PolicyPass allows all logs through, effectively disabling the limiter.
|
||||
PolicyPass RateLimitPolicy = iota
|
||||
// PolicyDrop drops logs that exceed the rate limit.
|
||||
PolicyDrop
|
||||
)
|
||||
|
||||
// Defines the configuration for pipeline-level rate limiting.
|
||||
type RateLimitConfig struct {
|
||||
// Rate is the number of log entries allowed per second. Default: 0 (disabled).
|
||||
Rate float64 `toml:"rate"`
|
||||
// Burst is the maximum number of log entries that can be sent in a short burst. Defaults to the Rate.
|
||||
Burst float64 `toml:"burst"`
|
||||
// Policy defines the action to take when the limit is exceeded. "pass" or "drop".
|
||||
Policy string `toml:"policy"`
|
||||
// MaxEntrySizeBytes is the maximum allowed size for a single log entry. 0 = no limit.
|
||||
MaxEntrySizeBytes int64 `toml:"max_entry_size_bytes"`
|
||||
}
|
||||
|
||||
func validateRateLimit(pipelineName string, cfg *RateLimitConfig) error {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.Rate < 0 {
|
||||
return fmt.Errorf("pipeline '%s': rate limit rate cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
if cfg.Burst < 0 {
|
||||
return fmt.Errorf("pipeline '%s': rate limit burst cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
if cfg.MaxEntrySizeBytes < 0 {
|
||||
return fmt.Errorf("pipeline '%s': max entry size bytes cannot be negative", pipelineName)
|
||||
}
|
||||
|
||||
// Validate policy
|
||||
switch strings.ToLower(cfg.Policy) {
|
||||
case "", "pass", "drop":
|
||||
// Valid policies
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s': invalid rate limit policy '%s' (must be 'pass' or 'drop')",
|
||||
pipelineName, cfg.Policy)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -11,6 +11,8 @@ import (
|
||||
lconfig "github.com/lixenwraith/config"
|
||||
)
|
||||
|
||||
var configManager *lconfig.Config
|
||||
|
||||
func defaults() *Config {
|
||||
return &Config{
|
||||
// Top-level flag defaults
|
||||
@ -21,41 +23,46 @@ func defaults() *Config {
|
||||
// Runtime behavior defaults
|
||||
DisableStatusReporter: false,
|
||||
ConfigAutoReload: false,
|
||||
ConfigSaveOnExit: false,
|
||||
|
||||
// Child process indicator
|
||||
BackgroundDaemon: false,
|
||||
|
||||
// Existing defaults
|
||||
Logging: DefaultLogConfig(),
|
||||
Logging: &LogConfig{
|
||||
Output: "stdout",
|
||||
Level: "info",
|
||||
File: &LogFileConfig{
|
||||
Directory: "./log",
|
||||
Name: "logwisp",
|
||||
MaxSizeMB: 100,
|
||||
MaxTotalSizeMB: 1000,
|
||||
RetentionHours: 168, // 7 days
|
||||
},
|
||||
Console: &LogConsoleConfig{
|
||||
Target: "stdout",
|
||||
Format: "txt",
|
||||
},
|
||||
},
|
||||
Pipelines: []PipelineConfig{
|
||||
{
|
||||
Name: "default",
|
||||
Sources: []SourceConfig{
|
||||
{
|
||||
Type: "directory",
|
||||
Options: map[string]any{
|
||||
"path": "./",
|
||||
"pattern": "*.log",
|
||||
"check_interval_ms": int64(100),
|
||||
Directory: &DirectorySourceOptions{
|
||||
Path: "./",
|
||||
Pattern: "*.log",
|
||||
CheckIntervalMS: int64(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
Sinks: []SinkConfig{
|
||||
{
|
||||
Type: "http",
|
||||
Options: map[string]any{
|
||||
"port": int64(8080),
|
||||
"buffer_size": int64(1000),
|
||||
"stream_path": "/stream",
|
||||
"status_path": "/status",
|
||||
"heartbeat": map[string]any{
|
||||
"enabled": true,
|
||||
"interval_seconds": int64(30),
|
||||
"include_timestamp": true,
|
||||
"include_stats": false,
|
||||
"format": "comment",
|
||||
},
|
||||
Type: "console",
|
||||
Console: &ConsoleSinkOptions{
|
||||
Target: "stdout",
|
||||
Colorize: false,
|
||||
BufferSize: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -68,18 +75,30 @@ func defaults() *Config {
|
||||
func Load(args []string) (*Config, error) {
|
||||
configPath, isExplicit := resolveConfigPath(args)
|
||||
// Build configuration with all sources
|
||||
|
||||
// Create target config instance that will be populated
|
||||
finalConfig := &Config{}
|
||||
|
||||
// The builder now handles loading, populating the target struct, and validation
|
||||
cfg, err := lconfig.NewBuilder().
|
||||
WithDefaults(defaults()).
|
||||
WithEnvPrefix("LOGWISP_").
|
||||
WithEnvTransform(customEnvTransform).
|
||||
WithArgs(args).
|
||||
WithFile(configPath).
|
||||
WithTarget(finalConfig). // Typed target struct
|
||||
WithDefaults(defaults()). // Default values
|
||||
WithSources(
|
||||
lconfig.SourceCLI,
|
||||
lconfig.SourceEnv,
|
||||
lconfig.SourceFile,
|
||||
lconfig.SourceDefault,
|
||||
).
|
||||
WithEnvTransform(customEnvTransform). // Convert '.' to '_' in env separation
|
||||
WithEnvPrefix("LOGWISP_"). // Environment variable prefix
|
||||
WithArgs(args). // Command-line arguments
|
||||
WithFile(configPath). // TOML config file
|
||||
WithFileFormat("toml"). // Explicit format
|
||||
WithTypedValidator(validateConfig). // Centralized validation
|
||||
WithSecurityOptions(lconfig.SecurityOptions{
|
||||
PreventPathTraversal: true,
|
||||
MaxFileSize: 10 * 1024 * 1024, // 10MB max config
|
||||
}).
|
||||
Build()
|
||||
|
||||
if err != nil {
|
||||
@ -88,42 +107,28 @@ func Load(args []string) (*Config, error) {
|
||||
if isExplicit {
|
||||
return nil, fmt.Errorf("config file not found: %s", configPath)
|
||||
}
|
||||
// If the default config file is not found, it's not an error
|
||||
// If the default config file is not found, it's not an error, default/cli/env will be used
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to load config: %w", err)
|
||||
return nil, fmt.Errorf("failed to load or validate config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Scan into final config struct - using new interface
|
||||
finalConfig := &Config{}
|
||||
if err := cfg.Scan(finalConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan config: %w", err)
|
||||
}
|
||||
|
||||
// Set config file path if it exists
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
// Store the config file path for hot reload
|
||||
finalConfig.ConfigFile = configPath
|
||||
|
||||
// Store the manager for hot reload
|
||||
if cfg != nil {
|
||||
configManager = cfg
|
||||
}
|
||||
|
||||
// Ensure critical fields are not nil
|
||||
if finalConfig.Logging == nil {
|
||||
finalConfig.Logging = DefaultLogConfig()
|
||||
}
|
||||
|
||||
// Apply console target overrides if needed
|
||||
if err := applyConsoleTargetOverrides(finalConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply console target overrides: %w", err)
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
return finalConfig, finalConfig.validate()
|
||||
return finalConfig, nil
|
||||
}
|
||||
|
||||
// Returns the configuration file path
|
||||
func resolveConfigPath(args []string) (path string, isExplicit bool) {
|
||||
// 1. Check for --config flag in command-line arguments (highest precedence)
|
||||
for i, arg := range args {
|
||||
if (arg == "--config" || arg == "-c") && i+1 < len(args) {
|
||||
if arg == "-c" {
|
||||
return args[i+1], true
|
||||
}
|
||||
if strings.HasPrefix(arg, "--config=") {
|
||||
@ -161,37 +166,3 @@ func customEnvTransform(path string) string {
|
||||
// env = "LOGWISP_" + env // already added by WithEnvPrefix
|
||||
return env
|
||||
}
|
||||
|
||||
// Centralizes console target configuration
|
||||
func applyConsoleTargetOverrides(cfg *Config) error {
|
||||
// Check environment variable for console target override
|
||||
consoleTarget := os.Getenv("LOGWISP_CONSOLE_TARGET")
|
||||
if consoleTarget == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate console target value
|
||||
validTargets := map[string]bool{
|
||||
"stdout": true,
|
||||
"stderr": true,
|
||||
"split": true,
|
||||
}
|
||||
if !validTargets[consoleTarget] {
|
||||
return fmt.Errorf("invalid LOGWISP_CONSOLE_TARGET value: %s", consoleTarget)
|
||||
}
|
||||
|
||||
// Apply to console sinks
|
||||
for i, pipeline := range cfg.Pipelines {
|
||||
for j, sink := range pipeline.Sinks {
|
||||
if sink.Type == "console" {
|
||||
if sink.Options == nil {
|
||||
cfg.Pipelines[i].Sinks[j].Options = make(map[string]any)
|
||||
}
|
||||
// Set target for split mode handling
|
||||
cfg.Pipelines[i].Sinks[j].Options["target"] = consoleTarget
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,99 +0,0 @@
|
||||
// FILE: logwisp/src/internal/config/logging.go
|
||||
package config
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Represents logging configuration for LogWisp
|
||||
type LogConfig struct {
|
||||
// Output mode: "file", "stdout", "stderr", "split", "all", "none"
|
||||
Output string `toml:"output"`
|
||||
|
||||
// Log level: "debug", "info", "warn", "error"
|
||||
Level string `toml:"level"`
|
||||
|
||||
// File output settings (when Output includes "file" or "all")
|
||||
File *LogFileConfig `toml:"file"`
|
||||
|
||||
// Console output settings
|
||||
Console *LogConsoleConfig `toml:"console"`
|
||||
}
|
||||
|
||||
type LogFileConfig struct {
|
||||
// Directory for log files
|
||||
Directory string `toml:"directory"`
|
||||
|
||||
// Base name for log files
|
||||
Name string `toml:"name"`
|
||||
|
||||
// Maximum size per log file in MB
|
||||
MaxSizeMB int64 `toml:"max_size_mb"`
|
||||
|
||||
// Maximum total size of all logs in MB
|
||||
MaxTotalSizeMB int64 `toml:"max_total_size_mb"`
|
||||
|
||||
// Log retention in hours (0 = disabled)
|
||||
RetentionHours float64 `toml:"retention_hours"`
|
||||
}
|
||||
|
||||
type LogConsoleConfig struct {
|
||||
// Target for console output: "stdout", "stderr", "split"
|
||||
// "split": info/debug to stdout, warn/error to stderr
|
||||
Target string `toml:"target"`
|
||||
|
||||
// Format: "txt" or "json"
|
||||
Format string `toml:"format"`
|
||||
}
|
||||
|
||||
// Returns sensible logging defaults
|
||||
func DefaultLogConfig() *LogConfig {
|
||||
return &LogConfig{
|
||||
Output: "stdout",
|
||||
Level: "info",
|
||||
File: &LogFileConfig{
|
||||
Directory: "./log",
|
||||
Name: "logwisp",
|
||||
MaxSizeMB: 100,
|
||||
MaxTotalSizeMB: 1000,
|
||||
RetentionHours: 168, // 7 days
|
||||
},
|
||||
Console: &LogConsoleConfig{
|
||||
Target: "stdout",
|
||||
Format: "txt",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validateLogConfig(cfg *LogConfig) error {
|
||||
validOutputs := map[string]bool{
|
||||
"file": true, "stdout": true, "stderr": true,
|
||||
"split": true, "all": true, "none": true,
|
||||
}
|
||||
if !validOutputs[cfg.Output] {
|
||||
return fmt.Errorf("invalid log output mode: %s", cfg.Output)
|
||||
}
|
||||
|
||||
validLevels := map[string]bool{
|
||||
"debug": true, "info": true, "warn": true, "error": true,
|
||||
}
|
||||
if !validLevels[cfg.Level] {
|
||||
return fmt.Errorf("invalid log level: %s", cfg.Level)
|
||||
}
|
||||
|
||||
if cfg.Console != nil {
|
||||
validTargets := map[string]bool{
|
||||
"stdout": true, "stderr": true, "split": true,
|
||||
}
|
||||
if !validTargets[cfg.Console.Target] {
|
||||
return fmt.Errorf("invalid console target: %s", cfg.Console.Target)
|
||||
}
|
||||
|
||||
validFormats := map[string]bool{
|
||||
"txt": true, "json": true, "": true,
|
||||
}
|
||||
if !validFormats[cfg.Console.Format] {
|
||||
return fmt.Errorf("invalid console format: %s", cfg.Console.Format)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,416 +0,0 @@
|
||||
// FILE: logwisp/src/internal/config/pipeline.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Represents a data processing pipeline
|
||||
type PipelineConfig struct {
|
||||
// Pipeline identifier (used in logs and metrics)
|
||||
Name string `toml:"name"`
|
||||
|
||||
// Data sources for this pipeline
|
||||
Sources []SourceConfig `toml:"sources"`
|
||||
|
||||
// Rate limiting
|
||||
RateLimit *RateLimitConfig `toml:"rate_limit"`
|
||||
|
||||
// Filter configuration
|
||||
Filters []FilterConfig `toml:"filters"`
|
||||
|
||||
// Log formatting configuration
|
||||
Format string `toml:"format"`
|
||||
FormatOptions map[string]any `toml:"format_options"`
|
||||
|
||||
// Output sinks for this pipeline
|
||||
Sinks []SinkConfig `toml:"sinks"`
|
||||
|
||||
// Authentication/Authorization (applies to network sinks)
|
||||
Auth *AuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
// Represents an input data source
|
||||
type SourceConfig struct {
|
||||
// Source type
|
||||
Type string `toml:"type"`
|
||||
|
||||
// Type-specific configuration options
|
||||
Options map[string]any `toml:"options"`
|
||||
}
|
||||
|
||||
// Represents an output destination
|
||||
type SinkConfig struct {
|
||||
// Sink type
|
||||
Type string `toml:"type"`
|
||||
|
||||
// Type-specific configuration options
|
||||
Options map[string]any `toml:"options"`
|
||||
}
|
||||
|
||||
func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) error {
|
||||
if cfg.Type == "" {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: missing type", pipelineName, sourceIndex)
|
||||
}
|
||||
|
||||
switch cfg.Type {
|
||||
case "directory":
|
||||
// Validate path
|
||||
path, ok := cfg.Options["path"].(string)
|
||||
if !ok || path == "" {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: directory source requires 'path' option",
|
||||
pipelineName, sourceIndex)
|
||||
}
|
||||
|
||||
// Check for directory traversal
|
||||
if strings.Contains(path, "..") {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: path contains directory traversal",
|
||||
pipelineName, sourceIndex)
|
||||
}
|
||||
|
||||
// Validate pattern
|
||||
if pattern, ok := cfg.Options["pattern"].(string); ok && pattern != "" {
|
||||
// Try to compile as glob pattern (will be converted to regex internally)
|
||||
if strings.Count(pattern, "*") == 0 && strings.Count(pattern, "?") == 0 {
|
||||
// If no wildcards, ensure it's a valid filename
|
||||
if filepath.Base(pattern) != pattern {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: pattern contains path separators",
|
||||
pipelineName, sourceIndex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate check interval
|
||||
if interval, ok := cfg.Options["check_interval_ms"]; ok {
|
||||
if intVal, ok := interval.(int64); ok {
|
||||
if intVal < 10 {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: check interval too small: %d ms (min: 10ms)",
|
||||
pipelineName, sourceIndex, intVal)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: invalid check_interval_ms type",
|
||||
pipelineName, sourceIndex)
|
||||
}
|
||||
}
|
||||
|
||||
case "stdin":
|
||||
// Validate buffer size
|
||||
if bufSize, ok := cfg.Options["buffer_size"].(int64); ok {
|
||||
if bufSize < 1 {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: stdin buffer_size must be positive: %d",
|
||||
pipelineName, sourceIndex, bufSize)
|
||||
}
|
||||
}
|
||||
|
||||
case "http":
|
||||
// Validate host
|
||||
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
||||
if net.ParseIP(host) == nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: invalid IP address: %s",
|
||||
pipelineName, sourceIndex, host)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate port
|
||||
port, ok := cfg.Options["port"].(int64)
|
||||
if !ok || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: invalid or missing HTTP port",
|
||||
pipelineName, sourceIndex)
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if path, ok := cfg.Options["ingest_path"].(string); ok {
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: ingest path must start with /: %s",
|
||||
pipelineName, sourceIndex, path)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate net_limit
|
||||
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("HTTP source", pipelineName, sourceIndex, nl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate TLS
|
||||
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
|
||||
if err := validateTLSOptions("HTTP source", pipelineName, sourceIndex, tls); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
case "tcp":
|
||||
// Validate host
|
||||
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
||||
if net.ParseIP(host) == nil {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: invalid IP address: %s",
|
||||
pipelineName, sourceIndex, host)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate port
|
||||
port, ok := cfg.Options["port"].(int64)
|
||||
if !ok || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: invalid or missing TCP port",
|
||||
pipelineName, sourceIndex)
|
||||
}
|
||||
|
||||
// Validate net_limit
|
||||
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("TCP source", pipelineName, sourceIndex, nl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' source[%d]: unknown source type '%s'",
|
||||
pipelineName, sourceIndex, cfg.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts map[int64]string) error {
|
||||
if cfg.Type == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: missing type", pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
switch cfg.Type {
|
||||
case "http":
|
||||
// Extract and validate HTTP configuration
|
||||
port, ok := cfg.Options["port"].(int64)
|
||||
if !ok || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid or missing HTTP port",
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate host
|
||||
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
||||
if net.ParseIP(host) == nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
|
||||
pipelineName, sinkIndex, host)
|
||||
}
|
||||
}
|
||||
|
||||
// Check port conflicts
|
||||
if existing, exists := allPorts[port]; exists {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: HTTP port %d already used by %s",
|
||||
pipelineName, sinkIndex, port, existing)
|
||||
}
|
||||
allPorts[port] = fmt.Sprintf("%s-http[%d]", pipelineName, sinkIndex)
|
||||
|
||||
// Validate buffer size
|
||||
if bufSize, ok := cfg.Options["buffer_size"].(int64); ok {
|
||||
if bufSize < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: HTTP buffer size must be positive: %d",
|
||||
pipelineName, sinkIndex, bufSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate paths
|
||||
if streamPath, ok := cfg.Options["stream_path"].(string); ok {
|
||||
if !strings.HasPrefix(streamPath, "/") {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: stream path must start with /: %s",
|
||||
pipelineName, sinkIndex, streamPath)
|
||||
}
|
||||
}
|
||||
|
||||
if statusPath, ok := cfg.Options["status_path"].(string); ok {
|
||||
if !strings.HasPrefix(statusPath, "/") {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: status path must start with /: %s",
|
||||
pipelineName, sinkIndex, statusPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate heartbeat
|
||||
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
|
||||
if err := validateHeartbeatOptions("HTTP", pipelineName, sinkIndex, hb); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate TLS if present
|
||||
if tls, ok := cfg.Options["tls"].(map[string]any); ok {
|
||||
if err := validateTLSOptions("HTTP", pipelineName, sinkIndex, tls); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate net limit
|
||||
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("HTTP", pipelineName, sinkIndex, nl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
case "tcp":
|
||||
// Extract and validate TCP configuration
|
||||
port, ok := cfg.Options["port"].(int64)
|
||||
if !ok || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid or missing TCP port",
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate host
|
||||
if host, ok := cfg.Options["host"].(string); ok && host != "" {
|
||||
if net.ParseIP(host) == nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid IP address: %s",
|
||||
pipelineName, sinkIndex, host)
|
||||
}
|
||||
}
|
||||
|
||||
// Check port conflicts
|
||||
if existing, exists := allPorts[port]; exists {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: TCP port %d already used by %s",
|
||||
pipelineName, sinkIndex, port, existing)
|
||||
}
|
||||
allPorts[port] = fmt.Sprintf("%s-tcp[%d]", pipelineName, sinkIndex)
|
||||
|
||||
// Validate buffer size
|
||||
if bufSize, ok := cfg.Options["buffer_size"].(int64); ok {
|
||||
if bufSize < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: TCP buffer size must be positive: %d",
|
||||
pipelineName, sinkIndex, bufSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate heartbeat
|
||||
if hb, ok := cfg.Options["heartbeat"].(map[string]any); ok {
|
||||
if err := validateHeartbeatOptions("TCP", pipelineName, sinkIndex, hb); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate net limit
|
||||
if nl, ok := cfg.Options["net_limit"].(map[string]any); ok {
|
||||
if err := validateNetLimitOptions("TCP", pipelineName, sinkIndex, nl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
case "http_client":
|
||||
// Validate URL
|
||||
urlStr, ok := cfg.Options["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: http_client sink requires 'url' option",
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate URL format
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid URL: %w",
|
||||
pipelineName, sinkIndex, err)
|
||||
}
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: URL must use http or https scheme",
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate batch size
|
||||
if batchSize, ok := cfg.Options["batch_size"].(int64); ok {
|
||||
if batchSize < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: batch_size must be positive: %d",
|
||||
pipelineName, sinkIndex, batchSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate timeout
|
||||
if timeout, ok := cfg.Options["timeout_seconds"].(int64); ok {
|
||||
if timeout < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: timeout_seconds must be positive: %d",
|
||||
pipelineName, sinkIndex, timeout)
|
||||
}
|
||||
}
|
||||
|
||||
case "tcp_client":
|
||||
// Added validation for TCP client sink
|
||||
// Validate address
|
||||
address, ok := cfg.Options["address"].(string)
|
||||
if !ok || address == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: tcp_client sink requires 'address' option",
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate address format
|
||||
_, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: invalid address format (expected host:port): %w",
|
||||
pipelineName, sinkIndex, err)
|
||||
}
|
||||
|
||||
// Validate timeouts
|
||||
if dialTimeout, ok := cfg.Options["dial_timeout_seconds"].(int64); ok {
|
||||
if dialTimeout < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: dial_timeout_seconds must be positive: %d",
|
||||
pipelineName, sinkIndex, dialTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
if writeTimeout, ok := cfg.Options["write_timeout_seconds"].(int64); ok {
|
||||
if writeTimeout < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: write_timeout_seconds must be positive: %d",
|
||||
pipelineName, sinkIndex, writeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
// Validate directory
|
||||
directory, ok := cfg.Options["directory"].(string)
|
||||
if !ok || directory == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: file sink requires 'directory' option",
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate filename
|
||||
name, ok := cfg.Options["name"].(string)
|
||||
if !ok || name == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: file sink requires 'name' option",
|
||||
pipelineName, sinkIndex)
|
||||
}
|
||||
|
||||
// Validate size options
|
||||
if maxSize, ok := cfg.Options["max_size_mb"].(int64); ok {
|
||||
if maxSize < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: max_size_mb must be positive: %d",
|
||||
pipelineName, sinkIndex, maxSize)
|
||||
}
|
||||
}
|
||||
|
||||
if maxTotalSize, ok := cfg.Options["max_total_size_mb"].(int64); ok {
|
||||
if maxTotalSize < 0 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: max_total_size_mb cannot be negative: %d",
|
||||
pipelineName, sinkIndex, maxTotalSize)
|
||||
}
|
||||
}
|
||||
|
||||
if minDiskFree, ok := cfg.Options["min_disk_free_mb"].(int64); ok {
|
||||
if minDiskFree < 0 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: min_disk_free_mb cannot be negative: %d",
|
||||
pipelineName, sinkIndex, minDiskFree)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate retention period
|
||||
if retention, ok := cfg.Options["retention_hours"].(float64); ok {
|
||||
if retention < 0 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: retention_hours cannot be negative: %f",
|
||||
pipelineName, sinkIndex, retention)
|
||||
}
|
||||
}
|
||||
|
||||
case "console":
|
||||
// No specific validation needed for console sinks
|
||||
|
||||
default:
|
||||
return fmt.Errorf("pipeline '%s' sink[%d]: unknown sink type '%s'",
|
||||
pipelineName, sinkIndex, cfg.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,33 +0,0 @@
|
||||
// FILE: logwisp/src/internal/config/saver.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
lconfig "github.com/lixenwraith/config"
|
||||
)
|
||||
|
||||
// Saves the configuration to the specified file path.
|
||||
func (c *Config) SaveToFile(path string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("cannot save config: path is empty")
|
||||
}
|
||||
|
||||
// Create a temporary lconfig instance just for saving
|
||||
// This avoids the need to track lconfig throughout the application
|
||||
lcfg, err := lconfig.NewBuilder().
|
||||
WithFile(path).
|
||||
WithTarget(c).
|
||||
WithFileFormat("toml").
|
||||
Build()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config builder: %w", err)
|
||||
}
|
||||
|
||||
// Use lconfig's Save method which handles atomic writes
|
||||
if err := lcfg.Save(path); err != nil {
|
||||
return fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,203 +0,0 @@
|
||||
// FILE: logwisp/src/internal/config/server.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TCPConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
Port int64 `toml:"port"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
|
||||
// Net limiting
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
|
||||
// Heartbeat
|
||||
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
|
||||
}
|
||||
|
||||
type HTTPConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
Port int64 `toml:"port"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
|
||||
// Endpoint paths
|
||||
StreamPath string `toml:"stream_path"`
|
||||
StatusPath string `toml:"status_path"`
|
||||
|
||||
// TLS Configuration
|
||||
TLS *TLSConfig `toml:"tls"`
|
||||
|
||||
// Nate limiting
|
||||
NetLimit *NetLimitConfig `toml:"net_limit"`
|
||||
|
||||
// Heartbeat
|
||||
Heartbeat *HeartbeatConfig `toml:"heartbeat"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
IntervalSeconds int64 `toml:"interval_seconds"`
|
||||
IncludeTimestamp bool `toml:"include_timestamp"`
|
||||
IncludeStats bool `toml:"include_stats"`
|
||||
Format string `toml:"format"`
|
||||
}
|
||||
|
||||
type NetLimitConfig struct {
|
||||
// Enable net limiting
|
||||
Enabled bool `toml:"enabled"`
|
||||
|
||||
// IP Access Control Lists
|
||||
IPWhitelist []string `toml:"ip_whitelist"`
|
||||
IPBlacklist []string `toml:"ip_blacklist"`
|
||||
|
||||
// Requests per second per client
|
||||
RequestsPerSecond float64 `toml:"requests_per_second"`
|
||||
|
||||
// Burst size (token bucket)
|
||||
BurstSize int64 `toml:"burst_size"`
|
||||
|
||||
// Response when net limited
|
||||
ResponseCode int64 `toml:"response_code"` // Default: 429
|
||||
ResponseMessage string `toml:"response_message"` // Default: "Net limit exceeded"
|
||||
|
||||
// Connection limits
|
||||
MaxConnectionsPerIP int64 `toml:"max_connections_per_ip"`
|
||||
MaxConnectionsPerUser int64 `toml:"max_connections_per_user"`
|
||||
MaxConnectionsPerToken int64 `toml:"max_connections_per_token"`
|
||||
MaxConnectionsTotal int64 `toml:"max_connections_total"`
|
||||
}
|
||||
|
||||
func validateHeartbeatOptions(serverType, pipelineName string, sinkIndex int, hb map[string]any) error {
|
||||
if enabled, ok := hb["enabled"].(bool); ok && enabled {
|
||||
interval, ok := hb["interval_seconds"].(int64)
|
||||
if !ok || interval < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: heartbeat interval must be positive",
|
||||
pipelineName, sinkIndex, serverType)
|
||||
}
|
||||
|
||||
if format, ok := hb["format"].(string); ok {
|
||||
if format != "json" && format != "comment" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: heartbeat format must be 'json' or 'comment': %s",
|
||||
pipelineName, sinkIndex, serverType, format)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNetLimitOptions(serverType, pipelineName string, sinkIndex int, nl map[string]any) error {
|
||||
if enabled, ok := nl["enabled"].(bool); !ok || !enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate IP lists if present
|
||||
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||
for i, entry := range ipWhitelist {
|
||||
entryStr, ok := entry.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := validateIPv4Entry(entryStr); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: whitelist[%d] %v",
|
||||
pipelineName, sinkIndex, serverType, i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||
for i, entry := range ipBlacklist {
|
||||
entryStr, ok := entry.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := validateIPv4Entry(entryStr); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: blacklist[%d] %v",
|
||||
pipelineName, sinkIndex, serverType, i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate requests per second
|
||||
rps, ok := nl["requests_per_second"].(float64)
|
||||
if !ok || rps <= 0 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: requests_per_second must be positive",
|
||||
pipelineName, sinkIndex, serverType)
|
||||
}
|
||||
|
||||
// Validate burst size
|
||||
burst, ok := nl["burst_size"].(int64)
|
||||
if !ok || burst < 1 {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: burst_size must be at least 1",
|
||||
pipelineName, sinkIndex, serverType)
|
||||
}
|
||||
|
||||
// Validate response code
|
||||
if respCode, ok := nl["response_code"].(int64); ok {
|
||||
if respCode > 0 && (respCode < 400 || respCode >= 600) {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: response_code must be 4xx or 5xx: %d",
|
||||
pipelineName, sinkIndex, serverType, respCode)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate connection limits
|
||||
maxPerIP, perIPOk := nl["max_connections_per_ip"].(int64)
|
||||
maxPerUser, perUserOk := nl["max_connections_per_user"].(int64)
|
||||
maxPerToken, perTokenOk := nl["max_connections_per_token"].(int64)
|
||||
maxTotal, totalOk := nl["max_connections_total"].(int64)
|
||||
|
||||
if perIPOk && perUserOk && perTokenOk && totalOk &&
|
||||
maxPerIP > 0 && maxPerUser > 0 && maxPerToken > 0 && maxTotal > 0 {
|
||||
if maxPerIP > maxTotal {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_ip (%d) cannot exceed max_connections_total (%d)",
|
||||
pipelineName, sinkIndex, serverType, maxPerIP, maxTotal)
|
||||
}
|
||||
if maxPerUser > maxTotal {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_user (%d) cannot exceed max_connections_total (%d)",
|
||||
pipelineName, sinkIndex, serverType, maxPerUser, maxTotal)
|
||||
}
|
||||
if maxPerToken > maxTotal {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: max_connections_per_token (%d) cannot exceed max_connections_total (%d)",
|
||||
pipelineName, sinkIndex, serverType, maxPerToken, maxTotal)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensures an IP or CIDR is IPv4
|
||||
func validateIPv4Entry(entry string) error {
|
||||
// Handle single IP
|
||||
if !strings.Contains(entry, "/") {
|
||||
ip := net.ParseIP(entry)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid IP address: %s", entry)
|
||||
}
|
||||
if ip.To4() == nil {
|
||||
return fmt.Errorf("IPv6 not supported (IPv4-only): %s", entry)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle CIDR
|
||||
ipAddr, ipNet, err := net.ParseCIDR(entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR: %s", entry)
|
||||
}
|
||||
|
||||
// Check if the IP is IPv4
|
||||
if ipAddr.To4() == nil {
|
||||
return fmt.Errorf("IPv6 CIDR not supported (IPv4-only): %s", entry)
|
||||
}
|
||||
|
||||
// Verify the network mask is appropriate for IPv4
|
||||
_, bits := ipNet.Mask.Size()
|
||||
if bits != 32 {
|
||||
return fmt.Errorf("invalid IPv4 CIDR mask (got %d bits, expected 32): %s", bits, entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,82 +0,0 @@
|
||||
// FILE: logwisp/src/internal/config/tls.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
type TLSConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
CertFile string `toml:"cert_file"`
|
||||
KeyFile string `toml:"key_file"`
|
||||
|
||||
// Client certificate authentication
|
||||
ClientAuth bool `toml:"client_auth"`
|
||||
ClientCAFile string `toml:"client_ca_file"`
|
||||
VerifyClientCert bool `toml:"verify_client_cert"`
|
||||
|
||||
// Option to skip verification for clients
|
||||
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
|
||||
|
||||
// CA file for client to trust specific server certificates
|
||||
CAFile string `toml:"ca_file"`
|
||||
|
||||
// TLS version constraints
|
||||
MinVersion string `toml:"min_version"` // "TLS1.2", "TLS1.3"
|
||||
MaxVersion string `toml:"max_version"`
|
||||
|
||||
// Cipher suites (comma-separated list)
|
||||
CipherSuites string `toml:"cipher_suites"`
|
||||
}
|
||||
|
||||
func validateTLSOptions(serverType, pipelineName string, sinkIndex int, tls map[string]any) error {
|
||||
if enabled, ok := tls["enabled"].(bool); ok && enabled {
|
||||
certFile, certOk := tls["cert_file"].(string)
|
||||
keyFile, keyOk := tls["key_file"].(string)
|
||||
|
||||
if !certOk || certFile == "" || !keyOk || keyFile == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: TLS enabled but cert/key files not specified",
|
||||
pipelineName, sinkIndex, serverType)
|
||||
}
|
||||
|
||||
// Validate that certificate files exist and are readable
|
||||
if _, err := os.Stat(certFile); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: cert_file is not accessible: %w",
|
||||
pipelineName, sinkIndex, serverType, err)
|
||||
}
|
||||
if _, err := os.Stat(keyFile); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: key_file is not accessible: %w",
|
||||
pipelineName, sinkIndex, serverType, err)
|
||||
}
|
||||
|
||||
if clientAuth, ok := tls["client_auth"].(bool); ok && clientAuth {
|
||||
caFile, caOk := tls["client_ca_file"].(string)
|
||||
if !caOk || caFile == "" {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified",
|
||||
pipelineName, sinkIndex, serverType)
|
||||
}
|
||||
// Validate that the client CA file exists and is readable
|
||||
if _, err := os.Stat(caFile); err != nil {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: client_ca_file is not accessible: %w",
|
||||
pipelineName, sinkIndex, serverType, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate TLS versions
|
||||
validVersions := map[string]bool{"TLS1.0": true, "TLS1.1": true, "TLS1.2": true, "TLS1.3": true}
|
||||
if minVer, ok := tls["min_version"].(string); ok && minVer != "" {
|
||||
if !validVersions[minVer] {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid min TLS version: %s",
|
||||
pipelineName, sinkIndex, serverType, minVer)
|
||||
}
|
||||
}
|
||||
if maxVer, ok := tls["max_version"].(string); ok && maxVer != "" {
|
||||
if !validVersions[maxVer] {
|
||||
return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid max TLS version: %s",
|
||||
pipelineName, sinkIndex, serverType, maxVer)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
13
src/internal/core/const.go
Normal file
13
src/internal/core/const.go
Normal file
@ -0,0 +1,13 @@
|
||||
// FILE: logwisp/src/internal/core/const.go
|
||||
package core
|
||||
|
||||
// Argon2id parameters
|
||||
const (
|
||||
Argon2Time = 3
|
||||
Argon2Memory = 64 * 1024 // 64 MB
|
||||
Argon2Threads = 4
|
||||
Argon2SaltLen = 16
|
||||
Argon2KeyLen = 32
|
||||
)
|
||||
|
||||
const DefaultTokenLength = 32
|
||||
@ -1,80 +0,0 @@
|
||||
// FILE: logwisp/src/internal/filter/chain_test.go
|
||||
package filter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewChain(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
configs := []config.FilterConfig{
|
||||
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
|
||||
{Type: config.FilterTypeExclude, Patterns: []string{"banana"}},
|
||||
}
|
||||
chain, err := NewChain(configs, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, chain)
|
||||
assert.Len(t, chain.filters, 2)
|
||||
})
|
||||
|
||||
t.Run("ErrorInvalidRegexInChain", func(t *testing.T) {
|
||||
configs := []config.FilterConfig{
|
||||
{Patterns: []string{"apple"}},
|
||||
{Patterns: []string{"["}},
|
||||
}
|
||||
chain, err := NewChain(configs, logger)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, chain)
|
||||
assert.Contains(t, err.Error(), "filter[1]")
|
||||
})
|
||||
}
|
||||
|
||||
func TestChain_Apply(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
entry := core.LogEntry{Message: "an apple a day"}
|
||||
|
||||
t.Run("EmptyChain", func(t *testing.T) {
|
||||
chain, err := NewChain([]config.FilterConfig{}, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, chain.Apply(entry))
|
||||
})
|
||||
|
||||
t.Run("AllFiltersPass", func(t *testing.T) {
|
||||
configs := []config.FilterConfig{
|
||||
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
|
||||
{Type: config.FilterTypeInclude, Patterns: []string{"day"}},
|
||||
{Type: config.FilterTypeExclude, Patterns: []string{"banana"}},
|
||||
}
|
||||
chain, err := NewChain(configs, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, chain.Apply(entry))
|
||||
})
|
||||
|
||||
t.Run("OneFilterFails", func(t *testing.T) {
|
||||
configs := []config.FilterConfig{
|
||||
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
|
||||
{Type: config.FilterTypeExclude, Patterns: []string{"day"}}, // This one will fail
|
||||
{Type: config.FilterTypeInclude, Patterns: []string{"a"}},
|
||||
}
|
||||
chain, err := NewChain(configs, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, chain.Apply(entry))
|
||||
})
|
||||
|
||||
t.Run("FirstFilterFails", func(t *testing.T) {
|
||||
configs := []config.FilterConfig{
|
||||
{Type: config.FilterTypeInclude, Patterns: []string{"banana"}}, // This one will fail
|
||||
{Type: config.FilterTypeInclude, Patterns: []string{"apple"}},
|
||||
}
|
||||
chain, err := NewChain(configs, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, chain.Apply(entry))
|
||||
})
|
||||
}
|
||||
@ -1,159 +0,0 @@
|
||||
// FILE: logwisp/src/internal/filter/filter_test.go
|
||||
package filter
|
||||
|
||||
import (
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"testing"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func newTestLogger() *log.Logger {
|
||||
return log.NewLogger()
|
||||
}
|
||||
|
||||
func TestNewFilter(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
|
||||
t.Run("SuccessWithDefaults", func(t *testing.T) {
|
||||
cfg := config.FilterConfig{Patterns: []string{"test"}}
|
||||
f, err := NewFilter(cfg, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, f)
|
||||
assert.Equal(t, config.FilterTypeInclude, f.config.Type)
|
||||
assert.Equal(t, config.FilterLogicOr, f.config.Logic)
|
||||
})
|
||||
|
||||
t.Run("SuccessWithCustomConfig", func(t *testing.T) {
|
||||
cfg := config.FilterConfig{
|
||||
Type: config.FilterTypeExclude,
|
||||
Logic: config.FilterLogicAnd,
|
||||
Patterns: []string{"test", "pattern"},
|
||||
}
|
||||
f, err := NewFilter(cfg, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, f)
|
||||
assert.Equal(t, config.FilterTypeExclude, f.config.Type)
|
||||
assert.Equal(t, config.FilterLogicAnd, f.config.Logic)
|
||||
assert.Len(t, f.patterns, 2)
|
||||
})
|
||||
|
||||
t.Run("ErrorInvalidRegex", func(t *testing.T) {
|
||||
cfg := config.FilterConfig{Patterns: []string{"["}}
|
||||
f, err := NewFilter(cfg, logger)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, f)
|
||||
assert.Contains(t, err.Error(), "invalid regex pattern")
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilter_Apply(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
cfg config.FilterConfig
|
||||
entry core.LogEntry
|
||||
expected bool
|
||||
}{
|
||||
// Include OR logic
|
||||
{
|
||||
name: "IncludeOR_MatchOne",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}},
|
||||
entry: core.LogEntry{Message: "this is an apple"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "IncludeOR_NoMatch",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicOr, Patterns: []string{"apple", "banana"}},
|
||||
entry: core.LogEntry{Message: "this is a pear"},
|
||||
expected: false,
|
||||
},
|
||||
// Include AND logic
|
||||
{
|
||||
name: "IncludeAND_MatchAll",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}},
|
||||
entry: core.LogEntry{Message: "an apple keeps the doctor away"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "IncludeAND_MatchOne",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Logic: config.FilterLogicAnd, Patterns: []string{"apple", "doctor"}},
|
||||
entry: core.LogEntry{Message: "this is an apple"},
|
||||
expected: false,
|
||||
},
|
||||
// Exclude OR logic
|
||||
{
|
||||
name: "ExcludeOR_MatchOne",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}},
|
||||
entry: core.LogEntry{Message: "this is an error"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "ExcludeOR_NoMatch",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicOr, Patterns: []string{"error", "fatal"}},
|
||||
entry: core.LogEntry{Message: "this is a warning"},
|
||||
expected: true,
|
||||
},
|
||||
// Exclude AND logic
|
||||
{
|
||||
name: "ExcludeAND_MatchAll",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}},
|
||||
entry: core.LogEntry{Message: "critical error in database"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "ExcludeAND_MatchOne",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeExclude, Logic: config.FilterLogicAnd, Patterns: []string{"critical", "database"}},
|
||||
entry: core.LogEntry{Message: "critical error in app"},
|
||||
expected: true,
|
||||
},
|
||||
// Edge Cases
|
||||
{
|
||||
name: "NoPatterns",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{}},
|
||||
entry: core.LogEntry{Message: "any message"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyEntry_NoPatterns",
|
||||
cfg: config.FilterConfig{Patterns: []string{}},
|
||||
entry: core.LogEntry{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyEntry_DoesNotMatchSpace",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{" "}},
|
||||
entry: core.LogEntry{Level: "", Source: "", Message: ""},
|
||||
expected: false, // CORRECTED: An empty entry results in an empty string, which doesn't match a space.
|
||||
},
|
||||
{
|
||||
name: "MatchOnLevel",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"ERROR"}},
|
||||
entry: core.LogEntry{Level: "ERROR", Message: "A message"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "MatchOnSource",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"database"}},
|
||||
entry: core.LogEntry{Source: "database", Message: "A message"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "MatchOnCombinedFields",
|
||||
cfg: config.FilterConfig{Type: config.FilterTypeInclude, Patterns: []string{"^app ERROR"}},
|
||||
entry: core.LogEntry{Source: "app", Level: "ERROR", Message: "A message"},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
f, err := NewFilter(tc.cfg, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, f.Apply(tc.entry))
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -3,6 +3,7 @@ package format
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"logwisp/src/internal/config"
|
||||
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
@ -19,20 +20,15 @@ type Formatter interface {
|
||||
}
|
||||
|
||||
// Creates a new Formatter based on the provided configuration.
|
||||
func NewFormatter(name string, options map[string]any, logger *log.Logger) (Formatter, error) {
|
||||
// Default to raw if no format specified
|
||||
if name == "" {
|
||||
name = "raw"
|
||||
}
|
||||
|
||||
switch name {
|
||||
func NewFormatter(cfg *config.FormatConfig, logger *log.Logger) (Formatter, error) {
|
||||
switch cfg.Type {
|
||||
case "json":
|
||||
return NewJSONFormatter(options, logger)
|
||||
return NewJSONFormatter(cfg.JSONFormatOptions, logger)
|
||||
case "txt":
|
||||
return NewTextFormatter(options, logger)
|
||||
case "raw":
|
||||
return NewRawFormatter(options, logger)
|
||||
return NewTextFormatter(cfg.TextFormatOptions, logger)
|
||||
case "raw", "":
|
||||
return NewRawFormatter(cfg.RawFormatOptions, logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown formatter type: %s", name)
|
||||
return nil, fmt.Errorf("unknown formatter type: %s", cfg.Type)
|
||||
}
|
||||
}
|
||||
@ -1,65 +0,0 @@
|
||||
// FILE: logwisp/src/internal/format/format_test.go
|
||||
package format
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestLogger() *log.Logger {
|
||||
return log.NewLogger()
|
||||
}
|
||||
|
||||
func TestNewFormatter(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
formatName string
|
||||
expected string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "JSONFormatter",
|
||||
formatName: "json",
|
||||
expected: "json",
|
||||
},
|
||||
{
|
||||
name: "TextFormatter",
|
||||
formatName: "txt",
|
||||
expected: "txt",
|
||||
},
|
||||
{
|
||||
name: "RawFormatter",
|
||||
formatName: "raw",
|
||||
expected: "raw",
|
||||
},
|
||||
{
|
||||
name: "DefaultToRaw",
|
||||
formatName: "",
|
||||
expected: "raw",
|
||||
},
|
||||
{
|
||||
name: "UnknownFormatter",
|
||||
formatName: "xml",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
formatter, err := NewFormatter(tc.formatName, nil, logger)
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, formatter)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, formatter)
|
||||
assert.Equal(t, tc.expected, formatter.Name())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
@ -13,41 +14,17 @@ import (
|
||||
|
||||
// Produces structured JSON logs
|
||||
type JSONFormatter struct {
|
||||
pretty bool
|
||||
timestampField string
|
||||
levelField string
|
||||
messageField string
|
||||
sourceField string
|
||||
config *config.JSONFormatterOptions
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a new JSON formatter
|
||||
func NewJSONFormatter(options map[string]any, logger *log.Logger) (*JSONFormatter, error) {
|
||||
func NewJSONFormatter(opts *config.JSONFormatterOptions, logger *log.Logger) (*JSONFormatter, error) {
|
||||
f := &JSONFormatter{
|
||||
timestampField: "timestamp",
|
||||
levelField: "level",
|
||||
messageField: "message",
|
||||
sourceField: "source",
|
||||
config: opts,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Extract options
|
||||
if pretty, ok := options["pretty"].(bool); ok {
|
||||
f.pretty = pretty
|
||||
}
|
||||
if field, ok := options["timestamp_field"].(string); ok && field != "" {
|
||||
f.timestampField = field
|
||||
}
|
||||
if field, ok := options["level_field"].(string); ok && field != "" {
|
||||
f.levelField = field
|
||||
}
|
||||
if field, ok := options["message_field"].(string); ok && field != "" {
|
||||
f.messageField = field
|
||||
}
|
||||
if field, ok := options["source_field"].(string); ok && field != "" {
|
||||
f.sourceField = field
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
@ -57,9 +34,9 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
output := make(map[string]any)
|
||||
|
||||
// First, populate with LogWisp metadata
|
||||
output[f.timestampField] = entry.Time.Format(time.RFC3339Nano)
|
||||
output[f.levelField] = entry.Level
|
||||
output[f.sourceField] = entry.Source
|
||||
output[f.config.TimestampField] = entry.Time.Format(time.RFC3339Nano)
|
||||
output[f.config.LevelField] = entry.Level
|
||||
output[f.config.SourceField] = entry.Source
|
||||
|
||||
// Try to parse the message as JSON
|
||||
var msgData map[string]any
|
||||
@ -68,21 +45,21 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
// LogWisp metadata takes precedence
|
||||
for k, v := range msgData {
|
||||
// Don't overwrite our standard fields
|
||||
if k != f.timestampField && k != f.levelField && k != f.sourceField {
|
||||
if k != f.config.TimestampField && k != f.config.LevelField && k != f.config.SourceField {
|
||||
output[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// If the original JSON had these fields, log that we're overriding
|
||||
if _, hasTime := msgData[f.timestampField]; hasTime {
|
||||
if _, hasTime := msgData[f.config.TimestampField]; hasTime {
|
||||
f.logger.Debug("msg", "Overriding timestamp from JSON message",
|
||||
"component", "json_formatter",
|
||||
"original", msgData[f.timestampField],
|
||||
"logwisp", output[f.timestampField])
|
||||
"original", msgData[f.config.TimestampField],
|
||||
"logwisp", output[f.config.TimestampField])
|
||||
}
|
||||
} else {
|
||||
// Message is not valid JSON - add as message field
|
||||
output[f.messageField] = entry.Message
|
||||
output[f.config.MessageField] = entry.Message
|
||||
}
|
||||
|
||||
// Add any additional fields from LogEntry.Fields
|
||||
@ -101,7 +78,7 @@ func (f *JSONFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
// Marshal to JSON
|
||||
var result []byte
|
||||
var err error
|
||||
if f.pretty {
|
||||
if f.config.Pretty {
|
||||
result, err = json.MarshalIndent(output, "", " ")
|
||||
} else {
|
||||
result, err = json.Marshal(output)
|
||||
@ -147,7 +124,7 @@ func (f *JSONFormatter) FormatBatch(entries []core.LogEntry) ([]byte, error) {
|
||||
// Marshal the entire batch as an array
|
||||
var result []byte
|
||||
var err error
|
||||
if f.pretty {
|
||||
if f.config.Pretty {
|
||||
result, err = json.MarshalIndent(batch, "", " ")
|
||||
} else {
|
||||
result, err = json.Marshal(batch)
|
||||
|
||||
@ -1,129 +0,0 @@
|
||||
// FILE: logwisp/src/internal/format/json_test.go
|
||||
package format
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJSONFormatter_Format(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
testTime := time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
entry := core.LogEntry{
|
||||
Time: testTime,
|
||||
Source: "test-app",
|
||||
Level: "INFO",
|
||||
Message: "this is a test",
|
||||
}
|
||||
|
||||
t.Run("BasicFormatting", func(t *testing.T) {
|
||||
formatter, err := NewJSONFormatter(nil, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(output, &result)
|
||||
require.NoError(t, err, "Output should be valid JSON")
|
||||
|
||||
assert.Equal(t, testTime.Format(time.RFC3339Nano), result["timestamp"])
|
||||
assert.Equal(t, "INFO", result["level"])
|
||||
assert.Equal(t, "test-app", result["source"])
|
||||
assert.Equal(t, "this is a test", result["message"])
|
||||
assert.True(t, strings.HasSuffix(string(output), "\n"), "Output should end with a newline")
|
||||
})
|
||||
|
||||
t.Run("PrettyFormatting", func(t *testing.T) {
|
||||
formatter, err := NewJSONFormatter(map[string]any{"pretty": true}, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, string(output), ` "level": "INFO"`)
|
||||
assert.True(t, strings.HasSuffix(string(output), "\n"))
|
||||
})
|
||||
|
||||
t.Run("MessageIsJSON", func(t *testing.T) {
|
||||
jsonMessageEntry := entry
|
||||
jsonMessageEntry.Message = `{"user":"test","request_id":"abc-123"}`
|
||||
formatter, err := NewJSONFormatter(nil, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(jsonMessageEntry)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(output, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test", result["user"])
|
||||
assert.Equal(t, "abc-123", result["request_id"])
|
||||
_, messageExists := result["message"]
|
||||
assert.False(t, messageExists, "message field should not exist when message is merged JSON")
|
||||
})
|
||||
|
||||
t.Run("MessageIsJSONWithConflicts", func(t *testing.T) {
|
||||
jsonMessageEntry := entry
|
||||
jsonMessageEntry.Level = "INFO" // top-level
|
||||
jsonMessageEntry.Message = `{"level":"DEBUG","msg":"hello"}`
|
||||
formatter, err := NewJSONFormatter(nil, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(jsonMessageEntry)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(output, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "INFO", result["level"], "Top-level LogEntry field should take precedence")
|
||||
})
|
||||
|
||||
t.Run("CustomFieldNames", func(t *testing.T) {
|
||||
options := map[string]any{"timestamp_field": "@timestamp"}
|
||||
formatter, err := NewJSONFormatter(options, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(output, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, defaultExists := result["timestamp"]
|
||||
assert.False(t, defaultExists)
|
||||
assert.Equal(t, testTime.Format(time.RFC3339Nano), result["@timestamp"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestJSONFormatter_FormatBatch(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
formatter, err := NewJSONFormatter(nil, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
entries := []core.LogEntry{
|
||||
{Time: time.Now(), Level: "INFO", Message: "First message"},
|
||||
{Time: time.Now(), Level: "WARN", Message: "Second message"},
|
||||
}
|
||||
|
||||
output, err := formatter.FormatBatch(entries)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result []map[string]interface{}
|
||||
err = json.Unmarshal(output, &result)
|
||||
require.NoError(t, err, "Batch output should be a valid JSON array")
|
||||
require.Len(t, result, 2)
|
||||
|
||||
assert.Equal(t, "First message", result[0]["message"])
|
||||
assert.Equal(t, "WARN", result[1]["level"])
|
||||
}
|
||||
@ -2,6 +2,7 @@
|
||||
package format
|
||||
|
||||
import (
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
@ -9,20 +10,26 @@ import (
|
||||
|
||||
// Outputs the log message as-is with a newline
|
||||
type RawFormatter struct {
|
||||
config *config.RawFormatterOptions
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a new raw formatter
|
||||
func NewRawFormatter(options map[string]any, logger *log.Logger) (*RawFormatter, error) {
|
||||
func NewRawFormatter(cfg *config.RawFormatterOptions, logger *log.Logger) (*RawFormatter, error) {
|
||||
return &RawFormatter{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Returns the message with a newline appended
|
||||
func (f *RawFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
// Simply return the message with newline
|
||||
// TODO: Standardize not to add "\n" when processing raw, check lixenwraith/log for consistency
|
||||
if f.config.AddNewLine {
|
||||
return append([]byte(entry.Message), '\n'), nil
|
||||
} else {
|
||||
return []byte(entry.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the formatter name
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
// FILE: logwisp/src/internal/format/raw_test.go
|
||||
package format
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRawFormatter_Format(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
formatter, err := NewRawFormatter(nil, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
entry := core.LogEntry{
|
||||
Time: time.Now(),
|
||||
Message: "This is a raw log line.",
|
||||
}
|
||||
|
||||
output, err := formatter.Format(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := "This is a raw log line.\n"
|
||||
assert.Equal(t, expected, string(output))
|
||||
}
|
||||
@ -4,6 +4,7 @@ package format
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"logwisp/src/internal/config"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
@ -15,41 +16,29 @@ import (
|
||||
|
||||
// Produces human-readable text logs using templates
|
||||
type TextFormatter struct {
|
||||
config *config.TextFormatterOptions
|
||||
template *template.Template
|
||||
timestampFormat string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Creates a new text formatter
|
||||
func NewTextFormatter(options map[string]any, logger *log.Logger) (*TextFormatter, error) {
|
||||
// Default template
|
||||
templateStr := "[{{.Timestamp | FmtTime}}] [{{.Level | ToUpper}}] {{.Source}} - {{.Message}}{{ if .Fields }} {{.Fields}}{{ end }}"
|
||||
if tmpl, ok := options["template"].(string); ok && tmpl != "" {
|
||||
templateStr = tmpl
|
||||
}
|
||||
|
||||
// Default timestamp format
|
||||
timestampFormat := time.RFC3339
|
||||
if tsFormat, ok := options["timestamp_format"].(string); ok && tsFormat != "" {
|
||||
timestampFormat = tsFormat
|
||||
}
|
||||
|
||||
func NewTextFormatter(opts *config.TextFormatterOptions, logger *log.Logger) (*TextFormatter, error) {
|
||||
f := &TextFormatter{
|
||||
timestampFormat: timestampFormat,
|
||||
config: opts,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Create template with helper functions
|
||||
funcMap := template.FuncMap{
|
||||
"FmtTime": func(t time.Time) string {
|
||||
return t.Format(f.timestampFormat)
|
||||
return t.Format(f.config.TimestampFormat)
|
||||
},
|
||||
"ToUpper": strings.ToUpper,
|
||||
"ToLower": strings.ToLower,
|
||||
"TrimSpace": strings.TrimSpace,
|
||||
}
|
||||
|
||||
tmpl, err := template.New("log").Funcs(funcMap).Parse(templateStr)
|
||||
tmpl, err := template.New("log").Funcs(funcMap).Parse(f.config.Template)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid template: %w", err)
|
||||
}
|
||||
@ -86,7 +75,7 @@ func (f *TextFormatter) Format(entry core.LogEntry) ([]byte, error) {
|
||||
"error", err)
|
||||
|
||||
fallback := fmt.Sprintf("[%s] [%s] %s - %s\n",
|
||||
entry.Time.Format(f.timestampFormat),
|
||||
entry.Time.Format(f.config.TimestampFormat),
|
||||
strings.ToUpper(entry.Level),
|
||||
entry.Source,
|
||||
entry.Message)
|
||||
|
||||
@ -1,81 +0,0 @@
|
||||
// FILE: logwisp/src/internal/format/text_test.go
|
||||
package format
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/core"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTextFormatter(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
t.Run("InvalidTemplate", func(t *testing.T) {
|
||||
options := map[string]any{"template": "{{ .Timestamp | InvalidFunc }}"}
|
||||
_, err := NewTextFormatter(options, logger)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid template")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTextFormatter_Format(t *testing.T) {
|
||||
logger := newTestLogger()
|
||||
testTime := time.Date(2023, 10, 27, 10, 30, 0, 0, time.UTC)
|
||||
entry := core.LogEntry{
|
||||
Time: testTime,
|
||||
Source: "api",
|
||||
Level: "WARN",
|
||||
Message: "rate limit exceeded",
|
||||
}
|
||||
|
||||
t.Run("DefaultTemplate", func(t *testing.T) {
|
||||
formatter, err := NewTextFormatter(nil, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := fmt.Sprintf("[%s] [WARN] api - rate limit exceeded\n", testTime.Format(time.RFC3339))
|
||||
assert.Equal(t, expected, string(output))
|
||||
})
|
||||
|
||||
t.Run("CustomTemplate", func(t *testing.T) {
|
||||
options := map[string]any{"template": "{{.Level}}:{{.Source}}:{{.Message}}"}
|
||||
formatter, err := NewTextFormatter(options, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := "WARN:api:rate limit exceeded\n"
|
||||
assert.Equal(t, expected, string(output))
|
||||
})
|
||||
|
||||
t.Run("CustomTimestampFormat", func(t *testing.T) {
|
||||
options := map[string]any{"timestamp_format": "2006-01-02"}
|
||||
formatter, err := NewTextFormatter(options, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, strings.HasPrefix(string(output), "[2023-10-27]"))
|
||||
})
|
||||
|
||||
t.Run("EmptyLevelDefaultsToInfo", func(t *testing.T) {
|
||||
emptyLevelEntry := entry
|
||||
emptyLevelEntry.Level = ""
|
||||
formatter, err := NewTextFormatter(nil, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := formatter.Format(emptyLevelEntry)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, string(output), "[INFO]")
|
||||
})
|
||||
}
|
||||
@ -34,7 +34,7 @@ const (
|
||||
|
||||
// NetLimiter manages net limiting for a transport
|
||||
type NetLimiter struct {
|
||||
config config.NetLimitConfig
|
||||
config *config.NetLimitConfig
|
||||
logger *log.Logger
|
||||
|
||||
// IP Access Control Lists
|
||||
@ -89,7 +89,11 @@ type connTracker struct {
|
||||
}
|
||||
|
||||
// Creates a new net limiter
|
||||
func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
func NewNetLimiter(cfg *config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return nil only if nothing is configured
|
||||
hasACL := len(cfg.IPWhitelist) > 0 || len(cfg.IPBlacklist) > 0
|
||||
hasRateLimit := cfg.Enabled
|
||||
@ -120,7 +124,7 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
}
|
||||
|
||||
// Parse IP lists
|
||||
l.parseIPLists(cfg)
|
||||
l.parseIPLists()
|
||||
|
||||
// Start cleanup goroutine only if rate limiting is enabled
|
||||
if cfg.Enabled {
|
||||
@ -144,16 +148,16 @@ func NewNetLimiter(cfg config.NetLimitConfig, logger *log.Logger) *NetLimiter {
|
||||
}
|
||||
|
||||
// parseIPLists parses and validates IP whitelist/blacklist
|
||||
func (l *NetLimiter) parseIPLists(cfg config.NetLimitConfig) {
|
||||
func (l *NetLimiter) parseIPLists() {
|
||||
// Parse whitelist
|
||||
for _, entry := range cfg.IPWhitelist {
|
||||
for _, entry := range l.config.IPWhitelist {
|
||||
if ipNet := l.parseIPEntry(entry, "whitelist"); ipNet != nil {
|
||||
l.ipWhitelist = append(l.ipWhitelist, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse blacklist
|
||||
for _, entry := range cfg.IPBlacklist {
|
||||
for _, entry := range l.config.IPBlacklist {
|
||||
if ipNet := l.parseIPEntry(entry, "blacklist"); ipNet != nil {
|
||||
l.ipBlacklist = append(l.ipBlacklist, ipNet)
|
||||
}
|
||||
|
||||
@ -1,117 +0,0 @@
|
||||
// FILE: src/internal/scram/integration.go
|
||||
package scram
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// ScramManager provides high-level SCRAM operations with rate limiting
|
||||
type ScramManager struct {
|
||||
server *Server
|
||||
sessions map[string]*SessionInfo
|
||||
limiter map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SessionInfo tracks authenticated sessions
|
||||
type SessionInfo struct {
|
||||
Username string
|
||||
RemoteAddr string
|
||||
SessionID string
|
||||
CreatedAt time.Time
|
||||
LastActivity time.Time
|
||||
Method string // "scram-sha-256"
|
||||
}
|
||||
|
||||
// NewScramManager creates SCRAM manager
|
||||
func NewScramManager() *ScramManager {
|
||||
m := &ScramManager{
|
||||
server: NewServer(),
|
||||
sessions: make(map[string]*SessionInfo),
|
||||
limiter: make(map[string]*rate.Limiter),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go m.cleanupLoop()
|
||||
return m
|
||||
}
|
||||
|
||||
// RegisterUser creates new user credential
|
||||
func (sm *ScramManager) RegisterUser(username, password string) error {
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return fmt.Errorf("salt generation failed: %w", err)
|
||||
}
|
||||
|
||||
cred, err := DeriveCredential(username, password, salt,
|
||||
sm.server.DefaultTime, sm.server.DefaultMemory, sm.server.DefaultThreads)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sm.server.AddCredential(cred)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRateLimiter returns per-IP rate limiter
|
||||
func (sm *ScramManager) GetRateLimiter(remoteAddr string) *rate.Limiter {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if limiter, exists := sm.limiter[remoteAddr]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// 10 attempts per minute, burst of 3
|
||||
limiter := rate.NewLimiter(rate.Every(6*time.Second), 3)
|
||||
sm.limiter[remoteAddr] = limiter
|
||||
|
||||
// Prevent unbounded growth
|
||||
if len(sm.limiter) > 10000 {
|
||||
// Remove oldest entries
|
||||
for addr := range sm.limiter {
|
||||
delete(sm.limiter, addr)
|
||||
if len(sm.limiter) < 8000 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (sm *ScramManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
sm.mu.Lock()
|
||||
cutoff := time.Now().Add(-30 * time.Minute)
|
||||
for sid, session := range sm.sessions {
|
||||
if session.LastActivity.Before(cutoff) {
|
||||
delete(sm.sessions, sid)
|
||||
}
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// HandleClientFirst wraps server's HandleClientFirst
|
||||
func (sm *ScramManager) HandleClientFirst(msg *ClientFirst) (*ServerFirst, error) {
|
||||
return sm.server.HandleClientFirst(msg)
|
||||
}
|
||||
|
||||
// HandleClientFinal wraps server's HandleClientFinal
|
||||
func (sm *ScramManager) HandleClientFinal(msg *ClientFinal) (*ServerFinal, error) {
|
||||
return sm.server.HandleClientFinal(msg)
|
||||
}
|
||||
|
||||
// AddCredential wraps server's AddCredential
|
||||
func (sm *ScramManager) AddCredential(cred *Credential) {
|
||||
sm.server.AddCredential(cred)
|
||||
}
|
||||
@ -1,101 +0,0 @@
|
||||
// FILE: src/internal/scram/message.go
|
||||
package scram
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ClientFirst initiates authentication
|
||||
type ClientFirst struct {
|
||||
Username string `json:"u"`
|
||||
ClientNonce string `json:"n"`
|
||||
}
|
||||
|
||||
// ServerFirst contains server challenge
|
||||
type ServerFirst struct {
|
||||
FullNonce string `json:"r"` // client_nonce + server_nonce
|
||||
Salt string `json:"s"` // base64
|
||||
ArgonTime uint32 `json:"t"`
|
||||
ArgonMemory uint32 `json:"m"`
|
||||
ArgonThreads uint8 `json:"p"`
|
||||
}
|
||||
|
||||
// ClientFinal contains client proof
|
||||
type ClientFinal struct {
|
||||
FullNonce string `json:"r"`
|
||||
ClientProof string `json:"p"` // base64
|
||||
}
|
||||
|
||||
// ServerFinal contains server signature for mutual auth
|
||||
type ServerFinal struct {
|
||||
ServerSignature string `json:"v"` // base64
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
}
|
||||
|
||||
// Marshal/Unmarshal helpers for TCP protocol (line-based)
|
||||
func (cf *ClientFirst) Marshal() string {
|
||||
return fmt.Sprintf("u=%s,n=%s", cf.Username, cf.ClientNonce)
|
||||
}
|
||||
|
||||
func ParseClientFirst(data string) (*ClientFirst, error) {
|
||||
parts := strings.Split(data, ",")
|
||||
msg := &ClientFirst{}
|
||||
for _, part := range parts {
|
||||
kv := strings.SplitN(part, "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
switch kv[0] {
|
||||
case "u":
|
||||
msg.Username = kv[1]
|
||||
case "n":
|
||||
msg.ClientNonce = kv[1]
|
||||
}
|
||||
}
|
||||
if msg.Username == "" || msg.ClientNonce == "" {
|
||||
return nil, fmt.Errorf("missing required fields")
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (sf *ServerFirst) Marshal() string {
|
||||
return fmt.Sprintf("r=%s,s=%s,t=%d,m=%d,p=%d",
|
||||
sf.FullNonce, sf.Salt, sf.ArgonTime, sf.ArgonMemory, sf.ArgonThreads)
|
||||
}
|
||||
|
||||
func ParseServerFirst(data string) (*ServerFirst, error) {
|
||||
parts := strings.Split(data, ",")
|
||||
msg := &ServerFirst{}
|
||||
for _, part := range parts {
|
||||
kv := strings.SplitN(part, "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
switch kv[0] {
|
||||
case "r":
|
||||
msg.FullNonce = kv[1]
|
||||
case "s":
|
||||
msg.Salt = kv[1]
|
||||
case "t":
|
||||
fmt.Sscanf(kv[1], "%d", &msg.ArgonTime)
|
||||
case "m":
|
||||
fmt.Sscanf(kv[1], "%d", &msg.ArgonMemory)
|
||||
case "p":
|
||||
fmt.Sscanf(kv[1], "%d", &msg.ArgonThreads)
|
||||
}
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// JSON variants for HTTP
|
||||
func (cf *ClientFirst) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(*cf)
|
||||
}
|
||||
|
||||
func (sf *ServerFirst) MarshalJSON() ([]byte, error) {
|
||||
sf.Salt = base64.StdEncoding.EncodeToString([]byte(sf.Salt))
|
||||
return json.Marshal(*sf)
|
||||
}
|
||||
@ -1,228 +0,0 @@
|
||||
// FILE: src/internal/scram/scram_test.go
|
||||
package scram
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
func TestCredentialDerivation(t *testing.T) {
|
||||
salt := make([]byte, 16)
|
||||
_, err := rand.Read(salt)
|
||||
require.NoError(t, err)
|
||||
|
||||
cred, err := DeriveCredential("testuser", "testpass123", salt, 3, 64*1024, 4)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "testuser", cred.Username)
|
||||
assert.Equal(t, uint32(3), cred.ArgonTime)
|
||||
assert.Equal(t, uint32(64*1024), cred.ArgonMemory)
|
||||
assert.Equal(t, uint8(4), cred.ArgonThreads)
|
||||
assert.Len(t, cred.StoredKey, 32)
|
||||
assert.Len(t, cred.ServerKey, 32)
|
||||
}
|
||||
|
||||
func TestFullHandshake(t *testing.T) {
|
||||
// Setup server with credential
|
||||
server := NewServer()
|
||||
salt := make([]byte, 16)
|
||||
_, err := rand.Read(salt)
|
||||
require.NoError(t, err)
|
||||
|
||||
cred, err := DeriveCredential("alice", "secret123", salt, 3, 64*1024, 4)
|
||||
require.NoError(t, err)
|
||||
server.AddCredential(cred)
|
||||
|
||||
// Setup client
|
||||
client := NewClient("alice", "secret123")
|
||||
|
||||
// Step 1: Client starts
|
||||
clientFirst, err := client.StartAuthentication()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "alice", clientFirst.Username)
|
||||
assert.NotEmpty(t, clientFirst.ClientNonce)
|
||||
|
||||
// Step 2: Server responds
|
||||
serverFirst, err := server.HandleClientFirst(clientFirst)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, serverFirst.FullNonce, clientFirst.ClientNonce)
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(salt), serverFirst.Salt)
|
||||
|
||||
// Step 3: Client proves
|
||||
clientFinal, err := client.ProcessServerFirst(serverFirst)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, serverFirst.FullNonce, clientFinal.FullNonce)
|
||||
assert.NotEmpty(t, clientFinal.ClientProof)
|
||||
|
||||
// Step 4: Server verifies and signs
|
||||
serverFinal, err := server.HandleClientFinal(clientFinal)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, serverFinal.ServerSignature)
|
||||
assert.NotEmpty(t, serverFinal.SessionID)
|
||||
|
||||
// Step 5: Client verifies server
|
||||
err = client.VerifyServerFinal(serverFinal)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidPassword(t *testing.T) {
|
||||
server := NewServer()
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
cred, _ := DeriveCredential("bob", "correct", salt, 3, 64*1024, 4)
|
||||
server.AddCredential(cred)
|
||||
|
||||
// Client with wrong password
|
||||
client := NewClient("bob", "wrong")
|
||||
|
||||
clientFirst, _ := client.StartAuthentication()
|
||||
serverFirst, _ := server.HandleClientFirst(clientFirst)
|
||||
clientFinal, _ := client.ProcessServerFirst(serverFirst)
|
||||
|
||||
clientFinal, err := client.ProcessServerFirst(serverFirst)
|
||||
require.NoError(t, err) // Check error to prevent panic
|
||||
|
||||
// Server should reject
|
||||
_, err = server.HandleClientFinal(clientFinal)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
}
|
||||
|
||||
func TestHandshakeTimeout(t *testing.T) {
|
||||
server := NewServer()
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
cred, _ := DeriveCredential("charlie", "pass", salt, 3, 64*1024, 4)
|
||||
server.AddCredential(cred)
|
||||
|
||||
client := NewClient("charlie", "pass")
|
||||
clientFirst, _ := client.StartAuthentication()
|
||||
serverFirst, _ := server.HandleClientFirst(clientFirst)
|
||||
|
||||
// Manipulate handshake timestamp
|
||||
server.mu.Lock()
|
||||
if state, exists := server.handshakes[serverFirst.FullNonce]; exists {
|
||||
state.CreatedAt = time.Now().Add(-61 * time.Second)
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
clientFinal, _ := client.ProcessServerFirst(serverFirst)
|
||||
_, err := server.HandleClientFinal(clientFinal)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout")
|
||||
}
|
||||
|
||||
func TestMessageParsing(t *testing.T) {
|
||||
// Test ClientFirst parsing
|
||||
cf, err := ParseClientFirst("u=alice,n=abc123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "alice", cf.Username)
|
||||
assert.Equal(t, "abc123", cf.ClientNonce)
|
||||
|
||||
// Test marshal round-trip
|
||||
marshaled := cf.Marshal()
|
||||
parsed, err := ParseClientFirst(marshaled)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cf.Username, parsed.Username)
|
||||
assert.Equal(t, cf.ClientNonce, parsed.ClientNonce)
|
||||
}
|
||||
|
||||
func TestPHCMigration(t *testing.T) {
|
||||
// Create PHC hash using known values
|
||||
password := "testpass"
|
||||
salt := []byte("saltsaltsaltsalt")
|
||||
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
||||
|
||||
// This would be generated by the existing auth system
|
||||
phcHash := "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + "dGVzdGhhc2g" // dummy hash
|
||||
|
||||
// For testing, we need a valid hash - generate it
|
||||
hash := argon2.IDKey([]byte(password), salt, 3, 65536, 4, 32)
|
||||
hashB64 := base64.RawStdEncoding.EncodeToString(hash)
|
||||
phcHash = "$argon2id$v=19$m=65536,t=3,p=4$" + saltB64 + "$" + hashB64
|
||||
|
||||
cred, err := MigrateFromPHC("alice", password, phcHash)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "alice", cred.Username)
|
||||
assert.Equal(t, salt, cred.Salt)
|
||||
}
|
||||
|
||||
func TestRateLimiting(t *testing.T) {
|
||||
manager := NewScramManager()
|
||||
|
||||
limiter := manager.GetRateLimiter("192.168.1.1")
|
||||
|
||||
// Should allow burst
|
||||
for i := 0; i < 3; i++ {
|
||||
assert.True(t, limiter.Allow())
|
||||
}
|
||||
|
||||
// 4th should be rate limited
|
||||
assert.False(t, limiter.Allow())
|
||||
}
|
||||
|
||||
func TestNonceUniqueness(t *testing.T) {
|
||||
nonces := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
nonce := generateNonce()
|
||||
assert.False(t, nonces[nonce], "Duplicate nonce generated")
|
||||
nonces[nonce] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstantTimeComparison(t *testing.T) {
|
||||
server := NewServer()
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
// Add real user
|
||||
cred, _ := DeriveCredential("realuser", "realpass", salt, 3, 64*1024, 4)
|
||||
server.AddCredential(cred)
|
||||
|
||||
// Test timing independence for invalid user
|
||||
fakeClient := NewClient("fakeuser", "fakepass")
|
||||
clientFirst, _ := fakeClient.StartAuthentication()
|
||||
|
||||
// Should still return ServerFirst (no user enumeration)
|
||||
serverFirst, err := server.HandleClientFirst(clientFirst)
|
||||
assert.NotNil(t, serverFirst)
|
||||
assert.Error(t, err) // Internal error, not exposed to client
|
||||
}
|
||||
|
||||
func BenchmarkArgon2Derivation(b *testing.B) {
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
DeriveCredential("user", "password", salt, 3, 64*1024, 4)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFullHandshake(b *testing.B) {
|
||||
server := NewServer()
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
cred, _ := DeriveCredential("bench", "pass", salt, 3, 64*1024, 4)
|
||||
server.AddCredential(cred)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
client := NewClient("bench", "pass")
|
||||
clientFirst, _ := client.StartAuthentication()
|
||||
serverFirst, _ := server.HandleClientFirst(clientFirst)
|
||||
clientFinal, _ := client.ProcessServerFirst(serverFirst)
|
||||
serverFinal, _ := server.HandleClientFinal(clientFinal)
|
||||
client.VerifyServerFinal(serverFinal)
|
||||
}
|
||||
}
|
||||
@ -3,12 +3,14 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/filter"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/sink"
|
||||
"logwisp/src/internal/source"
|
||||
@ -18,8 +20,7 @@ import (
|
||||
|
||||
// Manages the flow of data from sources through filters to sinks
|
||||
type Pipeline struct {
|
||||
Name string
|
||||
Config config.PipelineConfig
|
||||
Config *config.PipelineConfig
|
||||
Sources []source.Source
|
||||
RateLimiter *limit.RateLimiter
|
||||
FilterChain *filter.Chain
|
||||
@ -43,11 +44,116 @@ type PipelineStats struct {
|
||||
FilterStats map[string]any
|
||||
}
|
||||
|
||||
// Creates and starts a new pipeline
|
||||
func (s *Service) NewPipeline(cfg *config.PipelineConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.pipelines[cfg.Name]; exists {
|
||||
err := fmt.Errorf("pipeline '%s' already exists", cfg.Name)
|
||||
s.logger.Error("msg", "Failed to create pipeline - duplicate name",
|
||||
"component", "service",
|
||||
"pipeline", cfg.Name,
|
||||
"error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Debug("msg", "Creating pipeline", "pipeline", cfg.Name)
|
||||
|
||||
// Create pipeline context
|
||||
pipelineCtx, pipelineCancel := context.WithCancel(s.ctx)
|
||||
|
||||
// Create pipeline instance
|
||||
pipeline := &Pipeline{
|
||||
Config: cfg,
|
||||
Stats: &PipelineStats{
|
||||
StartTime: time.Now(),
|
||||
},
|
||||
ctx: pipelineCtx,
|
||||
cancel: pipelineCancel,
|
||||
logger: s.logger,
|
||||
}
|
||||
|
||||
// Create sources
|
||||
for i, srcCfg := range cfg.Sources {
|
||||
src, err := s.createSource(&srcCfg)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create source[%d]: %w", i, err)
|
||||
}
|
||||
pipeline.Sources = append(pipeline.Sources, src)
|
||||
}
|
||||
|
||||
// Create pipeline rate limiter
|
||||
if cfg.RateLimit != nil {
|
||||
limiter, err := limit.NewRateLimiter(*cfg.RateLimit, s.logger)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create pipeline rate limiter: %w", err)
|
||||
}
|
||||
pipeline.RateLimiter = limiter
|
||||
}
|
||||
|
||||
// Create filter chain
|
||||
if len(cfg.Filters) > 0 {
|
||||
chain, err := filter.NewChain(cfg.Filters, s.logger)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create filter chain: %w", err)
|
||||
}
|
||||
pipeline.FilterChain = chain
|
||||
}
|
||||
|
||||
// Create formatter for the pipeline
|
||||
formatter, err := format.NewFormatter(cfg.Format, s.logger)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create formatter: %w", err)
|
||||
}
|
||||
|
||||
// Create sinks
|
||||
for i, sinkCfg := range cfg.Sinks {
|
||||
sinkInst, err := s.createSink(sinkCfg, formatter)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create sink[%d]: %w", i, err)
|
||||
}
|
||||
pipeline.Sinks = append(pipeline.Sinks, sinkInst)
|
||||
}
|
||||
|
||||
// Start all sources
|
||||
for i, src := range pipeline.Sources {
|
||||
if err := src.Start(); err != nil {
|
||||
pipeline.Shutdown()
|
||||
return fmt.Errorf("failed to start source[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start all sinks
|
||||
for i, sinkInst := range pipeline.Sinks {
|
||||
if err := sinkInst.Start(pipelineCtx); err != nil {
|
||||
pipeline.Shutdown()
|
||||
return fmt.Errorf("failed to start sink[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wire sources to sinks through filters
|
||||
s.wirePipeline(pipeline)
|
||||
|
||||
// Start stats updater
|
||||
pipeline.startStatsUpdater(pipelineCtx)
|
||||
|
||||
s.pipelines[cfg.Name] = pipeline
|
||||
s.logger.Info("msg", "Pipeline created successfully",
|
||||
"pipeline", cfg.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Gracefully stops the pipeline
|
||||
func (p *Pipeline) Shutdown() {
|
||||
p.logger.Info("msg", "Shutting down pipeline",
|
||||
"component", "pipeline",
|
||||
"pipeline", p.Name)
|
||||
"pipeline", p.Config.Name)
|
||||
|
||||
// Cancel context to stop processing
|
||||
p.cancel()
|
||||
@ -78,7 +184,7 @@ func (p *Pipeline) Shutdown() {
|
||||
|
||||
p.logger.Info("msg", "Pipeline shutdown complete",
|
||||
"component", "pipeline",
|
||||
"pipeline", p.Name)
|
||||
"pipeline", p.Config.Name)
|
||||
}
|
||||
|
||||
// Returns pipeline statistics
|
||||
@ -88,7 +194,7 @@ func (p *Pipeline) GetStats() map[string]any {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
p.logger.Error("msg", "Panic getting pipeline stats",
|
||||
"pipeline", p.Name,
|
||||
"pipeline", p.Config.Name,
|
||||
"panic", r)
|
||||
}
|
||||
}()
|
||||
@ -142,7 +248,7 @@ func (p *Pipeline) GetStats() map[string]any {
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"name": p.Name,
|
||||
"name": p.Config.Name,
|
||||
"uptime_seconds": int(time.Since(p.Stats.StartTime).Seconds()),
|
||||
"total_processed": p.Stats.TotalEntriesProcessed.Load(),
|
||||
"total_dropped_rate_limit": p.Stats.TotalEntriesDroppedByRateLimit.Load(),
|
||||
|
||||
@ -5,13 +5,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/filter"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/sink"
|
||||
"logwisp/src/internal/source"
|
||||
|
||||
@ -39,127 +36,6 @@ func NewService(ctx context.Context, logger *log.Logger) *Service {
|
||||
}
|
||||
}
|
||||
|
||||
// Creates and starts a new pipeline
|
||||
func (s *Service) NewPipeline(cfg config.PipelineConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.pipelines[cfg.Name]; exists {
|
||||
err := fmt.Errorf("pipeline '%s' already exists", cfg.Name)
|
||||
s.logger.Error("msg", "Failed to create pipeline - duplicate name",
|
||||
"component", "service",
|
||||
"pipeline", cfg.Name,
|
||||
"error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Debug("msg", "Creating pipeline", "pipeline", cfg.Name)
|
||||
|
||||
// Create pipeline context
|
||||
pipelineCtx, pipelineCancel := context.WithCancel(s.ctx)
|
||||
|
||||
// Create pipeline instance
|
||||
pipeline := &Pipeline{
|
||||
Name: cfg.Name,
|
||||
Config: cfg,
|
||||
Stats: &PipelineStats{
|
||||
StartTime: time.Now(),
|
||||
},
|
||||
ctx: pipelineCtx,
|
||||
cancel: pipelineCancel,
|
||||
logger: s.logger,
|
||||
}
|
||||
|
||||
// Create sources
|
||||
for i, srcCfg := range cfg.Sources {
|
||||
src, err := s.createSource(srcCfg)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create source[%d]: %w", i, err)
|
||||
}
|
||||
pipeline.Sources = append(pipeline.Sources, src)
|
||||
}
|
||||
|
||||
// Create pipeline rate limiter
|
||||
if cfg.RateLimit != nil {
|
||||
limiter, err := limit.NewRateLimiter(*cfg.RateLimit, s.logger)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create pipeline rate limiter: %w", err)
|
||||
}
|
||||
pipeline.RateLimiter = limiter
|
||||
}
|
||||
|
||||
// Create filter chain
|
||||
if len(cfg.Filters) > 0 {
|
||||
chain, err := filter.NewChain(cfg.Filters, s.logger)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create filter chain: %w", err)
|
||||
}
|
||||
pipeline.FilterChain = chain
|
||||
}
|
||||
|
||||
// Create formatter for the pipeline
|
||||
var formatter format.Formatter
|
||||
var err error
|
||||
if cfg.Format != "" || len(cfg.FormatOptions) > 0 {
|
||||
formatter, err = format.NewFormatter(cfg.Format, cfg.FormatOptions, s.logger)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create formatter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create sinks
|
||||
for i, sinkCfg := range cfg.Sinks {
|
||||
sinkInst, err := s.createSink(sinkCfg, formatter)
|
||||
if err != nil {
|
||||
pipelineCancel()
|
||||
return fmt.Errorf("failed to create sink[%d]: %w", i, err)
|
||||
}
|
||||
pipeline.Sinks = append(pipeline.Sinks, sinkInst)
|
||||
}
|
||||
|
||||
// Configure authentication for sources that support it before starting them
|
||||
for _, sourceInst := range pipeline.Sources {
|
||||
sourceInst.SetAuth(cfg.Auth)
|
||||
}
|
||||
|
||||
// Start all sources
|
||||
for i, src := range pipeline.Sources {
|
||||
if err := src.Start(); err != nil {
|
||||
pipeline.Shutdown()
|
||||
return fmt.Errorf("failed to start source[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Configure authentication for sinks that support it before starting them
|
||||
for _, sinkInst := range pipeline.Sinks {
|
||||
sinkInst.SetAuth(cfg.Auth)
|
||||
}
|
||||
|
||||
// Start all sinks
|
||||
for i, sinkInst := range pipeline.Sinks {
|
||||
if err := sinkInst.Start(pipelineCtx); err != nil {
|
||||
pipeline.Shutdown()
|
||||
return fmt.Errorf("failed to start sink[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wire sources to sinks through filters
|
||||
s.wirePipeline(pipeline)
|
||||
|
||||
// Start stats updater
|
||||
pipeline.startStatsUpdater(pipelineCtx)
|
||||
|
||||
s.pipelines[cfg.Name] = pipeline
|
||||
s.logger.Info("msg", "Pipeline created successfully",
|
||||
"pipeline", cfg.Name,
|
||||
"auth_enabled", cfg.Auth != nil && cfg.Auth.Type != "none")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connects sources to sinks through filters
|
||||
func (s *Service) wirePipeline(p *Pipeline) {
|
||||
// For each source, subscribe and process entries
|
||||
@ -175,17 +51,17 @@ func (s *Service) wirePipeline(p *Pipeline) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Error("msg", "Panic in pipeline processing",
|
||||
"pipeline", p.Name,
|
||||
"pipeline", p.Config.Name,
|
||||
"source", source.GetStats().Type,
|
||||
"panic", r)
|
||||
|
||||
// Ensure failed pipelines don't leave resources hanging
|
||||
go func() {
|
||||
s.logger.Warn("msg", "Shutting down pipeline due to panic",
|
||||
"pipeline", p.Name)
|
||||
if err := s.RemovePipeline(p.Name); err != nil {
|
||||
"pipeline", p.Config.Name)
|
||||
if err := s.RemovePipeline(p.Config.Name); err != nil {
|
||||
s.logger.Error("msg", "Failed to remove panicked pipeline",
|
||||
"pipeline", p.Name,
|
||||
"pipeline", p.Config.Name,
|
||||
"error", err)
|
||||
}
|
||||
}()
|
||||
@ -228,7 +104,7 @@ func (s *Service) wirePipeline(p *Pipeline) {
|
||||
default:
|
||||
// Drop if sink buffer is full, may flood logging for slow client
|
||||
s.logger.Debug("msg", "Dropped log entry - sink buffer full",
|
||||
"pipeline", p.Name)
|
||||
"pipeline", p.Config.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -238,16 +114,16 @@ func (s *Service) wirePipeline(p *Pipeline) {
|
||||
}
|
||||
|
||||
// Creates a source instance based on configuration
|
||||
func (s *Service) createSource(cfg config.SourceConfig) (source.Source, error) {
|
||||
func (s *Service) createSource(cfg *config.SourceConfig) (source.Source, error) {
|
||||
switch cfg.Type {
|
||||
case "directory":
|
||||
return source.NewDirectorySource(cfg.Options, s.logger)
|
||||
return source.NewDirectorySource(cfg.Directory, s.logger)
|
||||
case "stdin":
|
||||
return source.NewStdinSource(cfg.Options, s.logger)
|
||||
return source.NewStdinSource(cfg.Stdin, s.logger)
|
||||
case "http":
|
||||
return source.NewHTTPSource(cfg.Options, s.logger)
|
||||
return source.NewHTTPSource(cfg.HTTP, s.logger)
|
||||
case "tcp":
|
||||
return source.NewTCPSource(cfg.Options, s.logger)
|
||||
return source.NewTCPSource(cfg.TCP, s.logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown source type: %s", cfg.Type)
|
||||
}
|
||||
@ -255,34 +131,28 @@ func (s *Service) createSource(cfg config.SourceConfig) (source.Source, error) {
|
||||
|
||||
// Creates a sink instance based on configuration
|
||||
func (s *Service) createSink(cfg config.SinkConfig, formatter format.Formatter) (sink.Sink, error) {
|
||||
if formatter == nil {
|
||||
// Default formatters for different sink types
|
||||
defaultFormat := "raw"
|
||||
switch cfg.Type {
|
||||
case "http", "tcp", "http_client", "tcp_client":
|
||||
defaultFormat = "json"
|
||||
}
|
||||
|
||||
var err error
|
||||
formatter, err = format.NewFormatter(defaultFormat, nil, s.logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default formatter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
switch cfg.Type {
|
||||
case "http":
|
||||
return sink.NewHTTPSink(cfg.Options, s.logger, formatter)
|
||||
if cfg.HTTP == nil {
|
||||
return nil, fmt.Errorf("HTTP sink configuration missing")
|
||||
}
|
||||
return sink.NewHTTPSink(cfg.HTTP, s.logger, formatter)
|
||||
|
||||
case "tcp":
|
||||
return sink.NewTCPSink(cfg.Options, s.logger, formatter)
|
||||
if cfg.TCP == nil {
|
||||
return nil, fmt.Errorf("TCP sink configuration missing")
|
||||
}
|
||||
return sink.NewTCPSink(cfg.TCP, s.logger, formatter)
|
||||
|
||||
case "http_client":
|
||||
return sink.NewHTTPClientSink(cfg.Options, s.logger, formatter)
|
||||
return sink.NewHTTPClientSink(cfg.HTTPClient, s.logger, formatter)
|
||||
case "tcp_client":
|
||||
return sink.NewTCPClientSink(cfg.Options, s.logger, formatter)
|
||||
return sink.NewTCPClientSink(cfg.TCPClient, s.logger, formatter)
|
||||
case "file":
|
||||
return sink.NewFileSink(cfg.Options, s.logger, formatter)
|
||||
return sink.NewFileSink(cfg.File, s.logger, formatter)
|
||||
case "console":
|
||||
return sink.NewConsoleSink(cfg.Options, s.logger, formatter)
|
||||
return sink.NewConsoleSink(cfg.Console, s.logger, formatter)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown sink type: %s", cfg.Type)
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
|
||||
// ConsoleSink writes log entries to the console (stdout/stderr) using an dedicated logger instance
|
||||
type ConsoleSink struct {
|
||||
config *config.ConsoleSinkOptions
|
||||
input chan core.LogEntry
|
||||
writer *log.Logger // Dedicated internal logger instance for console writing
|
||||
done chan struct{}
|
||||
@ -31,22 +32,24 @@ type ConsoleSink struct {
|
||||
}
|
||||
|
||||
// Creates a new console sink
|
||||
func NewConsoleSink(options map[string]any, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) {
|
||||
target := "stdout"
|
||||
if t, ok := options["target"].(string); ok {
|
||||
target = t
|
||||
func NewConsoleSink(opts *config.ConsoleSinkOptions, appLogger *log.Logger, formatter format.Formatter) (*ConsoleSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("console sink options cannot be nil")
|
||||
}
|
||||
|
||||
bufferSize := int64(1000)
|
||||
if buf, ok := options["buffer_size"].(int64); ok && buf > 0 {
|
||||
bufferSize = buf
|
||||
// Set defaults if not configured
|
||||
if opts.Target == "" {
|
||||
opts.Target = "stdout"
|
||||
}
|
||||
if opts.BufferSize <= 0 {
|
||||
opts.BufferSize = 1000
|
||||
}
|
||||
|
||||
// Dedicated logger instance as console writer
|
||||
writer, err := log.NewBuilder().
|
||||
EnableFile(false).
|
||||
EnableConsole(true).
|
||||
ConsoleTarget(target).
|
||||
ConsoleTarget(opts.Target).
|
||||
Format("raw"). // Passthrough pre-formatted messages
|
||||
ShowTimestamp(false). // Disable writer's own timestamp
|
||||
ShowLevel(false). // Disable writer's own level prefix
|
||||
@ -57,7 +60,8 @@ func NewConsoleSink(options map[string]any, appLogger *log.Logger, formatter for
|
||||
}
|
||||
|
||||
s := &ConsoleSink{
|
||||
input: make(chan core.LogEntry, bufferSize),
|
||||
config: opts,
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
writer: writer,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
@ -157,7 +161,3 @@ func (s *ConsoleSink) processLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ConsoleSink) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to the console sink.
|
||||
}
|
||||
@ -5,10 +5,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"logwisp/src/internal/config"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
|
||||
// Writes log entries to files with rotation
|
||||
type FileSink struct {
|
||||
config *config.FileSinkOptions
|
||||
input chan core.LogEntry
|
||||
writer *log.Logger // Internal logger instance for file writing
|
||||
done chan struct{}
|
||||
@ -30,64 +31,27 @@ type FileSink struct {
|
||||
}
|
||||
|
||||
// Creates a new file sink
|
||||
func NewFileSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*FileSink, error) {
|
||||
directory, ok := options["directory"].(string)
|
||||
if !ok || directory == "" {
|
||||
directory = "./"
|
||||
logger.Warn("No directory or invalid directory provided, current directory will be used")
|
||||
}
|
||||
|
||||
name, ok := options["name"].(string)
|
||||
if !ok || name == "" {
|
||||
name = "logwisp.output"
|
||||
logger.Warn(fmt.Sprintf("No filename provided, %s will be used", name))
|
||||
func NewFileSink(opts *config.FileSinkOptions, logger *log.Logger, formatter format.Formatter) (*FileSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("file sink options cannot be nil")
|
||||
}
|
||||
|
||||
// Create configuration for the internal log writer
|
||||
writerConfig := log.DefaultConfig()
|
||||
writerConfig.Directory = directory
|
||||
writerConfig.Name = name
|
||||
writerConfig.Directory = opts.Directory
|
||||
writerConfig.Name = opts.Name
|
||||
writerConfig.EnableConsole = false // File only
|
||||
writerConfig.ShowTimestamp = false // We already have timestamps in entries
|
||||
writerConfig.ShowLevel = false // We already have levels in entries
|
||||
|
||||
// Add optional configurations
|
||||
if maxSize, ok := options["max_size_mb"].(int64); ok && maxSize > 0 {
|
||||
writerConfig.MaxSizeKB = maxSize * 1000
|
||||
}
|
||||
|
||||
if maxTotalSize, ok := options["max_total_size_mb"].(int64); ok && maxTotalSize >= 0 {
|
||||
writerConfig.MaxTotalSizeKB = maxTotalSize * 1000
|
||||
}
|
||||
|
||||
if retention, ok := options["retention_hours"].(int64); ok && retention > 0 {
|
||||
writerConfig.RetentionPeriodHrs = float64(retention)
|
||||
}
|
||||
|
||||
if minDiskFree, ok := options["min_disk_free_mb"].(int64); ok && minDiskFree > 0 {
|
||||
writerConfig.MinDiskFreeKB = minDiskFree * 1000
|
||||
}
|
||||
|
||||
// Create internal logger for file writing
|
||||
writer := log.NewLogger()
|
||||
if err := writer.ApplyConfig(writerConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize file writer: %w", err)
|
||||
}
|
||||
|
||||
// Start the internal file writer
|
||||
if err := writer.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start file writer: %w", err)
|
||||
}
|
||||
|
||||
// Buffer size for input channel
|
||||
// TODO: Centralized constant file in core package
|
||||
bufferSize := int64(1000)
|
||||
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
|
||||
bufferSize = bufSize
|
||||
}
|
||||
|
||||
fs := &FileSink{
|
||||
input: make(chan core.LogEntry, bufferSize),
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
writer: writer,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
@ -104,6 +68,11 @@ func (fs *FileSink) Input() chan<- core.LogEntry {
|
||||
}
|
||||
|
||||
func (fs *FileSink) Start(ctx context.Context) error {
|
||||
// Start the internal file writer
|
||||
if err := fs.writer.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start sink file writer: %w", err)
|
||||
}
|
||||
|
||||
go fs.processLoop(ctx)
|
||||
fs.logger.Info("msg", "File sink started", "component", "file_sink")
|
||||
return nil
|
||||
@ -167,7 +136,3 @@ func (fs *FileSink) processLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fs *FileSink) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to file sink
|
||||
}
|
||||
@ -26,8 +26,11 @@ import (
|
||||
|
||||
// Streams log entries via Server-Sent Events
|
||||
type HTTPSink struct {
|
||||
// Configuration reference (NOT a copy)
|
||||
config *config.HTTPSinkOptions
|
||||
|
||||
// Runtime
|
||||
input chan core.LogEntry
|
||||
config HTTPConfig
|
||||
server *fasthttp.Server
|
||||
activeClients atomic.Int64
|
||||
mu sync.RWMutex
|
||||
@ -46,11 +49,7 @@ type HTTPSink struct {
|
||||
// Security components
|
||||
authenticator *auth.Authenticator
|
||||
tlsManager *tls.Manager
|
||||
authConfig *config.AuthConfig
|
||||
|
||||
// Path configuration
|
||||
streamPath string
|
||||
statusPath string
|
||||
authConfig *config.ServerAuthConfig
|
||||
|
||||
// Net limiting
|
||||
netLimiter *limit.NetLimiter
|
||||
@ -62,151 +61,58 @@ type HTTPSink struct {
|
||||
authSuccesses atomic.Uint64
|
||||
}
|
||||
|
||||
// Holds HTTP sink configuration
|
||||
type HTTPConfig struct {
|
||||
Host string
|
||||
Port int64
|
||||
BufferSize int64
|
||||
StreamPath string
|
||||
StatusPath string
|
||||
Heartbeat *config.HeartbeatConfig
|
||||
TLS *config.TLSConfig
|
||||
NetLimit *config.NetLimitConfig
|
||||
}
|
||||
|
||||
// Creates a new HTTP streaming sink
|
||||
func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*HTTPSink, error) {
|
||||
cfg := HTTPConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
BufferSize: 1000,
|
||||
StreamPath: "/stream",
|
||||
StatusPath: "/status",
|
||||
}
|
||||
|
||||
// Extract configuration from options
|
||||
if host, ok := options["host"].(string); ok && host != "" {
|
||||
cfg.Host = host
|
||||
}
|
||||
if port, ok := options["port"].(int64); ok {
|
||||
cfg.Port = port
|
||||
}
|
||||
if bufSize, ok := options["buffer_size"].(int64); ok {
|
||||
cfg.BufferSize = bufSize
|
||||
}
|
||||
if path, ok := options["stream_path"].(string); ok {
|
||||
cfg.StreamPath = path
|
||||
}
|
||||
if path, ok := options["status_path"].(string); ok {
|
||||
cfg.StatusPath = path
|
||||
}
|
||||
|
||||
// Extract heartbeat config
|
||||
if hb, ok := options["heartbeat"].(map[string]any); ok {
|
||||
cfg.Heartbeat = &config.HeartbeatConfig{}
|
||||
cfg.Heartbeat.Enabled, _ = hb["enabled"].(bool)
|
||||
if interval, ok := hb["interval_seconds"].(int64); ok {
|
||||
cfg.Heartbeat.IntervalSeconds = interval
|
||||
}
|
||||
cfg.Heartbeat.IncludeTimestamp, _ = hb["include_timestamp"].(bool)
|
||||
cfg.Heartbeat.IncludeStats, _ = hb["include_stats"].(bool)
|
||||
if hbFormat, ok := hb["format"].(string); ok {
|
||||
cfg.Heartbeat.Format = hbFormat
|
||||
}
|
||||
}
|
||||
|
||||
// Extract TLS config
|
||||
if tc, ok := options["tls"].(map[string]any); ok {
|
||||
cfg.TLS = &config.TLSConfig{}
|
||||
cfg.TLS.Enabled, _ = tc["enabled"].(bool)
|
||||
if certFile, ok := tc["cert_file"].(string); ok {
|
||||
cfg.TLS.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := tc["key_file"].(string); ok {
|
||||
cfg.TLS.KeyFile = keyFile
|
||||
}
|
||||
cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool)
|
||||
if caFile, ok := tc["client_ca_file"].(string); ok {
|
||||
cfg.TLS.ClientCAFile = caFile
|
||||
}
|
||||
cfg.TLS.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
|
||||
if minVer, ok := tc["min_version"].(string); ok {
|
||||
cfg.TLS.MinVersion = minVer
|
||||
}
|
||||
if maxVer, ok := tc["max_version"].(string); ok {
|
||||
cfg.TLS.MaxVersion = maxVer
|
||||
}
|
||||
if ciphers, ok := tc["cipher_suites"].(string); ok {
|
||||
cfg.TLS.CipherSuites = ciphers
|
||||
}
|
||||
}
|
||||
|
||||
// Extract net limit config
|
||||
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||
cfg.NetLimit = &config.NetLimitConfig{}
|
||||
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
|
||||
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||
cfg.NetLimit.RequestsPerSecond = rps
|
||||
}
|
||||
if burst, ok := nl["burst_size"].(int64); ok {
|
||||
cfg.NetLimit.BurstSize = burst
|
||||
}
|
||||
if respCode, ok := nl["response_code"].(int64); ok {
|
||||
cfg.NetLimit.ResponseCode = respCode
|
||||
}
|
||||
if msg, ok := nl["response_message"].(string); ok {
|
||||
cfg.NetLimit.ResponseMessage = msg
|
||||
}
|
||||
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
|
||||
}
|
||||
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsTotal = maxTotal
|
||||
}
|
||||
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||
for _, entry := range ipWhitelist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||
for _, entry := range ipBlacklist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
func NewHTTPSink(opts *config.HTTPSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("HTTP sink options cannot be nil")
|
||||
}
|
||||
|
||||
h := &HTTPSink{
|
||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||
config: cfg,
|
||||
config: opts, // Direct reference to config struct
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
startTime: time.Now(),
|
||||
done: make(chan struct{}),
|
||||
streamPath: cfg.StreamPath,
|
||||
statusPath: cfg.StatusPath,
|
||||
logger: logger,
|
||||
formatter: formatter,
|
||||
clients: make(map[uint64]chan core.LogEntry),
|
||||
unregister: make(chan uint64, 10), // Buffered for non-blocking
|
||||
}
|
||||
|
||||
h.lastProcessed.Store(time.Time{})
|
||||
|
||||
// Initialize TLS manager
|
||||
if cfg.TLS != nil && cfg.TLS.Enabled {
|
||||
tlsManager, err := tls.NewManager(cfg.TLS, logger)
|
||||
// Initialize TLS manager if configured
|
||||
if opts.TLS != nil && opts.TLS.Enabled {
|
||||
tlsManager, err := tls.NewManager(opts.TLS, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
h.tlsManager = tlsManager
|
||||
logger.Info("msg", "TLS enabled",
|
||||
"component", "http_sink")
|
||||
}
|
||||
|
||||
// Initialize net limiter if configured
|
||||
if cfg.NetLimit != nil && cfg.NetLimit.Enabled {
|
||||
h.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger)
|
||||
if opts.NetLimit != nil && (opts.NetLimit.Enabled ||
|
||||
len(opts.NetLimit.IPWhitelist) > 0 ||
|
||||
len(opts.NetLimit.IPBlacklist) > 0) {
|
||||
h.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
|
||||
}
|
||||
|
||||
// Initialize authenticator if auth is not "none"
|
||||
if opts.Auth != nil && opts.Auth.Type != "none" {
|
||||
// Only "basic" and "token" are valid for HTTP sink
|
||||
if opts.Auth.Type != "basic" && opts.Auth.Type != "token" {
|
||||
return nil, fmt.Errorf("invalid auth type '%s' for HTTP sink (valid: none, basic, token)", opts.Auth.Type)
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(opts.Auth, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create authenticator: %w", err)
|
||||
}
|
||||
h.authenticator = authenticator
|
||||
h.authConfig = opts.Auth
|
||||
logger.Info("msg", "Authentication enabled",
|
||||
"component", "http_sink",
|
||||
"type", opts.Auth.Type)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
@ -230,6 +136,9 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
DisableKeepalive: false,
|
||||
StreamRequestBody: true,
|
||||
Logger: fasthttpLogger,
|
||||
// ReadTimeout: time.Duration(h.config.ReadTimeout) * time.Millisecond,
|
||||
WriteTimeout: time.Duration(h.config.WriteTimeout) * time.Millisecond,
|
||||
// MaxRequestBodySize: int(h.config.MaxBodySize),
|
||||
}
|
||||
|
||||
// Configure TLS if enabled
|
||||
@ -250,8 +159,8 @@ func (h *HTTPSink) Start(ctx context.Context) error {
|
||||
"component", "http_sink",
|
||||
"host", h.config.Host,
|
||||
"port", h.config.Port,
|
||||
"stream_path", h.streamPath,
|
||||
"status_path", h.statusPath,
|
||||
"stream_path", h.config.StreamPath,
|
||||
"status_path", h.config.StatusPath,
|
||||
"tls_enabled", h.tlsManager != nil)
|
||||
|
||||
var err error
|
||||
@ -296,7 +205,7 @@ func (h *HTTPSink) brokerLoop(ctx context.Context) {
|
||||
var tickerChan <-chan time.Time
|
||||
|
||||
if h.config.Heartbeat != nil && h.config.Heartbeat.Enabled {
|
||||
ticker = time.NewTicker(time.Duration(h.config.Heartbeat.IntervalSeconds) * time.Second)
|
||||
ticker = time.NewTicker(time.Duration(h.config.Heartbeat.Interval) * time.Second)
|
||||
tickerChan = ticker.C
|
||||
defer ticker.Stop()
|
||||
}
|
||||
@ -441,8 +350,8 @@ func (h *HTTPSink) GetStats() SinkStats {
|
||||
"port": h.config.Port,
|
||||
"buffer_size": h.config.BufferSize,
|
||||
"endpoints": map[string]string{
|
||||
"stream": h.streamPath,
|
||||
"status": h.statusPath,
|
||||
"stream": h.config.StreamPath,
|
||||
"status": h.config.StatusPath,
|
||||
},
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
@ -489,7 +398,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Status endpoint doesn't require auth
|
||||
if path == h.statusPath {
|
||||
if path == h.config.StatusPath {
|
||||
h.handleStatus(ctx)
|
||||
return
|
||||
}
|
||||
@ -509,14 +418,14 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
|
||||
// Return 401 with WWW-Authenticate header
|
||||
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
||||
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
|
||||
realm := h.authConfig.BasicAuth.Realm
|
||||
if h.authConfig.Type == "basic" && h.authConfig.Basic != nil {
|
||||
realm := h.authConfig.Basic.Realm
|
||||
if realm == "" {
|
||||
realm = "Restricted"
|
||||
}
|
||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", realm))
|
||||
} else if h.authConfig.Type == "bearer" {
|
||||
ctx.Response.Header.Set("WWW-Authenticate", "Bearer")
|
||||
} else if h.authConfig.Type == "token" {
|
||||
ctx.Response.Header.Set("WWW-Authenticate", "Token")
|
||||
}
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
@ -538,7 +447,7 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
|
||||
switch path {
|
||||
case h.streamPath:
|
||||
case h.config.StreamPath:
|
||||
h.handleStream(ctx, session)
|
||||
default:
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
@ -547,6 +456,15 @@ func (h *HTTPSink) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
"error": "Not Found",
|
||||
})
|
||||
}
|
||||
// Handle stream endpoint
|
||||
// if path == h.config.StreamPath {
|
||||
// h.handleStream(ctx, session)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// // Unknown path
|
||||
// ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
// ctx.SetBody([]byte("Not Found"))
|
||||
}
|
||||
|
||||
func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session) {
|
||||
@ -611,8 +529,8 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
|
||||
"client_id": fmt.Sprintf("%d", clientID),
|
||||
"username": session.Username,
|
||||
"auth_method": session.Method,
|
||||
"stream_path": h.streamPath,
|
||||
"status_path": h.statusPath,
|
||||
"stream_path": h.config.StreamPath,
|
||||
"status_path": h.config.StatusPath,
|
||||
"buffer_size": h.config.BufferSize,
|
||||
"tls": h.tlsManager != nil,
|
||||
}
|
||||
@ -627,7 +545,7 @@ func (h *HTTPSink) handleStream(ctx *fasthttp.RequestCtx, session *auth.Session)
|
||||
var tickerChan <-chan time.Time
|
||||
|
||||
if h.config.Heartbeat != nil && h.config.Heartbeat.Enabled {
|
||||
ticker = time.NewTicker(time.Duration(h.config.Heartbeat.IntervalSeconds) * time.Second)
|
||||
ticker = time.NewTicker(time.Duration(h.config.Heartbeat.Interval) * time.Second)
|
||||
tickerChan = ticker.C
|
||||
defer ticker.Stop()
|
||||
}
|
||||
@ -716,7 +634,7 @@ func (h *HTTPSink) createHeartbeatEntry() core.LogEntry {
|
||||
fields := make(map[string]any)
|
||||
fields["type"] = "heartbeat"
|
||||
|
||||
if h.config.Heartbeat.IncludeStats {
|
||||
if h.config.Heartbeat.Enabled {
|
||||
fields["active_clients"] = h.activeClients.Load()
|
||||
fields["uptime_seconds"] = int(time.Since(h.startTime).Seconds())
|
||||
}
|
||||
@ -775,13 +693,13 @@ func (h *HTTPSink) handleStatus(ctx *fasthttp.RequestCtx) {
|
||||
"uptime_seconds": int(time.Since(h.startTime).Seconds()),
|
||||
},
|
||||
"endpoints": map[string]string{
|
||||
"transport": h.streamPath,
|
||||
"status": h.statusPath,
|
||||
"transport": h.config.StreamPath,
|
||||
"status": h.config.StatusPath,
|
||||
},
|
||||
"features": map[string]any{
|
||||
"heartbeat": map[string]any{
|
||||
"enabled": h.config.Heartbeat.Enabled,
|
||||
"interval": h.config.Heartbeat.IntervalSeconds,
|
||||
"interval": h.config.Heartbeat.Interval,
|
||||
"format": h.config.Heartbeat.Format,
|
||||
},
|
||||
"tls": tlsStats,
|
||||
@ -806,37 +724,15 @@ func (h *HTTPSink) GetActiveConnections() int64 {
|
||||
|
||||
// Returns the configured transport endpoint path
|
||||
func (h *HTTPSink) GetStreamPath() string {
|
||||
return h.streamPath
|
||||
return h.config.StreamPath
|
||||
}
|
||||
|
||||
// Returns the configured status endpoint path
|
||||
func (h *HTTPSink) GetStatusPath() string {
|
||||
return h.statusPath
|
||||
return h.config.StatusPath
|
||||
}
|
||||
|
||||
// Returns the configured host
|
||||
func (h *HTTPSink) GetHost() string {
|
||||
return h.config.Host
|
||||
}
|
||||
|
||||
// Configures http sink auth
|
||||
func (h *HTTPSink) SetAuth(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
h.authConfig = authCfg
|
||||
authenticator, err := auth.NewAuthenticator(authCfg, h.logger)
|
||||
if err != nil {
|
||||
h.logger.Error("msg", "Failed to initialize authenticator for HTTP sink",
|
||||
"component", "http_sink",
|
||||
"error", err)
|
||||
// Continue without auth
|
||||
return
|
||||
}
|
||||
h.authenticator = authenticator
|
||||
|
||||
h.logger.Info("msg", "Authentication configured for HTTP sink",
|
||||
"component", "http_sink",
|
||||
"auth_type", authCfg.Type)
|
||||
}
|
||||
@ -8,7 +8,6 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -28,7 +27,7 @@ import (
|
||||
// Forwards log entries to a remote HTTP endpoint
|
||||
type HTTPClientSink struct {
|
||||
input chan core.LogEntry
|
||||
config HTTPClientConfig
|
||||
config *config.HTTPClientSinkOptions
|
||||
client *fasthttp.Client
|
||||
batch []core.LogEntry
|
||||
batchMu sync.Mutex
|
||||
@ -48,195 +47,16 @@ type HTTPClientSink struct {
|
||||
activeConnections atomic.Int64
|
||||
}
|
||||
|
||||
// Holds HTTP client sink configuration
|
||||
// TODO: missing toml tags
|
||||
type HTTPClientConfig struct {
|
||||
// Config
|
||||
URL string `toml:"url"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
BatchSize int64 `toml:"batch_size"`
|
||||
BatchDelay time.Duration `toml:"batch_delay_ms"`
|
||||
Timeout time.Duration `toml:"timeout_seconds"`
|
||||
Headers map[string]string `toml:"headers"`
|
||||
|
||||
// Retry configuration
|
||||
MaxRetries int64 `toml:"max_retries"`
|
||||
RetryDelay time.Duration `toml:"retry_delay"`
|
||||
RetryBackoff float64 `toml:"retry_backoff"` // Multiplier for exponential backoff
|
||||
|
||||
// Security
|
||||
AuthType string `toml:"auth_type"` // "none", "basic", "bearer", "mtls"
|
||||
Username string `toml:"username"` // For basic auth
|
||||
Password string `toml:"password"` // For basic auth
|
||||
BearerToken string `toml:"bearer_token"` // For bearer auth
|
||||
|
||||
// TLS configuration
|
||||
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
|
||||
CAFile string `toml:"ca_file"`
|
||||
CertFile string `toml:"cert_file"`
|
||||
KeyFile string `toml:"key_file"`
|
||||
}
|
||||
|
||||
// Creates a new HTTP client sink
|
||||
func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*HTTPClientSink, error) {
|
||||
cfg := HTTPClientConfig{
|
||||
BufferSize: int64(1000),
|
||||
BatchSize: int64(100),
|
||||
BatchDelay: time.Second,
|
||||
Timeout: 30 * time.Second,
|
||||
MaxRetries: int64(3),
|
||||
RetryDelay: time.Second,
|
||||
RetryBackoff: float64(2.0),
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
|
||||
// Extract URL
|
||||
urlStr, ok := options["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return nil, fmt.Errorf("http_client sink requires 'url' option")
|
||||
}
|
||||
|
||||
// Validate URL
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return nil, fmt.Errorf("URL must use http or https scheme")
|
||||
}
|
||||
cfg.URL = urlStr
|
||||
|
||||
// Extract other options
|
||||
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
|
||||
cfg.BufferSize = bufSize
|
||||
}
|
||||
if batchSize, ok := options["batch_size"].(int64); ok && batchSize > 0 {
|
||||
cfg.BatchSize = batchSize
|
||||
}
|
||||
if delayMs, ok := options["batch_delay_ms"].(int64); ok && delayMs > 0 {
|
||||
cfg.BatchDelay = time.Duration(delayMs) * time.Millisecond
|
||||
}
|
||||
if timeoutSec, ok := options["timeout_seconds"].(int64); ok && timeoutSec > 0 {
|
||||
cfg.Timeout = time.Duration(timeoutSec) * time.Second
|
||||
}
|
||||
if maxRetries, ok := options["max_retries"].(int64); ok && maxRetries >= 0 {
|
||||
cfg.MaxRetries = maxRetries
|
||||
}
|
||||
if retryDelayMs, ok := options["retry_delay_ms"].(int64); ok && retryDelayMs > 0 {
|
||||
cfg.RetryDelay = time.Duration(retryDelayMs) * time.Millisecond
|
||||
}
|
||||
if backoff, ok := options["retry_backoff"].(float64); ok && backoff >= 1.0 {
|
||||
cfg.RetryBackoff = backoff
|
||||
}
|
||||
if insecure, ok := options["insecure_skip_verify"].(bool); ok {
|
||||
cfg.InsecureSkipVerify = insecure
|
||||
}
|
||||
if authType, ok := options["auth_type"].(string); ok {
|
||||
switch authType {
|
||||
case "none", "basic", "bearer", "mtls":
|
||||
cfg.AuthType = authType
|
||||
default:
|
||||
return nil, fmt.Errorf("http_client sink: invalid auth_type '%s'", authType)
|
||||
}
|
||||
} else {
|
||||
cfg.AuthType = "none"
|
||||
}
|
||||
if username, ok := options["username"].(string); ok {
|
||||
cfg.Username = username
|
||||
}
|
||||
if password, ok := options["password"].(string); ok {
|
||||
cfg.Password = password // TODO: change to Argon2 hashed password
|
||||
}
|
||||
if token, ok := options["bearer_token"].(string); ok {
|
||||
cfg.BearerToken = token
|
||||
}
|
||||
|
||||
// Validate auth configuration and TLS enforcement
|
||||
isHTTPS := strings.HasPrefix(cfg.URL, "https://")
|
||||
|
||||
switch cfg.AuthType {
|
||||
case "basic":
|
||||
if cfg.Username == "" || cfg.Password == "" {
|
||||
return nil, fmt.Errorf("http_client sink: username and password required for basic auth")
|
||||
}
|
||||
if !isHTTPS {
|
||||
return nil, fmt.Errorf("http_client sink: basic auth requires HTTPS (security: credentials would be sent in plaintext)")
|
||||
}
|
||||
|
||||
case "bearer":
|
||||
if cfg.BearerToken == "" {
|
||||
return nil, fmt.Errorf("http_client sink: bearer_token required for bearer auth")
|
||||
}
|
||||
if !isHTTPS {
|
||||
return nil, fmt.Errorf("http_client sink: bearer auth requires HTTPS (security: token would be sent in plaintext)")
|
||||
}
|
||||
|
||||
case "mtls":
|
||||
if !isHTTPS {
|
||||
return nil, fmt.Errorf("http_client sink: mTLS requires HTTPS")
|
||||
}
|
||||
if cfg.CertFile == "" || cfg.KeyFile == "" {
|
||||
return nil, fmt.Errorf("http_client sink: cert_file and key_file required for mTLS")
|
||||
}
|
||||
|
||||
case "none":
|
||||
// Clear any credentials if auth is "none"
|
||||
if cfg.Username != "" || cfg.Password != "" || cfg.BearerToken != "" {
|
||||
logger.Warn("msg", "Credentials provided but auth_type is 'none', ignoring",
|
||||
"component", "http_client_sink")
|
||||
cfg.Username = ""
|
||||
cfg.Password = ""
|
||||
cfg.BearerToken = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Extract headers
|
||||
if headers, ok := options["headers"].(map[string]any); ok {
|
||||
for k, v := range headers {
|
||||
if strVal, ok := v.(string); ok {
|
||||
cfg.Headers[k] = strVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set default Content-Type if not specified
|
||||
if _, exists := cfg.Headers["Content-Type"]; !exists {
|
||||
cfg.Headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
// Extract TLS options
|
||||
if caFile, ok := options["ca_file"].(string); ok && caFile != "" {
|
||||
cfg.CAFile = caFile
|
||||
}
|
||||
|
||||
// Extract client certificate options from TLS config
|
||||
if tc, ok := options["tls"].(map[string]any); ok {
|
||||
if enabled, _ := tc["enabled"].(bool); enabled {
|
||||
// Extract client certificate files for mTLS
|
||||
if certFile, ok := tc["cert_file"].(string); ok && certFile != "" {
|
||||
if keyFile, ok := tc["key_file"].(string); ok && keyFile != "" {
|
||||
// These will be used below when configuring TLS
|
||||
cfg.CertFile = certFile // Need to add these fields to HTTPClientConfig
|
||||
cfg.KeyFile = keyFile
|
||||
}
|
||||
}
|
||||
// Extract CA file from TLS config if not already set
|
||||
if cfg.CAFile == "" {
|
||||
if caFile, ok := tc["ca_file"].(string); ok {
|
||||
cfg.CAFile = caFile
|
||||
}
|
||||
}
|
||||
// Extract insecure skip verify from TLS config
|
||||
if insecure, ok := tc["insecure_skip_verify"].(bool); ok {
|
||||
cfg.InsecureSkipVerify = insecure
|
||||
}
|
||||
}
|
||||
func NewHTTPClientSink(opts *config.HTTPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*HTTPClientSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("HTTP client sink options cannot be nil")
|
||||
}
|
||||
|
||||
h := &HTTPClientSink{
|
||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||
config: cfg,
|
||||
batch: make([]core.LogEntry, 0, cfg.BatchSize),
|
||||
config: opts,
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
batch: make([]core.LogEntry, 0, opts.BatchSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
@ -249,46 +69,48 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for
|
||||
h.client = &fasthttp.Client{
|
||||
MaxConnsPerHost: 10,
|
||||
MaxIdleConnDuration: 10 * time.Second,
|
||||
ReadTimeout: cfg.Timeout,
|
||||
WriteTimeout: cfg.Timeout,
|
||||
ReadTimeout: time.Duration(opts.Timeout) * time.Second,
|
||||
WriteTimeout: time.Duration(opts.Timeout) * time.Second,
|
||||
DisableHeaderNamesNormalizing: true,
|
||||
}
|
||||
|
||||
// Configure TLS if using HTTPS
|
||||
if strings.HasPrefix(cfg.URL, "https://") {
|
||||
if strings.HasPrefix(opts.URL, "https://") {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
InsecureSkipVerify: opts.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
// Load custom CA for server verification if provided
|
||||
if cfg.CAFile != "" {
|
||||
caCert, err := os.ReadFile(cfg.CAFile)
|
||||
// Use TLS config if provided
|
||||
if opts.TLS != nil {
|
||||
// Load custom CA for server verification
|
||||
if opts.TLS.CAFile != "" {
|
||||
caCert, err := os.ReadFile(opts.TLS.CAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.CAFile, err)
|
||||
return nil, fmt.Errorf("failed to read CA file '%s': %w", opts.TLS.CAFile, err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.CAFile)
|
||||
return nil, fmt.Errorf("failed to parse CA certificate from '%s'", opts.TLS.CAFile)
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
logger.Debug("msg", "Custom CA loaded for server verification",
|
||||
"component", "http_client_sink",
|
||||
"ca_file", cfg.CAFile)
|
||||
"ca_file", opts.TLS.CAFile)
|
||||
}
|
||||
|
||||
// Load client certificate for mTLS if provided
|
||||
if cfg.CertFile != "" && cfg.KeyFile != "" {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
|
||||
if opts.TLS.CertFile != "" && opts.TLS.KeyFile != "" {
|
||||
cert, err := tls.LoadX509KeyPair(opts.TLS.CertFile, opts.TLS.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client certificate: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
logger.Info("msg", "Client certificate loaded for mTLS",
|
||||
"component", "http_client_sink",
|
||||
"cert_file", cfg.CertFile)
|
||||
"cert_file", opts.TLS.CertFile)
|
||||
}
|
||||
}
|
||||
|
||||
// Set TLS config directly on the client
|
||||
h.client.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
@ -308,7 +130,7 @@ func (h *HTTPClientSink) Start(ctx context.Context) error {
|
||||
"component", "http_client_sink",
|
||||
"url", h.config.URL,
|
||||
"batch_size", h.config.BatchSize,
|
||||
"batch_delay", h.config.BatchDelay)
|
||||
"batch_delay_ms", h.config.BatchDelayMS)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -399,7 +221,7 @@ func (h *HTTPClientSink) processLoop(ctx context.Context) {
|
||||
func (h *HTTPClientSink) batchTimer(ctx context.Context) {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(h.config.BatchDelay)
|
||||
ticker := time.NewTicker(time.Duration(h.config.BatchDelayMS) * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@ -468,7 +290,7 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
|
||||
// Retry logic
|
||||
var lastErr error
|
||||
retryDelay := h.config.RetryDelay
|
||||
retryDelay := time.Duration(h.config.RetryDelayMS) * time.Millisecond
|
||||
|
||||
// TODO: verify retry loop placement is correct or should it be after acquiring resources (req :=....)
|
||||
for attempt := int64(0); attempt <= h.config.MaxRetries; attempt++ {
|
||||
@ -480,9 +302,10 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
newDelay := time.Duration(float64(retryDelay) * h.config.RetryBackoff)
|
||||
|
||||
// Cap at maximum to prevent integer overflow
|
||||
if newDelay > h.config.Timeout || newDelay < retryDelay {
|
||||
timeout := time.Duration(h.config.Timeout) * time.Second
|
||||
if newDelay > timeout || newDelay < retryDelay {
|
||||
// Either exceeded max or overflowed (negative/wrapped)
|
||||
retryDelay = h.config.Timeout
|
||||
retryDelay = timeout
|
||||
} else {
|
||||
retryDelay = newDelay
|
||||
}
|
||||
@ -500,14 +323,14 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("LogWisp/%s", version.Short()))
|
||||
|
||||
// Add authentication based on auth type
|
||||
switch h.config.AuthType {
|
||||
switch h.config.Auth.Type {
|
||||
case "basic":
|
||||
creds := h.config.Username + ":" + h.config.Password
|
||||
creds := h.config.Auth.Username + ":" + h.config.Auth.Password
|
||||
encodedCreds := base64.StdEncoding.EncodeToString([]byte(creds))
|
||||
req.Header.Set("Authorization", "Basic "+encodedCreds)
|
||||
|
||||
case "bearer":
|
||||
req.Header.Set("Authorization", "Bearer "+h.config.BearerToken)
|
||||
case "token":
|
||||
req.Header.Set("Authorization", "Token "+h.config.Auth.Token)
|
||||
|
||||
case "mtls":
|
||||
// mTLS auth is handled at TLS layer via client certificates
|
||||
@ -523,7 +346,7 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
}
|
||||
|
||||
// Send request
|
||||
err := h.client.DoTimeout(req, resp, h.config.Timeout)
|
||||
err := h.client.DoTimeout(req, resp, time.Duration(h.config.Timeout)*time.Second)
|
||||
|
||||
// Capture response before releasing
|
||||
statusCode := resp.StatusCode()
|
||||
@ -588,9 +411,3 @@ func (h *HTTPClientSink) sendBatch(batch []core.LogEntry) {
|
||||
"last_error", lastErr)
|
||||
h.failedBatches.Add(1)
|
||||
}
|
||||
|
||||
// Not applicable, Clients authenticate to remote servers using Username/Password in config
|
||||
func (h *HTTPClientSink) SetAuth(authCfg *config.AuthConfig) {
|
||||
// No-op: client sinks don't validate incoming connections
|
||||
// They authenticate to remote servers using Username/Password fields
|
||||
}
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
@ -22,9 +21,6 @@ type Sink interface {
|
||||
|
||||
// Returns sink statistics
|
||||
GetStats() SinkStats
|
||||
|
||||
// Configure authentication
|
||||
SetAuth(auth *config.AuthConfig)
|
||||
}
|
||||
|
||||
// Contains statistics about a sink
|
||||
|
||||
@ -7,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -25,9 +24,8 @@ import (
|
||||
|
||||
// Streams log entries via TCP
|
||||
type TCPSink struct {
|
||||
// C
|
||||
input chan core.LogEntry
|
||||
config TCPConfig
|
||||
config *config.TCPSinkOptions
|
||||
server *tcpServer
|
||||
done chan struct{}
|
||||
activeConns atomic.Int64
|
||||
@ -38,13 +36,10 @@ type TCPSink struct {
|
||||
netLimiter *limit.NetLimiter
|
||||
logger *log.Logger
|
||||
formatter format.Formatter
|
||||
authenticator *auth.Authenticator
|
||||
|
||||
// Statistics
|
||||
totalProcessed atomic.Uint64
|
||||
lastProcessed atomic.Value // time.Time
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
|
||||
// Write error tracking
|
||||
writeErrors atomic.Uint64
|
||||
@ -62,87 +57,14 @@ type TCPConfig struct {
|
||||
}
|
||||
|
||||
// Creates a new TCP streaming sink
|
||||
func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*TCPSink, error) {
|
||||
cfg := TCPConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: int64(9090),
|
||||
BufferSize: int64(1000),
|
||||
}
|
||||
|
||||
// Extract configuration from options
|
||||
if host, ok := options["host"].(string); ok && host != "" {
|
||||
cfg.Host = host
|
||||
}
|
||||
if port, ok := options["port"].(int64); ok {
|
||||
cfg.Port = port
|
||||
}
|
||||
if bufSize, ok := options["buffer_size"].(int64); ok {
|
||||
cfg.BufferSize = bufSize
|
||||
}
|
||||
|
||||
// Extract heartbeat config
|
||||
if hb, ok := options["heartbeat"].(map[string]any); ok {
|
||||
cfg.Heartbeat = &config.HeartbeatConfig{}
|
||||
cfg.Heartbeat.Enabled, _ = hb["enabled"].(bool)
|
||||
if interval, ok := hb["interval_seconds"].(int64); ok {
|
||||
cfg.Heartbeat.IntervalSeconds = interval
|
||||
}
|
||||
cfg.Heartbeat.IncludeTimestamp, _ = hb["include_timestamp"].(bool)
|
||||
cfg.Heartbeat.IncludeStats, _ = hb["include_stats"].(bool)
|
||||
if hbFormat, ok := hb["format"].(string); ok {
|
||||
cfg.Heartbeat.Format = hbFormat
|
||||
}
|
||||
}
|
||||
|
||||
// Extract net limit config
|
||||
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||
cfg.NetLimit = &config.NetLimitConfig{}
|
||||
cfg.NetLimit.Enabled, _ = nl["enabled"].(bool)
|
||||
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||
cfg.NetLimit.RequestsPerSecond = rps
|
||||
}
|
||||
if burst, ok := nl["burst_size"].(int64); ok {
|
||||
cfg.NetLimit.BurstSize = burst
|
||||
}
|
||||
if respCode, ok := nl["response_code"].(int64); ok {
|
||||
cfg.NetLimit.ResponseCode = respCode
|
||||
}
|
||||
if msg, ok := nl["response_message"].(string); ok {
|
||||
cfg.NetLimit.ResponseMessage = msg
|
||||
}
|
||||
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsPerIP = maxPerIP
|
||||
}
|
||||
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsPerUser = maxPerUser
|
||||
}
|
||||
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsPerToken = maxPerToken
|
||||
}
|
||||
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||
cfg.NetLimit.MaxConnectionsTotal = maxTotal
|
||||
}
|
||||
if ipWhitelist, ok := nl["ip_whitelist"].([]any); ok {
|
||||
cfg.NetLimit.IPWhitelist = make([]string, 0, len(ipWhitelist))
|
||||
for _, entry := range ipWhitelist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPWhitelist = append(cfg.NetLimit.IPWhitelist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipBlacklist, ok := nl["ip_blacklist"].([]any); ok {
|
||||
cfg.NetLimit.IPBlacklist = make([]string, 0, len(ipBlacklist))
|
||||
for _, entry := range ipBlacklist {
|
||||
if str, ok := entry.(string); ok {
|
||||
cfg.NetLimit.IPBlacklist = append(cfg.NetLimit.IPBlacklist, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
func NewTCPSink(opts *config.TCPSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPSink, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("TCP sink options cannot be nil")
|
||||
}
|
||||
|
||||
t := &TCPSink{
|
||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||
config: cfg,
|
||||
config: opts, // Direct reference to config
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
@ -150,9 +72,11 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For
|
||||
}
|
||||
t.lastProcessed.Store(time.Time{})
|
||||
|
||||
// Initialize net limiter
|
||||
if cfg.NetLimit != nil && cfg.NetLimit.Enabled {
|
||||
t.netLimiter = limit.NewNetLimiter(*cfg.NetLimit, logger)
|
||||
// Initialize net limiter with pointer
|
||||
if opts.NetLimit != nil && (opts.NetLimit.Enabled ||
|
||||
len(opts.NetLimit.IPWhitelist) > 0 ||
|
||||
len(opts.NetLimit.IPBlacklist) > 0) {
|
||||
t.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
@ -193,8 +117,7 @@ func (t *TCPSink) Start(ctx context.Context) error {
|
||||
go func() {
|
||||
t.logger.Info("msg", "Starting TCP server",
|
||||
"component", "tcp_sink",
|
||||
"port", t.config.Port,
|
||||
"auth", t.authenticator != nil)
|
||||
"port", t.config.Port)
|
||||
|
||||
err := gnet.Run(t.server, addr, opts...)
|
||||
if err != nil {
|
||||
@ -282,7 +205,7 @@ func (t *TCPSink) broadcastLoop(ctx context.Context) {
|
||||
var tickerChan <-chan time.Time
|
||||
|
||||
if t.config.Heartbeat != nil && t.config.Heartbeat.Enabled {
|
||||
ticker = time.NewTicker(time.Duration(t.config.Heartbeat.IntervalSeconds) * time.Second)
|
||||
ticker = time.NewTicker(time.Duration(t.config.Heartbeat.Interval) * time.Second)
|
||||
tickerChan = ticker.C
|
||||
defer ticker.Stop()
|
||||
}
|
||||
@ -329,8 +252,7 @@ func (t *TCPSink) broadcastData(data []byte) {
|
||||
t.server.mu.RLock()
|
||||
defer t.server.mu.RUnlock()
|
||||
|
||||
for conn, client := range t.server.clients {
|
||||
if client.authenticated {
|
||||
for conn, _ := range t.server.clients {
|
||||
conn.AsyncWrite(data, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
t.writeErrors.Add(1)
|
||||
@ -344,7 +266,6 @@ func (t *TCPSink) broadcastData(data []byte) {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle write errors with threshold-based connection termination
|
||||
@ -410,7 +331,6 @@ func (t *TCPSink) GetActiveConnections() int64 {
|
||||
type tcpClient struct {
|
||||
conn gnet.Conn
|
||||
buffer bytes.Buffer
|
||||
authenticated bool
|
||||
authTimeout time.Time
|
||||
session *auth.Session
|
||||
}
|
||||
@ -439,7 +359,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
remoteAddr := c.RemoteAddr()
|
||||
s.sink.logger.Debug("msg", "TCP connection attempt", "remote_addr", remoteAddr)
|
||||
|
||||
// Reject IPv6 connections immediately
|
||||
// Reject IPv6 connections
|
||||
if tcpAddr, ok := remoteAddr.(*net.TCPAddr); ok {
|
||||
if tcpAddr.IP.To4() == nil {
|
||||
return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close
|
||||
@ -467,14 +387,10 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
s.sink.netLimiter.AddConnection(remoteStr)
|
||||
}
|
||||
|
||||
// Create client state without auth timeout initially
|
||||
// TCP Sink accepts all connections without authentication
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
authenticated: s.sink.authenticator == nil,
|
||||
}
|
||||
|
||||
if s.sink.authenticator != nil {
|
||||
client.authTimeout = time.Now().Add(30 * time.Second)
|
||||
buffer: bytes.Buffer{},
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
@ -484,13 +400,7 @@ func (s *tcpServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
newCount := s.sink.activeConns.Add(1)
|
||||
s.sink.logger.Debug("msg", "TCP connection opened",
|
||||
"remote_addr", remoteAddr,
|
||||
"active_connections", newCount,
|
||||
"auth_enabled", s.sink.authenticator != nil)
|
||||
|
||||
// Send auth prompt if authentication is required
|
||||
if s.sink.authenticator != nil {
|
||||
return []byte("AUTH_REQUIRED\n"), gnet.None
|
||||
}
|
||||
"active_connections", newCount)
|
||||
|
||||
return nil, gnet.None
|
||||
}
|
||||
@ -522,96 +432,7 @@ func (s *tcpServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
}
|
||||
|
||||
func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
s.mu.RLock()
|
||||
client, exists := s.clients[c]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Authentication phase
|
||||
if !client.authenticated {
|
||||
// Check auth timeout
|
||||
if time.Now().After(client.authTimeout) {
|
||||
s.sink.logger.Warn("msg", "Authentication timeout",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Read auth data
|
||||
data, _ := c.Next(-1)
|
||||
if len(data) == 0 {
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
client.buffer.Write(data)
|
||||
|
||||
// Look for complete auth line
|
||||
if idx := bytes.IndexByte(client.buffer.Bytes(), '\n'); idx >= 0 {
|
||||
line := client.buffer.Bytes()[:idx]
|
||||
client.buffer.Next(idx + 1)
|
||||
|
||||
// Parse AUTH command: AUTH <method> <credentials>
|
||||
parts := strings.SplitN(string(line), " ", 3)
|
||||
if len(parts) != 3 || parts[0] != "AUTH" {
|
||||
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Authenticate
|
||||
session, err := s.sink.authenticator.AuthenticateTCP(parts[1], parts[2], c.RemoteAddr().String())
|
||||
if err != nil {
|
||||
s.sink.authFailures.Add(1)
|
||||
s.sink.logger.Warn("msg", "TCP authentication failed",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"method", parts[1],
|
||||
"error", err)
|
||||
c.AsyncWrite([]byte("AUTH_FAIL\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Authentication successful
|
||||
s.sink.authSuccesses.Add(1)
|
||||
s.mu.Lock()
|
||||
client.authenticated = true
|
||||
client.session = session
|
||||
s.mu.Unlock()
|
||||
|
||||
s.sink.logger.Info("msg", "TCP client authenticated",
|
||||
"component", "tcp_sink",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"username", session.Username,
|
||||
"method", session.Method)
|
||||
|
||||
c.AsyncWrite([]byte("AUTH_OK\n"), nil)
|
||||
client.buffer.Reset()
|
||||
}
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Clients shouldn't send data, just discard
|
||||
// TCP Sink doesn't expect any data from clients, discard all
|
||||
c.Discard(-1)
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
// Configures tcp sink auth
|
||||
func (t *TCPSink) SetAuth(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(authCfg, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "Failed to initialize authenticator for TCP sink",
|
||||
"component", "tcp_sink",
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
t.authenticator = authenticator
|
||||
|
||||
t.logger.Info("msg", "Authentication configured for TCP sink",
|
||||
"component", "tcp_sink",
|
||||
"auth_type", authCfg.Type)
|
||||
}
|
||||
@ -7,7 +7,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"logwisp/src/internal/auth"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -16,7 +18,6 @@ import (
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/format"
|
||||
"logwisp/src/internal/scram"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
)
|
||||
@ -24,7 +25,8 @@ import (
|
||||
// Forwards log entries to a remote TCP endpoint
|
||||
type TCPClientSink struct {
|
||||
input chan core.LogEntry
|
||||
config TCPClientConfig
|
||||
config *config.TCPClientSinkOptions
|
||||
address string
|
||||
conn net.Conn
|
||||
connMu sync.RWMutex
|
||||
done chan struct{}
|
||||
@ -46,101 +48,17 @@ type TCPClientSink struct {
|
||||
connectionUptime atomic.Value // time.Duration
|
||||
}
|
||||
|
||||
// Holds TCP client sink configuration
|
||||
type TCPClientConfig struct {
|
||||
Address string `toml:"address"`
|
||||
BufferSize int64 `toml:"buffer_size"`
|
||||
DialTimeout time.Duration `toml:"dial_timeout_seconds"`
|
||||
WriteTimeout time.Duration `toml:"write_timeout_seconds"`
|
||||
ReadTimeout time.Duration `toml:"read_timeout_seconds"`
|
||||
KeepAlive time.Duration `toml:"keep_alive_seconds"`
|
||||
|
||||
// Security
|
||||
AuthType string `toml:"auth_type"`
|
||||
Username string `toml:"username"`
|
||||
Password string `toml:"password"`
|
||||
|
||||
// Reconnection settings
|
||||
ReconnectDelay time.Duration `toml:"reconnect_delay_ms"`
|
||||
MaxReconnectDelay time.Duration `toml:"max_reconnect_delay_seconds"`
|
||||
ReconnectBackoff float64 `toml:"reconnect_backoff"`
|
||||
}
|
||||
|
||||
// Creates a new TCP client sink
|
||||
func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) {
|
||||
cfg := TCPClientConfig{
|
||||
BufferSize: int64(1000),
|
||||
DialTimeout: 10 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
ReconnectDelay: time.Second,
|
||||
MaxReconnectDelay: 30 * time.Second,
|
||||
ReconnectBackoff: float64(1.5),
|
||||
}
|
||||
|
||||
// Extract address
|
||||
address, ok := options["address"].(string)
|
||||
if !ok || address == "" {
|
||||
return nil, fmt.Errorf("tcp_client sink requires 'address' option")
|
||||
}
|
||||
|
||||
// Validate address format
|
||||
_, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid address format (expected host:port): %w", err)
|
||||
}
|
||||
cfg.Address = address
|
||||
|
||||
// Extract other options
|
||||
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
|
||||
cfg.BufferSize = bufSize
|
||||
}
|
||||
if dialTimeout, ok := options["dial_timeout_seconds"].(int64); ok && dialTimeout > 0 {
|
||||
cfg.DialTimeout = time.Duration(dialTimeout) * time.Second
|
||||
}
|
||||
if writeTimeout, ok := options["write_timeout_seconds"].(int64); ok && writeTimeout > 0 {
|
||||
cfg.WriteTimeout = time.Duration(writeTimeout) * time.Second
|
||||
}
|
||||
if readTimeout, ok := options["read_timeout_seconds"].(int64); ok && readTimeout > 0 {
|
||||
cfg.ReadTimeout = time.Duration(readTimeout) * time.Second
|
||||
}
|
||||
if keepAlive, ok := options["keep_alive_seconds"].(int64); ok && keepAlive > 0 {
|
||||
cfg.KeepAlive = time.Duration(keepAlive) * time.Second
|
||||
}
|
||||
if reconnectDelay, ok := options["reconnect_delay_ms"].(int64); ok && reconnectDelay > 0 {
|
||||
cfg.ReconnectDelay = time.Duration(reconnectDelay) * time.Millisecond
|
||||
}
|
||||
if maxReconnectDelay, ok := options["max_reconnect_delay_seconds"].(int64); ok && maxReconnectDelay > 0 {
|
||||
cfg.MaxReconnectDelay = time.Duration(maxReconnectDelay) * time.Second
|
||||
}
|
||||
if backoff, ok := options["reconnect_backoff"].(float64); ok && backoff >= 1.0 {
|
||||
cfg.ReconnectBackoff = backoff
|
||||
}
|
||||
if authType, ok := options["auth_type"].(string); ok {
|
||||
switch authType {
|
||||
case "none":
|
||||
cfg.AuthType = authType
|
||||
case "scram":
|
||||
cfg.AuthType = authType
|
||||
if username, ok := options["username"].(string); ok && username != "" {
|
||||
cfg.Username = username
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid scram username")
|
||||
}
|
||||
if password, ok := options["password"].(string); ok && password != "" {
|
||||
cfg.Password = password
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid scram password")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("tcp_client sink: invalid auth_type '%s' (must be 'none' or 'scram')", authType)
|
||||
}
|
||||
func NewTCPClientSink(opts *config.TCPClientSinkOptions, logger *log.Logger, formatter format.Formatter) (*TCPClientSink, error) {
|
||||
// Validation and defaults are handled in config package
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("TCP client sink options cannot be nil")
|
||||
}
|
||||
|
||||
t := &TCPClientSink{
|
||||
input: make(chan core.LogEntry, cfg.BufferSize),
|
||||
config: cfg,
|
||||
config: opts,
|
||||
address: opts.Host + ":" + strconv.Itoa(int(opts.Port)),
|
||||
input: make(chan core.LogEntry, opts.BufferSize),
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
@ -167,7 +85,8 @@ func (t *TCPClientSink) Start(ctx context.Context) error {
|
||||
|
||||
t.logger.Info("msg", "TCP client sink started",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address)
|
||||
"host", t.config.Host,
|
||||
"port", t.config.Port)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -209,7 +128,7 @@ func (t *TCPClientSink) GetStats() SinkStats {
|
||||
StartTime: t.startTime,
|
||||
LastProcessed: lastProc,
|
||||
Details: map[string]any{
|
||||
"address": t.config.Address,
|
||||
"address": t.address,
|
||||
"connected": connected,
|
||||
"reconnecting": t.reconnecting.Load(),
|
||||
"total_failed": t.totalFailed.Load(),
|
||||
@ -223,7 +142,7 @@ func (t *TCPClientSink) GetStats() SinkStats {
|
||||
func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
defer t.wg.Done()
|
||||
|
||||
reconnectDelay := t.config.ReconnectDelay
|
||||
reconnectDelay := time.Duration(t.config.ReconnectDelayMS) * time.Millisecond
|
||||
|
||||
for {
|
||||
select {
|
||||
@ -243,9 +162,9 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
t.lastConnectErr = err
|
||||
t.logger.Warn("msg", "Failed to connect to TCP server",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"address", t.address,
|
||||
"error", err,
|
||||
"retry_delay", reconnectDelay)
|
||||
"retry_delay_ms", reconnectDelay)
|
||||
|
||||
// Wait before retry
|
||||
select {
|
||||
@ -258,15 +177,15 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
|
||||
// Exponential backoff
|
||||
reconnectDelay = time.Duration(float64(reconnectDelay) * t.config.ReconnectBackoff)
|
||||
if reconnectDelay > t.config.MaxReconnectDelay {
|
||||
reconnectDelay = t.config.MaxReconnectDelay
|
||||
if reconnectDelay > time.Duration(t.config.MaxReconnectDelayMS)*time.Millisecond {
|
||||
reconnectDelay = time.Duration(t.config.MaxReconnectDelayMS)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Connection successful
|
||||
t.lastConnectErr = nil
|
||||
reconnectDelay = t.config.ReconnectDelay // Reset backoff
|
||||
reconnectDelay = time.Duration(t.config.ReconnectDelayMS) * time.Millisecond // Reset backoff
|
||||
t.connectTime = time.Now()
|
||||
t.totalReconnects.Add(1)
|
||||
|
||||
@ -276,7 +195,7 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
|
||||
t.logger.Info("msg", "Connected to TCP server",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"address", t.address,
|
||||
"local_addr", conn.LocalAddr())
|
||||
|
||||
// Monitor connection
|
||||
@ -293,18 +212,18 @@ func (t *TCPClientSink) connectionManager(ctx context.Context) {
|
||||
|
||||
t.logger.Warn("msg", "Lost connection to TCP server",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"address", t.address,
|
||||
"uptime", uptime)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TCPClientSink) connect() (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: t.config.DialTimeout,
|
||||
KeepAlive: t.config.KeepAlive,
|
||||
Timeout: time.Duration(t.config.DialTimeout) * time.Second,
|
||||
KeepAlive: time.Duration(t.config.KeepAlive) * time.Second,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", t.config.Address)
|
||||
conn, err := dialer.Dial("tcp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -312,18 +231,18 @@ func (t *TCPClientSink) connect() (net.Conn, error) {
|
||||
// Set TCP keep-alive
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
tcpConn.SetKeepAlive(true)
|
||||
tcpConn.SetKeepAlivePeriod(t.config.KeepAlive)
|
||||
tcpConn.SetKeepAlivePeriod(time.Duration(t.config.KeepAlive) * time.Second)
|
||||
}
|
||||
|
||||
// SCRAM authentication if credentials configured
|
||||
if t.config.AuthType == "scram" {
|
||||
if t.config.Auth != nil && t.config.Auth.Type == "scram" {
|
||||
if err := t.performSCRAMAuth(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("SCRAM authentication failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("msg", "SCRAM authentication completed",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address)
|
||||
"address", t.address)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@ -333,7 +252,17 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
|
||||
reader := bufio.NewReader(conn)
|
||||
|
||||
// Create SCRAM client
|
||||
scramClient := scram.NewClient(t.config.Username, t.config.Password)
|
||||
scramClient := auth.NewScramClient(t.config.Auth.Username, t.config.Auth.Password)
|
||||
|
||||
// Wait for AUTH_REQUIRED from server
|
||||
authPrompt, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read auth prompt: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(authPrompt) != "AUTH_REQUIRED" {
|
||||
return fmt.Errorf("unexpected server greeting: %s", authPrompt)
|
||||
}
|
||||
|
||||
// Step 1: Send ClientFirst
|
||||
clientFirst, err := scramClient.StartAuthentication()
|
||||
@ -341,8 +270,10 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
|
||||
return fmt.Errorf("failed to start SCRAM: %w", err)
|
||||
}
|
||||
|
||||
clientFirstJSON, _ := json.Marshal(clientFirst)
|
||||
msg := fmt.Sprintf("SCRAM-FIRST %s\n", clientFirstJSON)
|
||||
msg, err := auth.FormatSCRAMRequest("SCRAM-FIRST", clientFirst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte(msg)); err != nil {
|
||||
return fmt.Errorf("failed to send SCRAM-FIRST: %w", err)
|
||||
@ -354,13 +285,17 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
|
||||
return fmt.Errorf("failed to read SCRAM challenge: %w", err)
|
||||
}
|
||||
|
||||
parts := strings.Fields(strings.TrimSpace(response))
|
||||
if len(parts) != 2 || parts[0] != "SCRAM-CHALLENGE" {
|
||||
return fmt.Errorf("unexpected server response: %s", response)
|
||||
command, data, err := auth.ParseSCRAMResponse(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var serverFirst scram.ServerFirst
|
||||
if err := json.Unmarshal([]byte(parts[1]), &serverFirst); err != nil {
|
||||
if command != "SCRAM-CHALLENGE" {
|
||||
return fmt.Errorf("unexpected server response: %s", command)
|
||||
}
|
||||
|
||||
var serverFirst auth.ServerFirst
|
||||
if err := json.Unmarshal([]byte(data), &serverFirst); err != nil {
|
||||
return fmt.Errorf("failed to parse server challenge: %w", err)
|
||||
}
|
||||
|
||||
@ -370,8 +305,10 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
|
||||
return fmt.Errorf("failed to process challenge: %w", err)
|
||||
}
|
||||
|
||||
clientFinalJSON, _ := json.Marshal(clientFinal)
|
||||
msg = fmt.Sprintf("SCRAM-PROOF %s\n", clientFinalJSON)
|
||||
msg, err = auth.FormatSCRAMRequest("SCRAM-PROOF", clientFinal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte(msg)); err != nil {
|
||||
return fmt.Errorf("failed to send SCRAM-PROOF: %w", err)
|
||||
@ -383,19 +320,15 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
|
||||
return fmt.Errorf("failed to read SCRAM result: %w", err)
|
||||
}
|
||||
|
||||
parts = strings.Fields(strings.TrimSpace(response))
|
||||
if len(parts) < 1 {
|
||||
return fmt.Errorf("empty server response")
|
||||
command, data, err = auth.ParseSCRAMResponse(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch parts[0] {
|
||||
switch command {
|
||||
case "SCRAM-OK":
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid SCRAM-OK response")
|
||||
}
|
||||
|
||||
var serverFinal scram.ServerFinal
|
||||
if err := json.Unmarshal([]byte(parts[1]), &serverFinal); err != nil {
|
||||
var serverFinal auth.ServerFinal
|
||||
if err := json.Unmarshal([]byte(data), &serverFinal); err != nil {
|
||||
return fmt.Errorf("failed to parse server signature: %w", err)
|
||||
}
|
||||
|
||||
@ -406,21 +339,21 @@ func (t *TCPClientSink) performSCRAMAuth(conn net.Conn) error {
|
||||
|
||||
t.logger.Info("msg", "SCRAM authentication successful",
|
||||
"component", "tcp_client_sink",
|
||||
"address", t.config.Address,
|
||||
"username", t.config.Username,
|
||||
"address", t.address,
|
||||
"username", t.config.Auth.Username,
|
||||
"session_id", serverFinal.SessionID)
|
||||
|
||||
return nil
|
||||
|
||||
case "SCRAM-FAIL":
|
||||
reason := "unknown"
|
||||
if len(parts) > 1 {
|
||||
reason = strings.Join(parts[1:], " ")
|
||||
reason := data
|
||||
if reason == "" {
|
||||
reason = "unknown"
|
||||
}
|
||||
return fmt.Errorf("authentication failed: %s", reason)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected response: %s", response)
|
||||
return fmt.Errorf("unexpected response: %s", command)
|
||||
}
|
||||
}
|
||||
|
||||
@ -436,7 +369,7 @@ func (t *TCPClientSink) monitorConnection(conn net.Conn) {
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Set read deadline
|
||||
if err := conn.SetReadDeadline(time.Now().Add(t.config.ReadTimeout)); err != nil {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(time.Duration(t.config.ReadTimeout) * time.Second)); err != nil {
|
||||
t.logger.Debug("msg", "Failed to set read deadline", "error", err)
|
||||
return
|
||||
}
|
||||
@ -502,7 +435,7 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
|
||||
}
|
||||
|
||||
// Set write deadline
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(t.config.WriteTimeout)); err != nil {
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(time.Duration(t.config.WriteTimeout) * time.Second)); err != nil {
|
||||
return fmt.Errorf("failed to set write deadline: %w", err)
|
||||
}
|
||||
|
||||
@ -519,9 +452,3 @@ func (t *TCPClientSink) sendEntry(entry core.LogEntry) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Not applicable, Clients authenticate to remote servers using Username/Password in config
|
||||
func (h *TCPClientSink) SetAuth(authCfg *config.AuthConfig) {
|
||||
// No-op: client sinks don't validate incoming connections
|
||||
// They authenticate to remote servers using Username/Password fields
|
||||
}
|
||||
@ -21,9 +21,7 @@ import (
|
||||
|
||||
// Monitors a directory for log files
|
||||
type DirectorySource struct {
|
||||
path string
|
||||
pattern string
|
||||
checkInterval time.Duration
|
||||
config *config.DirectorySourceOptions
|
||||
subscribers []chan core.LogEntry
|
||||
watchers map[string]*fileWatcher
|
||||
mu sync.RWMutex
|
||||
@ -38,31 +36,13 @@ type DirectorySource struct {
|
||||
}
|
||||
|
||||
// Creates a new directory monitoring source
|
||||
func NewDirectorySource(options map[string]any, logger *log.Logger) (*DirectorySource, error) {
|
||||
path, ok := options["path"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("directory source requires 'path' option")
|
||||
}
|
||||
|
||||
pattern, _ := options["pattern"].(string)
|
||||
if pattern == "" {
|
||||
pattern = "*"
|
||||
}
|
||||
|
||||
checkInterval := 100 * time.Millisecond
|
||||
if ms, ok := options["check_interval_ms"].(int64); ok && ms > 0 {
|
||||
checkInterval = time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid path %s: %w", path, err)
|
||||
func NewDirectorySource(opts *config.DirectorySourceOptions, logger *log.Logger) (*DirectorySource, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("directory source options cannot be nil")
|
||||
}
|
||||
|
||||
ds := &DirectorySource{
|
||||
path: absPath,
|
||||
pattern: pattern,
|
||||
checkInterval: checkInterval,
|
||||
config: opts,
|
||||
watchers: make(map[string]*fileWatcher),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
@ -88,9 +68,9 @@ func (ds *DirectorySource) Start() error {
|
||||
|
||||
ds.logger.Info("msg", "Directory source started",
|
||||
"component", "directory_source",
|
||||
"path", ds.path,
|
||||
"pattern", ds.pattern,
|
||||
"check_interval_ms", ds.checkInterval.Milliseconds())
|
||||
"path", ds.config.Path,
|
||||
"pattern", ds.config.Pattern,
|
||||
"check_interval_ms", ds.config.CheckIntervalMS)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -111,7 +91,7 @@ func (ds *DirectorySource) Stop() {
|
||||
|
||||
ds.logger.Info("msg", "Directory source stopped",
|
||||
"component", "directory_source",
|
||||
"path", ds.path)
|
||||
"path", ds.config.Path)
|
||||
}
|
||||
|
||||
func (ds *DirectorySource) GetStats() SourceStats {
|
||||
@ -171,7 +151,7 @@ func (ds *DirectorySource) monitorLoop() {
|
||||
|
||||
ds.checkTargets()
|
||||
|
||||
ticker := time.NewTicker(ds.checkInterval)
|
||||
ticker := time.NewTicker(time.Duration(ds.config.CheckIntervalMS) * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@ -189,8 +169,8 @@ func (ds *DirectorySource) checkTargets() {
|
||||
if err != nil {
|
||||
ds.logger.Warn("msg", "Failed to scan directory",
|
||||
"component", "directory_source",
|
||||
"path", ds.path,
|
||||
"pattern", ds.pattern,
|
||||
"path", ds.config.Path,
|
||||
"pattern", ds.config.Pattern,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
@ -203,13 +183,13 @@ func (ds *DirectorySource) checkTargets() {
|
||||
}
|
||||
|
||||
func (ds *DirectorySource) scanDirectory() ([]string, error) {
|
||||
entries, err := os.ReadDir(ds.path)
|
||||
entries, err := os.ReadDir(ds.config.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert glob pattern to regex
|
||||
regexPattern := globToRegex(ds.pattern)
|
||||
regexPattern := globToRegex(ds.config.Pattern)
|
||||
re, err := regexp.Compile(regexPattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pattern regex: %w", err)
|
||||
@ -223,7 +203,7 @@ func (ds *DirectorySource) scanDirectory() ([]string, error) {
|
||||
|
||||
name := entry.Name()
|
||||
if re.MatchString(name) {
|
||||
files = append(files, filepath.Join(ds.path, name))
|
||||
files = append(files, filepath.Join(ds.config.Path, name))
|
||||
}
|
||||
}
|
||||
|
||||
@ -288,7 +268,3 @@ func globToRegex(glob string) string {
|
||||
regex = strings.ReplaceAll(regex, `\?`, `.`)
|
||||
return "^" + regex + "$"
|
||||
}
|
||||
|
||||
func (ds *DirectorySource) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to directory source
|
||||
}
|
||||
@ -14,7 +14,6 @@ import (
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/tls"
|
||||
"logwisp/src/internal/version"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
@ -22,12 +21,7 @@ import (
|
||||
|
||||
// Receives log entries via HTTP POST requests
|
||||
type HTTPSource struct {
|
||||
// Config
|
||||
host string
|
||||
port int64
|
||||
path string
|
||||
bufferSize int64
|
||||
maxRequestBodySize int64
|
||||
config *config.HTTPSourceOptions
|
||||
|
||||
// Application
|
||||
server *fasthttp.Server
|
||||
@ -42,11 +36,9 @@ type HTTPSource struct {
|
||||
|
||||
// Security
|
||||
authenticator *auth.Authenticator
|
||||
authConfig *config.AuthConfig
|
||||
authFailures atomic.Uint64
|
||||
authSuccesses atomic.Uint64
|
||||
tlsManager *tls.Manager
|
||||
tlsConfig *config.TLSConfig
|
||||
|
||||
// Statistics
|
||||
totalEntries atomic.Uint64
|
||||
@ -57,38 +49,14 @@ type HTTPSource struct {
|
||||
}
|
||||
|
||||
// Creates a new HTTP server source
|
||||
func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, error) {
|
||||
host := "0.0.0.0"
|
||||
if h, ok := options["host"].(string); ok && h != "" {
|
||||
host = h
|
||||
}
|
||||
|
||||
port, ok := options["port"].(int64)
|
||||
if !ok || port < 1 || port > 65535 {
|
||||
return nil, fmt.Errorf("http source requires valid 'port' option")
|
||||
}
|
||||
|
||||
ingestPath := "/ingest"
|
||||
if path, ok := options["path"].(string); ok && path != "" {
|
||||
ingestPath = path
|
||||
}
|
||||
|
||||
bufferSize := int64(1000)
|
||||
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
|
||||
bufferSize = bufSize
|
||||
}
|
||||
|
||||
maxRequestBodySize := int64(10 * 1024 * 1024) // fasthttp default 10MB
|
||||
if maxBodySize, ok := options["max_body_size"].(int64); ok && maxBodySize > 0 && maxBodySize < maxRequestBodySize {
|
||||
maxRequestBodySize = maxBodySize
|
||||
func NewHTTPSource(opts *config.HTTPSourceOptions, logger *log.Logger) (*HTTPSource, error) {
|
||||
// Validation done in config package
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("HTTP source options cannot be nil")
|
||||
}
|
||||
|
||||
h := &HTTPSource{
|
||||
host: host,
|
||||
port: port,
|
||||
path: ingestPath,
|
||||
bufferSize: bufferSize,
|
||||
maxRequestBodySize: maxRequestBodySize,
|
||||
config: opts,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
@ -96,69 +64,37 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err
|
||||
h.lastEntryTime.Store(time.Time{})
|
||||
|
||||
// Initialize net limiter if configured
|
||||
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||
if enabled, _ := nl["enabled"].(bool); enabled {
|
||||
cfg := config.NetLimitConfig{
|
||||
Enabled: true,
|
||||
if opts.NetLimit != nil && (opts.NetLimit.Enabled ||
|
||||
len(opts.NetLimit.IPWhitelist) > 0 ||
|
||||
len(opts.NetLimit.IPBlacklist) > 0) {
|
||||
h.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
|
||||
}
|
||||
|
||||
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||
cfg.RequestsPerSecond = rps
|
||||
}
|
||||
if burst, ok := nl["burst_size"].(int64); ok {
|
||||
cfg.BurstSize = burst
|
||||
}
|
||||
if respCode, ok := nl["response_code"].(int64); ok {
|
||||
cfg.ResponseCode = respCode
|
||||
}
|
||||
if msg, ok := nl["response_message"].(string); ok {
|
||||
cfg.ResponseMessage = msg
|
||||
}
|
||||
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||
cfg.MaxConnectionsPerIP = maxPerIP
|
||||
}
|
||||
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||
cfg.MaxConnectionsTotal = maxTotal
|
||||
}
|
||||
|
||||
h.netLimiter = limit.NewNetLimiter(cfg, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract TLS config after existing options
|
||||
if tc, ok := options["tls"].(map[string]any); ok {
|
||||
h.tlsConfig = &config.TLSConfig{}
|
||||
h.tlsConfig.Enabled, _ = tc["enabled"].(bool)
|
||||
if certFile, ok := tc["cert_file"].(string); ok {
|
||||
h.tlsConfig.CertFile = certFile
|
||||
}
|
||||
if keyFile, ok := tc["key_file"].(string); ok {
|
||||
h.tlsConfig.KeyFile = keyFile
|
||||
}
|
||||
h.tlsConfig.ClientAuth, _ = tc["client_auth"].(bool)
|
||||
if caFile, ok := tc["client_ca_file"].(string); ok {
|
||||
h.tlsConfig.ClientCAFile = caFile
|
||||
}
|
||||
h.tlsConfig.VerifyClientCert, _ = tc["verify_client_cert"].(bool)
|
||||
h.tlsConfig.InsecureSkipVerify, _ = tc["insecure_skip_verify"].(bool)
|
||||
if minVer, ok := tc["min_version"].(string); ok {
|
||||
h.tlsConfig.MinVersion = minVer
|
||||
}
|
||||
if maxVer, ok := tc["max_version"].(string); ok {
|
||||
h.tlsConfig.MaxVersion = maxVer
|
||||
}
|
||||
if ciphers, ok := tc["cipher_suites"].(string); ok {
|
||||
h.tlsConfig.CipherSuites = ciphers
|
||||
}
|
||||
|
||||
// Create TLS manager
|
||||
if h.tlsConfig.Enabled {
|
||||
tlsManager, err := tls.NewManager(h.tlsConfig, logger)
|
||||
// Initialize TLS manager if configured
|
||||
if opts.TLS != nil && opts.TLS.Enabled {
|
||||
tlsManager, err := tls.NewManager(opts.TLS, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
h.tlsManager = tlsManager
|
||||
}
|
||||
|
||||
// Initialize authenticator if configured
|
||||
if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" {
|
||||
// Verify TLS is enabled for auth (validation should have caught this)
|
||||
if h.tlsManager == nil {
|
||||
return nil, fmt.Errorf("authentication requires TLS to be enabled")
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(opts.Auth, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create authenticator: %w", err)
|
||||
}
|
||||
h.authenticator = authenticator
|
||||
|
||||
logger.Info("msg", "Authentication configured for HTTP source",
|
||||
"component", "http_source",
|
||||
"auth_type", opts.Auth.Type)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
@ -168,23 +104,24 @@ func (h *HTTPSource) Subscribe() <-chan core.LogEntry {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
ch := make(chan core.LogEntry, h.bufferSize)
|
||||
ch := make(chan core.LogEntry, h.config.BufferSize)
|
||||
h.subscribers = append(h.subscribers, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (h *HTTPSource) Start() error {
|
||||
h.server = &fasthttp.Server{
|
||||
Name: fmt.Sprintf("LogWisp/%s", version.Short()),
|
||||
Handler: h.requestHandler,
|
||||
DisableKeepalive: false,
|
||||
StreamRequestBody: true,
|
||||
CloseOnShutdown: true,
|
||||
MaxRequestBodySize: int(h.maxRequestBodySize),
|
||||
ReadTimeout: time.Duration(h.config.ReadTimeout) * time.Millisecond,
|
||||
WriteTimeout: time.Duration(h.config.WriteTimeout) * time.Millisecond,
|
||||
MaxRequestBodySize: int(h.config.MaxRequestBodySize),
|
||||
}
|
||||
|
||||
// Use configured host and port
|
||||
addr := fmt.Sprintf("%s:%d", h.host, h.port)
|
||||
addr := fmt.Sprintf("%s:%d", h.config.Host, h.config.Port)
|
||||
|
||||
// Start server in background
|
||||
h.wg.Add(1)
|
||||
@ -193,35 +130,35 @@ func (h *HTTPSource) Start() error {
|
||||
defer h.wg.Done()
|
||||
h.logger.Info("msg", "HTTP source server starting",
|
||||
"component", "http_source",
|
||||
"port", h.port,
|
||||
"path", h.path,
|
||||
"tls_enabled", h.tlsManager != nil)
|
||||
"port", h.config.Port,
|
||||
"ingest_path", h.config.IngestPath,
|
||||
"tls_enabled", h.tlsManager != nil,
|
||||
"auth_enabled", h.authenticator != nil)
|
||||
|
||||
var err error
|
||||
// Check for TLS manager and start the appropriate server type
|
||||
if h.tlsManager != nil {
|
||||
// HTTPS server
|
||||
h.server.TLSConfig = h.tlsManager.GetHTTPConfig()
|
||||
err = h.server.ListenAndServeTLS(addr, h.tlsConfig.CertFile, h.tlsConfig.KeyFile)
|
||||
err = h.server.ListenAndServeTLS(addr, h.config.TLS.CertFile, h.config.TLS.KeyFile)
|
||||
} else {
|
||||
// HTTP server
|
||||
err = h.server.ListenAndServe(addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("msg", "HTTP source server failed",
|
||||
"component", "http_source",
|
||||
"port", h.port,
|
||||
"port", h.config.Port,
|
||||
"error", err)
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// Robust server startup check with timeout
|
||||
// Wait briefly for server startup
|
||||
select {
|
||||
case err := <-errChan:
|
||||
// Server failed to start
|
||||
return fmt.Errorf("HTTP server failed to start: %w", err)
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
// Server started successfully (no immediate error)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -263,6 +200,21 @@ func (h *HTTPSource) GetStats() SourceStats {
|
||||
netLimitStats = h.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var authStats map[string]any
|
||||
if h.authenticator != nil {
|
||||
authStats = map[string]any{
|
||||
"enabled": true,
|
||||
"type": h.config.Auth.Type,
|
||||
"failures": h.authFailures.Load(),
|
||||
"successes": h.authSuccesses.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
var tlsStats map[string]any
|
||||
if h.tlsManager != nil {
|
||||
tlsStats = h.tlsManager.GetStats()
|
||||
}
|
||||
|
||||
return SourceStats{
|
||||
Type: "http",
|
||||
TotalEntries: h.totalEntries.Load(),
|
||||
@ -270,10 +222,13 @@ func (h *HTTPSource) GetStats() SourceStats {
|
||||
StartTime: h.startTime,
|
||||
LastEntryTime: lastEntry,
|
||||
Details: map[string]any{
|
||||
"port": h.port,
|
||||
"path": h.path,
|
||||
"host": h.config.Host,
|
||||
"port": h.config.Port,
|
||||
"path": h.config.IngestPath,
|
||||
"invalid_entries": h.invalidEntries.Load(),
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
"tls": tlsStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -307,17 +262,10 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
// 2.5. Check TLS requirement for auth (early reject)
|
||||
if h.authenticator != nil && h.authConfig.Type != "none" {
|
||||
// Check if connection is TLS
|
||||
// 3. Check TLS requirement for auth
|
||||
if h.authenticator != nil {
|
||||
isTLS := ctx.IsTLS() || h.tlsManager != nil
|
||||
|
||||
if !isTLS {
|
||||
h.logger.Error("msg", "Authentication configured but connection is not TLS",
|
||||
"component", "http_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"auth_type", h.authConfig.Type)
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
@ -326,21 +274,45 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Path check (only process ingest path)
|
||||
path := string(ctx.Path())
|
||||
if path != h.path {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
// Authenticate request
|
||||
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||
session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
|
||||
if err != nil {
|
||||
h.authFailures.Add(1)
|
||||
h.logger.Warn("msg", "Authentication failed",
|
||||
"component", "http_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"error", err)
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
||||
if h.config.Auth.Type == "basic" && h.config.Auth.Basic != nil && h.config.Auth.Basic.Realm != "" {
|
||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, h.config.Auth.Basic.Realm))
|
||||
}
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Not Found",
|
||||
"hint": fmt.Sprintf("POST logs to %s", h.path),
|
||||
"error": "Authentication failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Method check (only accept POST)
|
||||
h.authSuccesses.Add(1)
|
||||
_ = session // Session can be used for audit logging
|
||||
}
|
||||
|
||||
// 4. Path check
|
||||
path := string(ctx.Path())
|
||||
if path != h.config.IngestPath {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Not Found",
|
||||
"hint": fmt.Sprintf("POST logs to %s", h.config.IngestPath),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Method check (only accepts POST)
|
||||
if string(ctx.Method()) != "POST" {
|
||||
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
|
||||
ctx.SetContentType("application/json")
|
||||
@ -352,43 +324,10 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Authentication check (if configured)
|
||||
if h.authenticator != nil {
|
||||
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||
session, err := h.authenticator.AuthenticateHTTP(authHeader, remoteAddr)
|
||||
if err != nil {
|
||||
h.authFailures.Add(1)
|
||||
h.logger.Warn("msg", "Authentication failed",
|
||||
"component", "http_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"error", err)
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
||||
if h.authConfig.Type == "basic" && h.authConfig.BasicAuth != nil {
|
||||
realm := h.authConfig.BasicAuth.Realm
|
||||
if realm == "" {
|
||||
realm = "Restricted"
|
||||
}
|
||||
ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm))
|
||||
} else if h.authConfig.Type == "bearer" {
|
||||
ctx.Response.Header.Set("WWW-Authenticate", "Bearer")
|
||||
}
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
h.authSuccesses.Add(1)
|
||||
h.logger.Debug("msg", "Request authenticated",
|
||||
"component", "http_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"username", session.Username)
|
||||
}
|
||||
|
||||
// 6. Process request body
|
||||
// 6. Process log entry
|
||||
body := ctx.PostBody()
|
||||
if len(body) == 0 {
|
||||
h.invalidEntries.Add(1)
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
@ -397,32 +336,34 @@ func (h *HTTPSource) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
return
|
||||
}
|
||||
|
||||
// 7. Parse log entries
|
||||
entries, err := h.parseEntries(body)
|
||||
if err != nil {
|
||||
var entry core.LogEntry
|
||||
if err := json.Unmarshal(body, &entry); err != nil {
|
||||
h.invalidEntries.Add(1)
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"error": fmt.Sprintf("Invalid log format: %v", err),
|
||||
"error": fmt.Sprintf("Invalid JSON: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 8. Publish entries to subscribers
|
||||
accepted := 0
|
||||
for _, entry := range entries {
|
||||
if h.publish(entry) {
|
||||
accepted++
|
||||
// Set defaults
|
||||
if entry.Time.IsZero() {
|
||||
entry.Time = time.Now()
|
||||
}
|
||||
if entry.Source == "" {
|
||||
entry.Source = "http"
|
||||
}
|
||||
entry.RawSize = int64(len(body))
|
||||
|
||||
// 9. Return success response
|
||||
// Publish to subscribers
|
||||
h.publish(entry)
|
||||
|
||||
// Success response
|
||||
ctx.SetStatusCode(fasthttp.StatusAccepted)
|
||||
ctx.SetContentType("application/json")
|
||||
json.NewEncoder(ctx).Encode(map[string]any{
|
||||
"accepted": accepted,
|
||||
"total": len(entries),
|
||||
json.NewEncoder(ctx).Encode(map[string]string{
|
||||
"status": "accepted",
|
||||
})
|
||||
}
|
||||
|
||||
@ -501,29 +442,22 @@ func (h *HTTPSource) parseEntries(body []byte) ([]core.LogEntry, error) {
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (h *HTTPSource) publish(entry core.LogEntry) bool {
|
||||
func (h *HTTPSource) publish(entry core.LogEntry) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
h.totalEntries.Add(1)
|
||||
h.lastEntryTime.Store(entry.Time)
|
||||
|
||||
dropped := false
|
||||
for _, ch := range h.subscribers {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
dropped = true
|
||||
h.droppedEntries.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
if dropped {
|
||||
h.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
|
||||
"component", "http_source")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Splits bytes into lines, handling both \n and \r\n
|
||||
@ -550,24 +484,3 @@ func splitLines(data []byte) [][]byte {
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
// Configure HTTP source auth
|
||||
func (h *HTTPSource) SetAuth(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
h.authConfig = authCfg
|
||||
authenticator, err := auth.NewAuthenticator(authCfg, h.logger)
|
||||
if err != nil {
|
||||
h.logger.Error("msg", "Failed to initialize authenticator for HTTP source",
|
||||
"component", "http_source",
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
h.authenticator = authenticator
|
||||
|
||||
h.logger.Info("msg", "Authentication configured for HTTP source",
|
||||
"component", "http_source",
|
||||
"auth_type", authCfg.Type)
|
||||
}
|
||||
@ -4,7 +4,6 @@ package source
|
||||
import (
|
||||
"time"
|
||||
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
)
|
||||
|
||||
@ -21,9 +20,6 @@ type Source interface {
|
||||
|
||||
// Returns source statistics
|
||||
GetStats() SourceStats
|
||||
|
||||
// Configure authentication
|
||||
SetAuth(auth *config.AuthConfig)
|
||||
}
|
||||
|
||||
// Contains statistics about a source
|
||||
|
||||
@ -15,24 +15,25 @@ import (
|
||||
|
||||
// Reads log entries from standard input
|
||||
type StdinSource struct {
|
||||
config *config.StdinSourceOptions
|
||||
subscribers []chan core.LogEntry
|
||||
done chan struct{}
|
||||
totalEntries atomic.Uint64
|
||||
droppedEntries atomic.Uint64
|
||||
bufferSize int64
|
||||
startTime time.Time
|
||||
lastEntryTime atomic.Value // time.Time
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func NewStdinSource(options map[string]any, logger *log.Logger) (*StdinSource, error) {
|
||||
bufferSize := int64(1000) // default
|
||||
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
|
||||
bufferSize = bufSize
|
||||
func NewStdinSource(opts *config.StdinSourceOptions, logger *log.Logger) (*StdinSource, error) {
|
||||
if opts == nil {
|
||||
opts = &config.StdinSourceOptions{
|
||||
BufferSize: 1000, // Default
|
||||
}
|
||||
}
|
||||
|
||||
source := &StdinSource{
|
||||
bufferSize: bufferSize,
|
||||
config: opts,
|
||||
subscribers: make([]chan core.LogEntry, 0),
|
||||
done: make(chan struct{}),
|
||||
logger: logger,
|
||||
@ -43,7 +44,7 @@ func NewStdinSource(options map[string]any, logger *log.Logger) (*StdinSource, e
|
||||
}
|
||||
|
||||
func (s *StdinSource) Subscribe() <-chan core.LogEntry {
|
||||
ch := make(chan core.LogEntry, s.bufferSize)
|
||||
ch := make(chan core.LogEntry, s.config.BufferSize)
|
||||
s.subscribers = append(s.subscribers, ch)
|
||||
return ch
|
||||
}
|
||||
@ -120,7 +121,3 @@ func (s *StdinSource) publish(entry core.LogEntry) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdinSource) SetAuth(auth *config.AuthConfig) {
|
||||
// Authentication does not apply to stdin source
|
||||
}
|
||||
@ -4,7 +4,6 @@ package source
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -17,7 +16,6 @@ import (
|
||||
"logwisp/src/internal/config"
|
||||
"logwisp/src/internal/core"
|
||||
"logwisp/src/internal/limit"
|
||||
"logwisp/src/internal/scram"
|
||||
|
||||
"github.com/lixenwraith/log"
|
||||
"github.com/lixenwraith/log/compat"
|
||||
@ -31,9 +29,7 @@ const (
|
||||
|
||||
// Receives log entries via TCP connections
|
||||
type TCPSource struct {
|
||||
host string
|
||||
port int64
|
||||
bufferSize int64
|
||||
config *config.TCPSourceOptions
|
||||
server *tcpSourceServer
|
||||
subscribers []chan core.LogEntry
|
||||
mu sync.RWMutex
|
||||
@ -41,9 +37,11 @@ type TCPSource struct {
|
||||
engine *gnet.Engine
|
||||
engineMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
authenticator *auth.Authenticator
|
||||
netLimiter *limit.NetLimiter
|
||||
logger *log.Logger
|
||||
scramManager *scram.ScramManager
|
||||
scramManager *auth.ScramManager
|
||||
scramProtocolHandler *auth.ScramProtocolHandler
|
||||
|
||||
// Statistics
|
||||
totalEntries atomic.Uint64
|
||||
@ -57,26 +55,14 @@ type TCPSource struct {
|
||||
}
|
||||
|
||||
// Creates a new TCP server source
|
||||
func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error) {
|
||||
host := "0.0.0.0"
|
||||
if h, ok := options["host"].(string); ok && h != "" {
|
||||
host = h
|
||||
}
|
||||
|
||||
port, ok := options["port"].(int64)
|
||||
if !ok || port < 1 || port > 65535 {
|
||||
return nil, fmt.Errorf("tcp source requires valid 'port' option")
|
||||
}
|
||||
|
||||
bufferSize := int64(1000)
|
||||
if bufSize, ok := options["buffer_size"].(int64); ok && bufSize > 0 {
|
||||
bufferSize = bufSize
|
||||
func NewTCPSource(opts *config.TCPSourceOptions, logger *log.Logger) (*TCPSource, error) {
|
||||
// Accept typed config - validation done in config package
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("TCP source options cannot be nil")
|
||||
}
|
||||
|
||||
t := &TCPSource{
|
||||
host: host,
|
||||
port: port,
|
||||
bufferSize: bufferSize,
|
||||
config: opts,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
logger: logger,
|
||||
@ -84,33 +70,21 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error
|
||||
t.lastEntryTime.Store(time.Time{})
|
||||
|
||||
// Initialize net limiter if configured
|
||||
if nl, ok := options["net_limit"].(map[string]any); ok {
|
||||
if enabled, _ := nl["enabled"].(bool); enabled {
|
||||
cfg := config.NetLimitConfig{
|
||||
Enabled: true,
|
||||
if opts.NetLimit != nil && (opts.NetLimit.Enabled ||
|
||||
len(opts.NetLimit.IPWhitelist) > 0 ||
|
||||
len(opts.NetLimit.IPBlacklist) > 0) {
|
||||
t.netLimiter = limit.NewNetLimiter(opts.NetLimit, logger)
|
||||
}
|
||||
|
||||
if rps, ok := nl["requests_per_second"].(float64); ok {
|
||||
cfg.RequestsPerSecond = rps
|
||||
}
|
||||
if burst, ok := nl["burst_size"].(int64); ok {
|
||||
cfg.BurstSize = burst
|
||||
}
|
||||
if maxPerIP, ok := nl["max_connections_per_ip"].(int64); ok {
|
||||
cfg.MaxConnectionsPerIP = maxPerIP
|
||||
}
|
||||
if maxPerUser, ok := nl["max_connections_per_user"].(int64); ok {
|
||||
cfg.MaxConnectionsPerUser = maxPerUser
|
||||
}
|
||||
if maxPerToken, ok := nl["max_connections_per_token"].(int64); ok {
|
||||
cfg.MaxConnectionsPerToken = maxPerToken
|
||||
}
|
||||
if maxTotal, ok := nl["max_connections_total"].(int64); ok {
|
||||
cfg.MaxConnectionsTotal = maxTotal
|
||||
}
|
||||
|
||||
t.netLimiter = limit.NewNetLimiter(cfg, logger)
|
||||
}
|
||||
// Initialize SCRAM
|
||||
if opts.Auth != nil && opts.Auth.Type == "scram" && opts.Auth.Scram != nil {
|
||||
t.scramManager = auth.NewScramManager(opts.Auth.Scram)
|
||||
t.scramProtocolHandler = auth.NewScramProtocolHandler(t.scramManager, logger)
|
||||
logger.Info("msg", "SCRAM authentication configured for TCP source",
|
||||
"component", "tcp_source",
|
||||
"users", len(opts.Auth.Scram.Users))
|
||||
} else if opts.Auth != nil && opts.Auth.Type != "none" && opts.Auth.Type != "" {
|
||||
return nil, fmt.Errorf("TCP source only supports 'none' or 'scram' auth")
|
||||
}
|
||||
|
||||
return t, nil
|
||||
@ -120,7 +94,7 @@ func (t *TCPSource) Subscribe() <-chan core.LogEntry {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
ch := make(chan core.LogEntry, t.bufferSize)
|
||||
ch := make(chan core.LogEntry, t.config.BufferSize)
|
||||
t.subscribers = append(t.subscribers, ch)
|
||||
return ch
|
||||
}
|
||||
@ -132,7 +106,7 @@ func (t *TCPSource) Start() error {
|
||||
}
|
||||
|
||||
// Use configured host and port
|
||||
addr := fmt.Sprintf("tcp://%s:%d", t.host, t.port)
|
||||
addr := fmt.Sprintf("tcp://%s:%d", t.config.Host, t.config.Port)
|
||||
|
||||
// Create a gnet adapter using the existing logger instance
|
||||
gnetLogger := compat.NewGnetAdapter(t.logger)
|
||||
@ -144,17 +118,19 @@ func (t *TCPSource) Start() error {
|
||||
defer t.wg.Done()
|
||||
t.logger.Info("msg", "TCP source server starting",
|
||||
"component", "tcp_source",
|
||||
"port", t.port)
|
||||
"port", t.config.Port,
|
||||
"auth_enabled", t.authenticator != nil)
|
||||
|
||||
err := gnet.Run(t.server, addr,
|
||||
gnet.WithLogger(gnetLogger),
|
||||
gnet.WithMulticore(true),
|
||||
gnet.WithReusePort(true),
|
||||
gnet.WithTCPKeepAlive(time.Duration(t.config.KeepAlivePeriod)*time.Millisecond),
|
||||
)
|
||||
if err != nil {
|
||||
t.logger.Error("msg", "TCP source server failed",
|
||||
"component", "tcp_source",
|
||||
"port", t.port,
|
||||
"port", t.config.Port,
|
||||
"error", err)
|
||||
}
|
||||
errChan <- err
|
||||
@ -169,7 +145,7 @@ func (t *TCPSource) Start() error {
|
||||
return err
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Server started successfully
|
||||
t.logger.Info("msg", "TCP server started", "port", t.port)
|
||||
t.logger.Info("msg", "TCP server started", "port", t.config.Port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -214,6 +190,16 @@ func (t *TCPSource) GetStats() SourceStats {
|
||||
netLimitStats = t.netLimiter.GetStats()
|
||||
}
|
||||
|
||||
var authStats map[string]any
|
||||
if t.authenticator != nil {
|
||||
authStats = map[string]any{
|
||||
"enabled": true,
|
||||
"type": t.config.Auth.Type,
|
||||
"failures": t.authFailures.Load(),
|
||||
"successes": t.authSuccesses.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
return SourceStats{
|
||||
Type: "tcp",
|
||||
TotalEntries: t.totalEntries.Load(),
|
||||
@ -221,37 +207,31 @@ func (t *TCPSource) GetStats() SourceStats {
|
||||
StartTime: t.startTime,
|
||||
LastEntryTime: lastEntry,
|
||||
Details: map[string]any{
|
||||
"port": t.port,
|
||||
"port": t.config.Port,
|
||||
"active_connections": t.activeConns.Load(),
|
||||
"invalid_entries": t.invalidEntries.Load(),
|
||||
"net_limit": netLimitStats,
|
||||
"auth": authStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TCPSource) publish(entry core.LogEntry) bool {
|
||||
func (t *TCPSource) publish(entry core.LogEntry) {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
t.totalEntries.Add(1)
|
||||
t.lastEntryTime.Store(entry.Time)
|
||||
|
||||
dropped := false
|
||||
for _, ch := range t.subscribers {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
dropped = true
|
||||
t.droppedEntries.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
if dropped {
|
||||
t.logger.Debug("msg", "Dropped log entry - subscriber buffer full",
|
||||
"component", "tcp_source")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Represents a connected TCP client
|
||||
@ -262,8 +242,6 @@ type tcpClient struct {
|
||||
authTimeout time.Time
|
||||
session *auth.Session
|
||||
maxBufferSeen int
|
||||
cumulativeEncrypted int64
|
||||
scramState *scram.HandshakeState
|
||||
}
|
||||
|
||||
// Handles gnet events
|
||||
@ -282,7 +260,7 @@ func (s *tcpSourceServer) OnBoot(eng gnet.Engine) gnet.Action {
|
||||
|
||||
s.source.logger.Debug("msg", "TCP source server booted",
|
||||
"component", "tcp_source",
|
||||
"port", s.source.port)
|
||||
"port", s.source.config.Port)
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
@ -303,6 +281,16 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
return nil, gnet.Close
|
||||
}
|
||||
|
||||
// Check if connection is allowed
|
||||
ip := tcpAddr.IP
|
||||
if ip.To4() == nil {
|
||||
// Reject IPv6
|
||||
s.source.logger.Warn("msg", "IPv6 connection rejected",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", remoteAddr)
|
||||
return []byte("IPv4-only (IPv6 not supported)\n"), gnet.Close
|
||||
}
|
||||
|
||||
if !s.source.netLimiter.CheckTCP(tcpAddr) {
|
||||
s.source.logger.Warn("msg", "TCP connection net limited",
|
||||
"component", "tcp_source",
|
||||
@ -311,49 +299,66 @@ func (s *tcpSourceServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
}
|
||||
|
||||
// Track connection
|
||||
s.source.netLimiter.AddConnection(remoteAddr)
|
||||
// s.source.netLimiter.AddConnection(remoteAddr)
|
||||
if !s.source.netLimiter.TrackConnection(ip.String(), "", "") {
|
||||
s.source.logger.Warn("msg", "TCP connection limit exceeded",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", remoteAddr)
|
||||
return nil, gnet.Close
|
||||
}
|
||||
}
|
||||
|
||||
// Create client state
|
||||
client := &tcpClient{
|
||||
conn: c,
|
||||
buffer: bytes.NewBuffer(nil),
|
||||
authTimeout: time.Now().Add(30 * time.Second),
|
||||
authenticated: s.source.scramManager == nil,
|
||||
authenticated: s.source.authenticator == nil, // No auth = auto authenticated
|
||||
}
|
||||
|
||||
if s.source.authenticator != nil {
|
||||
// Set auth timeout
|
||||
client.authTimeout = time.Now().Add(10 * time.Second)
|
||||
|
||||
// Send auth challenge for SCRAM
|
||||
if s.source.config.Auth.Type == "scram" {
|
||||
out = []byte("AUTH_REQUIRED\n")
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.clients[c] = client
|
||||
s.mu.Unlock()
|
||||
|
||||
newCount := s.source.activeConns.Add(1)
|
||||
s.source.activeConns.Add(1)
|
||||
s.source.logger.Debug("msg", "TCP connection opened",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"active_connections", newCount,
|
||||
"requires_auth", s.source.scramManager != nil)
|
||||
"auth_enabled", s.source.authenticator != nil)
|
||||
|
||||
return nil, gnet.None
|
||||
return out, gnet.None
|
||||
}
|
||||
|
||||
func (s *tcpSourceServer) OnClose(c gnet.Conn, err error) gnet.Action {
|
||||
remoteAddr := c.RemoteAddr().String()
|
||||
|
||||
// Untrack connection
|
||||
if s.source.netLimiter != nil {
|
||||
if tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr); err == nil {
|
||||
s.source.netLimiter.ReleaseConnection(tcpAddr.IP.String(), "", "")
|
||||
// s.source.netLimiter.RemoveConnection(remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove client state
|
||||
s.mu.Lock()
|
||||
delete(s.clients, c)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Remove connection tracking
|
||||
if s.source.netLimiter != nil {
|
||||
s.source.netLimiter.RemoveConnection(remoteAddr)
|
||||
}
|
||||
|
||||
newCount := s.source.activeConns.Add(-1)
|
||||
newConnectionCount := s.source.activeConns.Add(-1)
|
||||
s.source.logger.Debug("msg", "TCP connection closed",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", remoteAddr,
|
||||
"active_connections", newCount,
|
||||
"active_connections", newConnectionCount,
|
||||
"error", err)
|
||||
return gnet.None
|
||||
}
|
||||
@ -383,6 +388,8 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
s.source.logger.Warn("msg", "Authentication timeout",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String())
|
||||
s.source.authFailures.Add(1)
|
||||
c.AsyncWrite([]byte("AUTH_TIMEOUT\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
@ -392,7 +399,12 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
|
||||
client.buffer.Write(data)
|
||||
|
||||
// Look for complete line
|
||||
// Use centralized SCRAM protocol handler
|
||||
if s.source.scramProtocolHandler == nil {
|
||||
s.source.scramProtocolHandler = auth.NewScramProtocolHandler(s.source.scramManager, s.source.logger)
|
||||
}
|
||||
|
||||
// Look for complete auth line
|
||||
for {
|
||||
idx := bytes.IndexByte(client.buffer.Bytes(), '\n')
|
||||
if idx < 0 {
|
||||
@ -402,85 +414,44 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
line := client.buffer.Bytes()[:idx]
|
||||
client.buffer.Next(idx + 1)
|
||||
|
||||
// Parse SCRAM messages
|
||||
parts := strings.Fields(string(line))
|
||||
if len(parts) < 2 {
|
||||
c.AsyncWrite([]byte("SCRAM-FAIL Invalid message format\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
switch parts[0] {
|
||||
case "SCRAM-FIRST":
|
||||
// Parse ClientFirst JSON
|
||||
var clientFirst scram.ClientFirst
|
||||
if err := json.Unmarshal([]byte(parts[1]), &clientFirst); err != nil {
|
||||
c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Process with SCRAM server
|
||||
serverFirst, err := s.source.scramManager.HandleClientFirst(&clientFirst)
|
||||
if err != nil {
|
||||
// Still send challenge to prevent user enumeration
|
||||
response, _ := json.Marshal(serverFirst)
|
||||
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Send ServerFirst challenge
|
||||
response, _ := json.Marshal(serverFirst)
|
||||
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-CHALLENGE %s\n", response)), nil)
|
||||
|
||||
case "SCRAM-PROOF":
|
||||
// Parse ClientFinal JSON
|
||||
var clientFinal scram.ClientFinal
|
||||
if err := json.Unmarshal([]byte(parts[1]), &clientFinal); err != nil {
|
||||
c.AsyncWrite([]byte("SCRAM-FAIL Invalid JSON\n"), nil)
|
||||
return gnet.Close
|
||||
}
|
||||
|
||||
// Verify proof
|
||||
serverFinal, err := s.source.scramManager.HandleClientFinal(&clientFinal)
|
||||
// Process auth message through handler
|
||||
authenticated, session, err := s.source.scramProtocolHandler.HandleAuthMessage(line, c)
|
||||
if err != nil {
|
||||
s.source.logger.Warn("msg", "SCRAM authentication failed",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"error", err)
|
||||
c.AsyncWrite([]byte("SCRAM-FAIL Authentication failed\n"), nil)
|
||||
|
||||
if strings.Contains(err.Error(), "unknown command") {
|
||||
return gnet.Close
|
||||
}
|
||||
// Continue for other errors (might be multi-step auth)
|
||||
}
|
||||
|
||||
if authenticated && session != nil {
|
||||
// Authentication successful
|
||||
s.mu.Lock()
|
||||
client.authenticated = true
|
||||
client.session = &auth.Session{
|
||||
ID: serverFinal.SessionID,
|
||||
Method: "scram-sha-256",
|
||||
RemoteAddr: c.RemoteAddr().String(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
client.session = session
|
||||
s.mu.Unlock()
|
||||
|
||||
// Send ServerFinal with signature
|
||||
response, _ := json.Marshal(serverFinal)
|
||||
c.AsyncWrite([]byte(fmt.Sprintf("SCRAM-OK %s\n", response)), nil)
|
||||
|
||||
s.source.logger.Info("msg", "Client authenticated via SCRAM",
|
||||
"component", "tcp_source",
|
||||
"remote_addr", c.RemoteAddr().String(),
|
||||
"session_id", serverFinal.SessionID)
|
||||
"session_id", session.ID)
|
||||
|
||||
// Clear auth buffer
|
||||
client.buffer.Reset()
|
||||
|
||||
default:
|
||||
c.AsyncWrite([]byte("SCRAM-FAIL Unknown command\n"), nil)
|
||||
return gnet.Close
|
||||
break
|
||||
}
|
||||
}
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
return s.processLogData(c, client, data)
|
||||
}
|
||||
|
||||
func (s *tcpSourceServer) processLogData(c gnet.Conn, client *tcpClient, data []byte) gnet.Action {
|
||||
// Check if appending the new data would exceed the client buffer limit.
|
||||
if client.buffer.Len()+len(data) > maxClientBufferSize {
|
||||
s.source.logger.Warn("msg", "Client buffer limit exceeded, closing connection.",
|
||||
@ -572,47 +543,3 @@ func (s *tcpSourceServer) OnTraffic(c gnet.Conn) gnet.Action {
|
||||
|
||||
return gnet.None
|
||||
}
|
||||
|
||||
func (t *TCPSource) InitSCRAMManager(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type != "scram" || authCfg.ScramAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.scramManager = scram.NewScramManager()
|
||||
|
||||
// Load users from SCRAM config
|
||||
for _, user := range authCfg.ScramAuth.Users {
|
||||
storedKey, _ := base64.StdEncoding.DecodeString(user.StoredKey)
|
||||
serverKey, _ := base64.StdEncoding.DecodeString(user.ServerKey)
|
||||
salt, _ := base64.StdEncoding.DecodeString(user.Salt)
|
||||
|
||||
cred := &scram.Credential{
|
||||
Username: user.Username,
|
||||
StoredKey: storedKey,
|
||||
ServerKey: serverKey,
|
||||
Salt: salt,
|
||||
ArgonTime: user.ArgonTime,
|
||||
ArgonMemory: user.ArgonMemory,
|
||||
ArgonThreads: user.ArgonThreads,
|
||||
}
|
||||
t.scramManager.AddCredential(cred)
|
||||
}
|
||||
|
||||
t.logger.Info("msg", "SCRAM authentication configured",
|
||||
"component", "tcp_source",
|
||||
"users", len(authCfg.ScramAuth.Users))
|
||||
}
|
||||
|
||||
// Configure TCP source auth
|
||||
func (t *TCPSource) SetAuth(authCfg *config.AuthConfig) {
|
||||
if authCfg == nil || authCfg.Type == "none" {
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize SCRAM manager
|
||||
if authCfg.Type == "scram" {
|
||||
t.InitSCRAMManager(authCfg)
|
||||
t.logger.Info("msg", "SCRAM authentication configured for TCP source",
|
||||
"component", "tcp_source")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user