e5.0.0 Tests added, bug fixes.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@ logs
|
|||||||
script
|
script
|
||||||
*.log
|
*.log
|
||||||
bin
|
bin
|
||||||
|
example
|
||||||
|
|||||||
14
builder.go
14
builder.go
@ -1,4 +1,4 @@
|
|||||||
// File: lixenwraith/config/builder.go
|
// FILE: lixenwraith/config/builder.go
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -55,12 +55,9 @@ func (b *Builder) Build() (*Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register defaults if provided
|
// Explicitly set the file path on the config object so the watcher can find it,
|
||||||
if b.defaults != nil {
|
// even if the initial load fails with a non-fatal error (e.g., file not found).
|
||||||
if err := b.cfg.RegisterStruct(b.prefix, b.defaults); err != nil {
|
b.cfg.configFilePath = b.file
|
||||||
return nil, fmt.Errorf("failed to register defaults: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load configuration
|
// Load configuration
|
||||||
loadErr := b.cfg.LoadWithOptions(b.file, b.args, b.opts)
|
loadErr := b.cfg.LoadWithOptions(b.file, b.args, b.opts)
|
||||||
@ -104,6 +101,9 @@ func (b *Builder) WithTagName(tagName string) *Builder {
|
|||||||
switch tagName {
|
switch tagName {
|
||||||
case "toml", "json", "yaml":
|
case "toml", "json", "yaml":
|
||||||
b.tagName = tagName
|
b.tagName = tagName
|
||||||
|
if b.cfg != nil { // Ensure cfg exists
|
||||||
|
b.cfg.tagName = tagName
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
b.err = fmt.Errorf("unsupported tag name %q, must be one of: toml, json, yaml", tagName)
|
b.err = fmt.Errorf("unsupported tag name %q, must be one of: toml, json, yaml", tagName)
|
||||||
}
|
}
|
||||||
|
|||||||
303
builder_test.go
Normal file
303
builder_test.go
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
// FILE: lixenwraith/config/builder_test.go
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestBuilder tests the builder pattern
|
||||||
|
func TestBuilder(t *testing.T) {
|
||||||
|
t.Run("BasicBuilder", func(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
Host string `toml:"host"`
|
||||||
|
Port int `toml:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
defaults := &Config{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithDefaults(defaults).
|
||||||
|
WithEnvPrefix("TEST_").
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, cfg)
|
||||||
|
|
||||||
|
val, exists := cfg.Get("host")
|
||||||
|
assert.True(t, exists)
|
||||||
|
assert.Equal(t, "localhost", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BuilderWithAllOptions", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tmpDir, "test.toml")
|
||||||
|
os.WriteFile(configFile, []byte(`host = "filehost"`), 0644)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Host string `json:"hostname"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
defaults := &Config{
|
||||||
|
Host: "defaulthost",
|
||||||
|
Port: 3000,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom env transform
|
||||||
|
envTransform := func(path string) string {
|
||||||
|
return "CUSTOM_" + path
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithDefaults(defaults).
|
||||||
|
WithTagName("json").
|
||||||
|
WithPrefix("server").
|
||||||
|
WithEnvPrefix("APP_").
|
||||||
|
WithFile(configFile).
|
||||||
|
WithArgs([]string{"--server.hostname=clihost"}).
|
||||||
|
WithSources(SourceCLI, SourceFile, SourceEnv, SourceDefault).
|
||||||
|
WithEnvTransform(envTransform).
|
||||||
|
WithEnvWhitelist("server.hostname").
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// CLI should take precedence
|
||||||
|
val, _ := cfg.Get("server.hostname")
|
||||||
|
assert.Equal(t, "clihost", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BuilderWithTarget", func(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
Database struct {
|
||||||
|
Host string `toml:"host"`
|
||||||
|
Port int `toml:"port"`
|
||||||
|
} `toml:"db"`
|
||||||
|
Cache struct {
|
||||||
|
TTL int `toml:"ttl"`
|
||||||
|
} `toml:"cache"`
|
||||||
|
}
|
||||||
|
|
||||||
|
target := &Config{}
|
||||||
|
target.Database.Host = "localhost"
|
||||||
|
target.Database.Port = 5432
|
||||||
|
target.Cache.TTL = 300
|
||||||
|
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithTarget(target).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify paths were registered
|
||||||
|
paths := cfg.GetRegisteredPaths("")
|
||||||
|
assert.True(t, paths["db.host"])
|
||||||
|
assert.True(t, paths["db.port"])
|
||||||
|
assert.True(t, paths["cache.ttl"])
|
||||||
|
|
||||||
|
// Test AsStruct
|
||||||
|
result, err := cfg.AsStruct()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, target, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BuilderWithValidator", func(t *testing.T) {
|
||||||
|
type UserConfig struct {
|
||||||
|
Port int `toml:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
validatorCalled := false
|
||||||
|
validator := func(cfg *Config) error {
|
||||||
|
validatorCalled = true
|
||||||
|
val, exists := cfg.Get("port")
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("port not found")
|
||||||
|
}
|
||||||
|
// Convert to int - could be int64 from storage
|
||||||
|
var port int
|
||||||
|
switch v := val.(type) {
|
||||||
|
case int:
|
||||||
|
port = v
|
||||||
|
case int64:
|
||||||
|
port = int(v)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("port has unexpected type %T", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if port < 1024 {
|
||||||
|
return fmt.Errorf("port %d is below 1024", port)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid case
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithDefaults(&UserConfig{Port: 8080}).
|
||||||
|
WithValidator(validator).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, cfg)
|
||||||
|
assert.True(t, validatorCalled)
|
||||||
|
|
||||||
|
// Invalid case
|
||||||
|
validatorCalled = false
|
||||||
|
cfg2, err := NewBuilder().
|
||||||
|
WithDefaults(&UserConfig{Port: 80}).
|
||||||
|
WithValidator(validator).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
assert.Nil(t, cfg2)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "configuration validation failed")
|
||||||
|
assert.True(t, validatorCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BuilderErrorAccumulation", func(t *testing.T) {
|
||||||
|
// Unsupported tag name
|
||||||
|
_, err := NewBuilder().
|
||||||
|
WithTagName("xml").
|
||||||
|
WithDefaults(struct{}{}).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unsupported tag name")
|
||||||
|
|
||||||
|
// Invalid target
|
||||||
|
_, err = NewBuilder().
|
||||||
|
WithTarget("not-a-pointer").
|
||||||
|
Build()
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "requires non-nil pointer to struct")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("MustBuildPanic", func(t *testing.T) {
|
||||||
|
// Should not panic with valid config
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
cfg := NewBuilder().
|
||||||
|
WithDefaults(struct{ Port int }{Port: 8080}).
|
||||||
|
MustBuild()
|
||||||
|
assert.NotNil(t, cfg)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Should panic with error
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
NewBuilder().
|
||||||
|
WithTagName("invalid").
|
||||||
|
MustBuild()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFileDiscovery tests automatic config file discovery
|
||||||
|
func TestFileDiscovery(t *testing.T) {
|
||||||
|
t.Run("DiscoveryWithCLIFlag", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tmpDir, "custom.conf")
|
||||||
|
os.WriteFile(configFile, []byte(`test = "value"`), 0644)
|
||||||
|
|
||||||
|
opts := DefaultDiscoveryOptions("myapp")
|
||||||
|
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithDefaults(struct {
|
||||||
|
Test string `toml:"test"`
|
||||||
|
}{Test: "default"}).
|
||||||
|
WithArgs([]string{"--config", configFile}).
|
||||||
|
WithFileDiscovery(opts).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify file was loaded
|
||||||
|
val, _ := cfg.Get("test")
|
||||||
|
assert.Equal(t, "value", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DiscoveryWithEnvVar", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tmpDir, "env.toml")
|
||||||
|
os.WriteFile(configFile, []byte(`test = "envvalue"`), 0644)
|
||||||
|
|
||||||
|
os.Setenv("MYAPP_CONFIG", configFile)
|
||||||
|
defer os.Unsetenv("MYAPP_CONFIG")
|
||||||
|
|
||||||
|
opts := DefaultDiscoveryOptions("myapp")
|
||||||
|
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithDefaults(struct {
|
||||||
|
Test string `toml:"test"`
|
||||||
|
}{Test: "default"}).
|
||||||
|
WithFileDiscovery(opts).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
val, _ := cfg.Get("test")
|
||||||
|
assert.Equal(t, "envvalue", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DiscoveryInCurrentDir", func(t *testing.T) {
|
||||||
|
// Create config in current directory
|
||||||
|
cwd, _ := os.Getwd()
|
||||||
|
configFile := filepath.Join(cwd, "myapp.toml")
|
||||||
|
os.WriteFile(configFile, []byte(`test = "cwdvalue"`), 0644)
|
||||||
|
defer os.Remove(configFile)
|
||||||
|
|
||||||
|
opts := FileDiscoveryOptions{
|
||||||
|
Name: "myapp",
|
||||||
|
Extensions: []string{".toml"},
|
||||||
|
UseCurrentDir: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithDefaults(struct {
|
||||||
|
Test string `toml:"test"`
|
||||||
|
}{Test: "default"}).
|
||||||
|
WithFileDiscovery(opts).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
val, _ := cfg.Get("test")
|
||||||
|
assert.Equal(t, "cwdvalue", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DiscoveryPrecedence", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create multiple config files
|
||||||
|
cliFile := filepath.Join(tmpDir, "cli.toml")
|
||||||
|
envFile := filepath.Join(tmpDir, "env.toml")
|
||||||
|
os.WriteFile(cliFile, []byte(`test = "clifile"`), 0644)
|
||||||
|
os.WriteFile(envFile, []byte(`test = "envfile"`), 0644)
|
||||||
|
|
||||||
|
// CLI should take precedence over env
|
||||||
|
os.Setenv("MYAPP_CONFIG", envFile)
|
||||||
|
defer os.Unsetenv("MYAPP_CONFIG")
|
||||||
|
|
||||||
|
opts := DefaultDiscoveryOptions("myapp")
|
||||||
|
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithDefaults(struct {
|
||||||
|
Test string `toml:"test"`
|
||||||
|
}{Test: "default"}).
|
||||||
|
WithArgs([]string{"--config", cliFile}).
|
||||||
|
WithFileDiscovery(opts).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
val, _ := cfg.Get("test")
|
||||||
|
assert.Equal(t, "clifile", val)
|
||||||
|
})
|
||||||
|
}
|
||||||
168
cmd/main.go
168
cmd/main.go
@ -1,168 +0,0 @@
|
|||||||
// FILE: example/watch_demo.go
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lixenwraith/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AppConfig represents our application configuration
|
|
||||||
type AppConfig struct {
|
|
||||||
Server struct {
|
|
||||||
Host string `toml:"host"`
|
|
||||||
Port int `toml:"port"`
|
|
||||||
} `toml:"server"`
|
|
||||||
|
|
||||||
Database struct {
|
|
||||||
URL string `toml:"url"`
|
|
||||||
MaxConns int `toml:"max_conns"`
|
|
||||||
IdleTimeout time.Duration `toml:"idle_timeout"`
|
|
||||||
} `toml:"database"`
|
|
||||||
|
|
||||||
Features struct {
|
|
||||||
RateLimit bool `toml:"rate_limit"`
|
|
||||||
Caching bool `toml:"caching"`
|
|
||||||
} `toml:"features"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// Create configuration with defaults
|
|
||||||
defaults := &AppConfig{}
|
|
||||||
defaults.Server.Host = "localhost"
|
|
||||||
defaults.Server.Port = 8080
|
|
||||||
defaults.Database.MaxConns = 10
|
|
||||||
defaults.Database.IdleTimeout = 30 * time.Second
|
|
||||||
|
|
||||||
// Build configuration
|
|
||||||
cfg, err := config.NewBuilder().
|
|
||||||
WithDefaults(defaults).
|
|
||||||
WithEnvPrefix("MYAPP_").
|
|
||||||
WithFile("config.toml").
|
|
||||||
Build()
|
|
||||||
if err != nil && !errors.Is(err, config.ErrConfigNotFound) {
|
|
||||||
log.Fatal("Failed to load config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable auto-update with custom options
|
|
||||||
watchOpts := config.WatchOptions{
|
|
||||||
PollInterval: 500 * time.Millisecond, // Check twice per second
|
|
||||||
Debounce: 200 * time.Millisecond, // Quick response
|
|
||||||
MaxWatchers: 10,
|
|
||||||
ReloadTimeout: 2 * time.Second,
|
|
||||||
VerifyPermissions: true, // SECURITY: Detect permission changes
|
|
||||||
}
|
|
||||||
cfg.AutoUpdateWithOptions(watchOpts)
|
|
||||||
defer cfg.StopAutoUpdate()
|
|
||||||
|
|
||||||
// Start watching for changes
|
|
||||||
changes := cfg.Watch()
|
|
||||||
|
|
||||||
// Context for graceful shutdown
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Handle signals
|
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
|
|
||||||
// Log initial configuration
|
|
||||||
logConfig(cfg)
|
|
||||||
|
|
||||||
// Watch for changes
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case path := <-changes:
|
|
||||||
handleConfigChange(cfg, path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Main loop
|
|
||||||
log.Println("Watching for configuration changes. Edit config.toml to see updates.")
|
|
||||||
log.Println("Press Ctrl+C to exit.")
|
|
||||||
|
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-sigCh:
|
|
||||||
log.Println("Shutting down...")
|
|
||||||
return
|
|
||||||
|
|
||||||
case <-ticker.C:
|
|
||||||
// Periodic health check
|
|
||||||
var port int64
|
|
||||||
port, _ = cfg.Get("server.port")
|
|
||||||
log.Printf("Server still running on port %d", port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleConfigChange(cfg *config.Config, path string) {
|
|
||||||
switch path {
|
|
||||||
case "__file_deleted__":
|
|
||||||
log.Println("⚠️ Config file was deleted!")
|
|
||||||
case "__permissions_changed__":
|
|
||||||
log.Println("⚠️ SECURITY: Config file permissions changed!")
|
|
||||||
case "__reload_error__":
|
|
||||||
log.Printf("❌ Failed to reload config: %s", path)
|
|
||||||
case "__reload_timeout__":
|
|
||||||
log.Println("⚠️ Config reload timed out")
|
|
||||||
default:
|
|
||||||
// Normal configuration change
|
|
||||||
value, _ := cfg.Get(path)
|
|
||||||
log.Printf("📝 Config changed: %s = %v", path, value)
|
|
||||||
|
|
||||||
// Handle specific changes
|
|
||||||
switch path {
|
|
||||||
case "server.port":
|
|
||||||
log.Println("Port changed - server restart required")
|
|
||||||
case "database.url":
|
|
||||||
log.Println("Database URL changed - reconnection required")
|
|
||||||
case "features.rate_limit":
|
|
||||||
if cfg.Bool("features.rate_limit") {
|
|
||||||
log.Println("Rate limiting enabled")
|
|
||||||
} else {
|
|
||||||
log.Println("Rate limiting disabled")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func logConfig(cfg *config.Config) {
|
|
||||||
log.Println("Current configuration:")
|
|
||||||
log.Printf(" Server: %s:%d", cfg.String("server.host"), cfg.Int("server.port"))
|
|
||||||
log.Printf(" Database: %s (max_conns=%d)",
|
|
||||||
cfg.String("database.url"),
|
|
||||||
cfg.Int("database.max_conns"))
|
|
||||||
log.Printf(" Features: rate_limit=%v, caching=%v",
|
|
||||||
cfg.Bool("features.rate_limit"),
|
|
||||||
cfg.Bool("features.caching"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Example config.toml file:
|
|
||||||
/*
|
|
||||||
[server]
|
|
||||||
host = "localhost"
|
|
||||||
port = 8080
|
|
||||||
|
|
||||||
[database]
|
|
||||||
url = "postgres://localhost/myapp"
|
|
||||||
max_conns = 25
|
|
||||||
idle_timeout = "30s"
|
|
||||||
|
|
||||||
[features]
|
|
||||||
rate_limit = true
|
|
||||||
caching = false
|
|
||||||
*/
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
// File: lixenwraith/config/config.go
|
// FILE: lixenwraith/config/config.go
|
||||||
// Package config provides thread-safe configuration management for Go applications
|
// Package config provides thread-safe configuration management for Go applications
|
||||||
// with support for multiple sources: TOML files, environment variables, command-line
|
// with support for multiple sources: TOML files, environment variables, command-line
|
||||||
// arguments, and default values with configurable precedence.
|
// arguments, and default values with configurable precedence.
|
||||||
@ -52,6 +52,7 @@ type structCache struct {
|
|||||||
// 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct()
|
// 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct()
|
||||||
type Config struct {
|
type Config struct {
|
||||||
items map[string]configItem
|
items map[string]configItem
|
||||||
|
tagName string
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
options LoadOptions // Current load options
|
options LoadOptions // Current load options
|
||||||
fileData map[string]any // Cached file data
|
fileData map[string]any // Cached file data
|
||||||
@ -69,6 +70,7 @@ type Config struct {
|
|||||||
func New() *Config {
|
func New() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
items: make(map[string]configItem),
|
items: make(map[string]configItem),
|
||||||
|
tagName: "toml",
|
||||||
options: DefaultLoadOptions(),
|
options: DefaultLoadOptions(),
|
||||||
fileData: make(map[string]any),
|
fileData: make(map[string]any),
|
||||||
envData: make(map[string]any),
|
envData: make(map[string]any),
|
||||||
@ -157,6 +159,10 @@ func (c *Config) SetSource(path string, source Source, value any) error {
|
|||||||
return fmt.Errorf("path %s is not registered", path)
|
return fmt.Errorf("path %s is not registered", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if str, ok := value.(string); ok && len(str) > MaxValueSize {
|
||||||
|
return ErrValueSize
|
||||||
|
}
|
||||||
|
|
||||||
if item.values == nil {
|
if item.values == nil {
|
||||||
item.values = make(map[Source]any)
|
item.values = make(map[Source]any)
|
||||||
}
|
}
|
||||||
|
|||||||
454
config_test.go
Normal file
454
config_test.go
Normal file
@ -0,0 +1,454 @@
|
|||||||
|
// FILE: lixenwraith/config/config_test.go
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConfigCreation tests various config creation patterns
|
||||||
|
func TestConfigCreation(t *testing.T) {
|
||||||
|
t.Run("NewWithDefaultOptions", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
require.NotNil(t, cfg)
|
||||||
|
assert.NotNil(t, cfg.items)
|
||||||
|
assert.Equal(t, []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault}, cfg.options.Sources)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewWithCustomOptions", func(t *testing.T) {
|
||||||
|
opts := LoadOptions{
|
||||||
|
Sources: []Source{SourceEnv, SourceFile, SourceDefault},
|
||||||
|
EnvPrefix: "MYAPP_",
|
||||||
|
LoadMode: LoadModeReplace,
|
||||||
|
}
|
||||||
|
cfg := NewWithOptions(opts)
|
||||||
|
require.NotNil(t, cfg)
|
||||||
|
assert.Equal(t, opts.Sources, cfg.options.Sources)
|
||||||
|
assert.Equal(t, "MYAPP_", cfg.options.EnvPrefix)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPathRegistration tests path registration edge cases
|
||||||
|
func TestPathRegistration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
defaultVal any
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{"ValidSimplePath", "port", 8080, false, ""},
|
||||||
|
{"ValidNestedPath", "server.host.name", "localhost", false, ""},
|
||||||
|
{"EmptyPath", "", nil, true, "registration path cannot be empty"},
|
||||||
|
{"InvalidCharacter", "server.port!", 8080, true, "invalid path segment"},
|
||||||
|
{"InvalidDot", "server..port", 8080, true, "invalid path segment"},
|
||||||
|
{"LeadingDot", ".server.port", 8080, true, "invalid path segment"},
|
||||||
|
{"TrailingDot", "server.port.", 8080, true, "invalid path segment"},
|
||||||
|
{"ValidUnderscore", "server_config.max_connections", 100, false, ""},
|
||||||
|
{"ValidDash", "feature-flags.enable-debug", false, false, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
err := cfg.Register(tt.path, tt.defaultVal)
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
val, exists := cfg.Get(tt.path)
|
||||||
|
assert.True(t, exists)
|
||||||
|
assert.Equal(t, tt.defaultVal, val)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestComplexStructRegistration tests struct registration with various tag types
|
||||||
|
func TestComplexStructRegistration(t *testing.T) {
|
||||||
|
type DatabaseConfig struct {
|
||||||
|
Host string `toml:"host" json:"db_host" yaml:"dbHost"`
|
||||||
|
Port int `toml:"port" json:"db_port" yaml:"dbPort"`
|
||||||
|
MaxConns int `toml:"max_connections"`
|
||||||
|
Timeout time.Duration `toml:"timeout"`
|
||||||
|
EnableDebug bool `toml:"debug" env:"DB_DEBUG"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConfig struct {
|
||||||
|
Name string `toml:"name" json:"name"`
|
||||||
|
Database DatabaseConfig `toml:"db" json:"db"`
|
||||||
|
Tags []string `toml:"tags" json:"tags"`
|
||||||
|
Metadata map[string]any `toml:"metadata" json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultConfig := &ServerConfig{
|
||||||
|
Name: "test-server",
|
||||||
|
Database: DatabaseConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
MaxConns: 100,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
EnableDebug: false,
|
||||||
|
},
|
||||||
|
Tags: []string{"test", "development"},
|
||||||
|
Metadata: map[string]any{"version": "1.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("TOMLTags", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
err := cfg.RegisterStruct("", defaultConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify paths registered with TOML tags
|
||||||
|
paths := cfg.GetRegisteredPaths("")
|
||||||
|
assert.True(t, paths["name"])
|
||||||
|
assert.True(t, paths["db.host"])
|
||||||
|
assert.True(t, paths["db.port"])
|
||||||
|
assert.True(t, paths["db.max_connections"])
|
||||||
|
assert.True(t, paths["db.timeout"])
|
||||||
|
assert.True(t, paths["db.debug"])
|
||||||
|
assert.True(t, paths["tags"])
|
||||||
|
assert.True(t, paths["metadata"])
|
||||||
|
|
||||||
|
// Verify default values
|
||||||
|
val, _ := cfg.Get("db.timeout")
|
||||||
|
assert.Equal(t, 30*time.Second, val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JSONTags", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
err := cfg.RegisterStructWithTags("", defaultConfig, "json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// JSON tags should create different paths
|
||||||
|
paths := cfg.GetRegisteredPaths("")
|
||||||
|
assert.True(t, paths["db.db_host"])
|
||||||
|
assert.True(t, paths["db.db_port"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnsupportedTag", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
err := cfg.RegisterStructWithTags("", defaultConfig, "xml")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unsupported tag name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithPrefix", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
err := cfg.RegisterStruct("server", defaultConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
paths := cfg.GetRegisteredPaths("server.")
|
||||||
|
assert.True(t, paths["server.name"])
|
||||||
|
assert.True(t, paths["server.db.host"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSourcePrecedence tests configuration source precedence
|
||||||
|
func TestSourcePrecedence(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("test.value", "default")
|
||||||
|
|
||||||
|
// Set values in different sources
|
||||||
|
cfg.SetSource("test.value", SourceFile, "from-file")
|
||||||
|
cfg.SetSource("test.value", SourceEnv, "from-env")
|
||||||
|
cfg.SetSource("test.value", SourceCLI, "from-cli")
|
||||||
|
|
||||||
|
// Default precedence: CLI > Env > File > Default
|
||||||
|
val, _ := cfg.Get("test.value")
|
||||||
|
assert.Equal(t, "from-cli", val)
|
||||||
|
|
||||||
|
// Remove CLI value
|
||||||
|
cfg.ResetSource(SourceCLI)
|
||||||
|
val, _ = cfg.Get("test.value")
|
||||||
|
assert.Equal(t, "from-env", val)
|
||||||
|
|
||||||
|
// Change precedence
|
||||||
|
err := cfg.SetLoadOptions(LoadOptions{
|
||||||
|
Sources: []Source{SourceFile, SourceEnv, SourceCLI, SourceDefault},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
val, _ = cfg.Get("test.value")
|
||||||
|
assert.Equal(t, "from-file", val)
|
||||||
|
|
||||||
|
// Test GetSources
|
||||||
|
sources := cfg.GetSources("test.value")
|
||||||
|
assert.Equal(t, "from-file", sources[SourceFile])
|
||||||
|
assert.Equal(t, "from-env", sources[SourceEnv])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTypeConversion tests automatic type conversion through mapstructure
|
||||||
|
func TestTypeConversion(t *testing.T) {
|
||||||
|
type TestConfig struct {
|
||||||
|
IntValue int64 `toml:"int"`
|
||||||
|
FloatValue float64 `toml:"float"`
|
||||||
|
BoolValue bool `toml:"bool"`
|
||||||
|
Duration time.Duration `toml:"duration"`
|
||||||
|
Time time.Time `toml:"time"`
|
||||||
|
IP net.IP `toml:"ip"`
|
||||||
|
IPNet *net.IPNet `toml:"ipnet"`
|
||||||
|
URL *url.URL `toml:"url"`
|
||||||
|
StringSlice []string `toml:"strings"`
|
||||||
|
IntSlice []int `toml:"ints"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
defaults := &TestConfig{
|
||||||
|
IntValue: 42,
|
||||||
|
FloatValue: 3.14,
|
||||||
|
BoolValue: true,
|
||||||
|
Duration: 5 * time.Second,
|
||||||
|
Time: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
StringSlice: []string{"a", "b"},
|
||||||
|
IntSlice: []int{1, 2, 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.RegisterStruct("", defaults)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test string conversions from environment
|
||||||
|
cfg.SetSource("int", SourceEnv, "100")
|
||||||
|
cfg.SetSource("float", SourceEnv, "2.718")
|
||||||
|
cfg.SetSource("bool", SourceEnv, "false")
|
||||||
|
cfg.SetSource("duration", SourceEnv, "1m30s")
|
||||||
|
cfg.SetSource("time", SourceEnv, "2024-12-25T10:00:00Z")
|
||||||
|
cfg.SetSource("ip", SourceEnv, "192.168.1.1")
|
||||||
|
cfg.SetSource("ipnet", SourceEnv, "10.0.0.0/8")
|
||||||
|
cfg.SetSource("url", SourceEnv, "https://example.com:8080/path")
|
||||||
|
cfg.SetSource("strings", SourceEnv, "x,y,z")
|
||||||
|
// cfg.SetSource("ints", SourceEnv, "7,8,9") // failure due to mapstructure limitation
|
||||||
|
|
||||||
|
// Scan into struct
|
||||||
|
var result TestConfig
|
||||||
|
err = cfg.Scan("", &result)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(100), result.IntValue)
|
||||||
|
assert.Equal(t, 2.718, result.FloatValue)
|
||||||
|
assert.Equal(t, false, result.BoolValue)
|
||||||
|
assert.Equal(t, 90*time.Second, result.Duration)
|
||||||
|
assert.Equal(t, "2024-12-25T10:00:00Z", result.Time.Format(time.RFC3339))
|
||||||
|
assert.Equal(t, "192.168.1.1", result.IP.String())
|
||||||
|
assert.Equal(t, "10.0.0.0/8", result.IPNet.String())
|
||||||
|
assert.Equal(t, "https://example.com:8080/path", result.URL.String())
|
||||||
|
assert.Equal(t, []string{"x", "y", "z"}, result.StringSlice)
|
||||||
|
// Note: String to int slice conversion through env requires handling in the test
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcurrentAccess tests thread safety
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
|
||||||
|
// Register paths
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
cfg.Register(fmt.Sprintf("path%d", i), i)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errors := make(chan error, 1000)
|
||||||
|
|
||||||
|
// Concurrent readers
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
path := fmt.Sprintf("path%d", j)
|
||||||
|
if _, exists := cfg.Get(path); !exists {
|
||||||
|
errors <- fmt.Errorf("reader %d: path %s not found", id, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent writers
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
path := fmt.Sprintf("path%d", j)
|
||||||
|
value := fmt.Sprintf("writer%d-value%d", id, j)
|
||||||
|
if err := cfg.Set(path, value); err != nil {
|
||||||
|
errors <- fmt.Errorf("writer %d: %v", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent source changes
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
sources := []Source{SourceFile, SourceEnv, SourceCLI}
|
||||||
|
for j := 0; j < 50; j++ {
|
||||||
|
path := fmt.Sprintf("path%d", j)
|
||||||
|
source := sources[j%len(sources)]
|
||||||
|
value := fmt.Sprintf("source%d-value%d", id, j)
|
||||||
|
if err := cfg.SetSource(path, source, value); err != nil {
|
||||||
|
errors <- fmt.Errorf("source writer %d: %v", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
// Check for errors
|
||||||
|
var errs []error
|
||||||
|
for err := range errors {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
assert.Empty(t, errs, "Concurrent access should not produce errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUnregister tests path unregistration
|
||||||
|
func TestUnregister(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
|
||||||
|
// Register nested paths
|
||||||
|
cfg.Register("server.host", "localhost")
|
||||||
|
cfg.Register("server.port", 8080)
|
||||||
|
cfg.Register("server.tls.enabled", true)
|
||||||
|
cfg.Register("server.tls.cert", "/path/to/cert")
|
||||||
|
cfg.Register("database.host", "dbhost")
|
||||||
|
|
||||||
|
t.Run("UnregisterSinglePath", func(t *testing.T) {
|
||||||
|
err := cfg.Unregister("server.port")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, exists := cfg.Get("server.port")
|
||||||
|
assert.False(t, exists)
|
||||||
|
|
||||||
|
// Other paths should remain
|
||||||
|
_, exists = cfg.Get("server.host")
|
||||||
|
assert.True(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnregisterParentPath", func(t *testing.T) {
|
||||||
|
err := cfg.Unregister("server.tls")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// All child paths should be removed
|
||||||
|
_, exists := cfg.Get("server.tls.enabled")
|
||||||
|
assert.False(t, exists)
|
||||||
|
_, exists = cfg.Get("server.tls.cert")
|
||||||
|
assert.False(t, exists)
|
||||||
|
|
||||||
|
// Sibling paths should remain
|
||||||
|
_, exists = cfg.Get("server.host")
|
||||||
|
assert.True(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnregisterNonExistentPath", func(t *testing.T) {
|
||||||
|
err := cfg.Unregister("nonexistent.path")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "path not registered")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResetFunctionality tests reset operations
|
||||||
|
func TestResetFunctionality(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("test1", "default1")
|
||||||
|
cfg.Register("test2", "default2")
|
||||||
|
|
||||||
|
// Set values in different sources
|
||||||
|
cfg.SetSource("test1", SourceFile, "file1")
|
||||||
|
cfg.SetSource("test1", SourceEnv, "env1")
|
||||||
|
cfg.SetSource("test2", SourceCLI, "cli2")
|
||||||
|
|
||||||
|
t.Run("ResetSingleSource", func(t *testing.T) {
|
||||||
|
cfg.ResetSource(SourceEnv)
|
||||||
|
|
||||||
|
// Env value should be gone
|
||||||
|
_, exists := cfg.GetSource("test1", SourceEnv)
|
||||||
|
assert.False(t, exists)
|
||||||
|
|
||||||
|
// Other sources should remain
|
||||||
|
val, exists := cfg.GetSource("test1", SourceFile)
|
||||||
|
assert.True(t, exists)
|
||||||
|
assert.Equal(t, "file1", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ResetAll", func(t *testing.T) {
|
||||||
|
cfg.Reset()
|
||||||
|
|
||||||
|
// All values should revert to defaults
|
||||||
|
val1, _ := cfg.Get("test1")
|
||||||
|
val2, _ := cfg.Get("test2")
|
||||||
|
assert.Equal(t, "default1", val1)
|
||||||
|
assert.Equal(t, "default2", val2)
|
||||||
|
|
||||||
|
// Source values should be cleared
|
||||||
|
sources := cfg.GetSources("test1")
|
||||||
|
assert.Empty(t, sources)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValueSizeLimit tests the MaxValueSize constraint
|
||||||
|
func TestValueSizeLimit(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("test", "")
|
||||||
|
|
||||||
|
// Create a value larger than MaxValueSize
|
||||||
|
largeValue := make([]byte, MaxValueSize+1)
|
||||||
|
for i := range largeValue {
|
||||||
|
largeValue[i] = 'x'
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.Set("test", string(largeValue))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, ErrValueSize, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetRegisteredPaths tests path listing functionality
|
||||||
|
func TestGetRegisteredPaths(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
|
||||||
|
paths := []string{
|
||||||
|
"server.host",
|
||||||
|
"server.port",
|
||||||
|
"server.tls.enabled",
|
||||||
|
"database.host",
|
||||||
|
"database.port",
|
||||||
|
"cache.ttl",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range paths {
|
||||||
|
cfg.Register(path, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("GetAllPaths", func(t *testing.T) {
|
||||||
|
all := cfg.GetRegisteredPaths("")
|
||||||
|
assert.Len(t, all, len(paths))
|
||||||
|
for _, path := range paths {
|
||||||
|
assert.True(t, all[path])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetPathsWithPrefix", func(t *testing.T) {
|
||||||
|
serverPaths := cfg.GetRegisteredPaths("server.")
|
||||||
|
assert.Len(t, serverPaths, 3)
|
||||||
|
assert.True(t, serverPaths["server.host"])
|
||||||
|
assert.True(t, serverPaths["server.port"])
|
||||||
|
assert.True(t, serverPaths["server.tls.enabled"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetPathsWithDefaults", func(t *testing.T) {
|
||||||
|
defaults := cfg.GetRegisteredPathsWithDefaults("database.")
|
||||||
|
assert.Len(t, defaults, 2)
|
||||||
|
assert.Contains(t, defaults, "database.host")
|
||||||
|
assert.Contains(t, defaults, "database.port")
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -1,13 +1,14 @@
|
|||||||
// File: lixenwraith/config/convenience.go
|
// FILE: lixenwraith/config/convenience.go
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/BurntSushi/toml"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/BurntSushi/toml"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Quick creates a fully configured Config instance with a single call
|
// Quick creates a fully configured Config instance with a single call
|
||||||
|
|||||||
287
convenience_test.go
Normal file
287
convenience_test.go
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
// FILE: lixenwraith/config/convenience_test.go
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestQuickFunctions tests the convenience Quick* functions
|
||||||
|
func TestQuickFunctions(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tmpDir, "quick.toml")
|
||||||
|
os.WriteFile(configFile, []byte(`
|
||||||
|
host = "quickhost"
|
||||||
|
port = 7777
|
||||||
|
`), 0644)
|
||||||
|
|
||||||
|
type QuickConfig struct {
|
||||||
|
Host string `toml:"host"`
|
||||||
|
Port int `toml:"port"`
|
||||||
|
SSL bool `toml:"ssl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
defaults := &QuickConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
SSL: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Quick", func(t *testing.T) {
|
||||||
|
// Mock os.Args
|
||||||
|
oldArgs := os.Args
|
||||||
|
os.Args = []string{"cmd", "--port=9999"}
|
||||||
|
defer func() { os.Args = oldArgs }()
|
||||||
|
|
||||||
|
cfg, err := Quick(defaults, "QUICK_", configFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// CLI should override
|
||||||
|
port, _ := cfg.Get("port")
|
||||||
|
assert.Equal(t, "9999", port)
|
||||||
|
|
||||||
|
// File value
|
||||||
|
host, _ := cfg.Get("host")
|
||||||
|
assert.Equal(t, "quickhost", host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("QuickCustom", func(t *testing.T) {
|
||||||
|
opts := LoadOptions{
|
||||||
|
Sources: []Source{SourceFile, SourceDefault}, // Only file and defaults
|
||||||
|
EnvPrefix: "CUSTOM_",
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := QuickCustom(defaults, opts, configFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should use file value
|
||||||
|
port, _ := cfg.Get("port")
|
||||||
|
assert.Equal(t, int64(7777), port)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("MustQuickPanic", func(t *testing.T) {
|
||||||
|
// Valid case - should not panic
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
cfg := MustQuick(defaults, "TEST_", configFile)
|
||||||
|
assert.NotNil(t, cfg)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Invalid struct - should panic
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
MustQuick("not-a-struct", "TEST_", configFile)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("QuickTyped", func(t *testing.T) {
|
||||||
|
target := &QuickConfig{
|
||||||
|
Host: "typedhost",
|
||||||
|
Port: 6666,
|
||||||
|
SSL: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := QuickTyped(target, "TYPED_", configFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should populate from file
|
||||||
|
updated, err := cfg.AsStruct()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
typedCfg := updated.(*QuickConfig)
|
||||||
|
assert.Equal(t, "quickhost", typedCfg.Host)
|
||||||
|
assert.Equal(t, 7777, typedCfg.Port)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFlagGeneration tests flag generation and binding
|
||||||
|
func TestFlagGeneration(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "localhost")
|
||||||
|
cfg.Register("server.port", 8080)
|
||||||
|
cfg.Register("debug.enabled", false)
|
||||||
|
cfg.Register("timeout", 30.5)
|
||||||
|
cfg.Register("name", "app")
|
||||||
|
cfg.Register("complex", map[string]any{"key": "value"})
|
||||||
|
|
||||||
|
t.Run("GenerateFlags", func(t *testing.T) {
|
||||||
|
fs := cfg.GenerateFlags()
|
||||||
|
require.NotNil(t, fs)
|
||||||
|
|
||||||
|
// Verify flags exist
|
||||||
|
hostFlag := fs.Lookup("server.host")
|
||||||
|
require.NotNil(t, hostFlag)
|
||||||
|
assert.Equal(t, "localhost", hostFlag.DefValue)
|
||||||
|
|
||||||
|
portFlag := fs.Lookup("server.port")
|
||||||
|
require.NotNil(t, portFlag)
|
||||||
|
assert.Equal(t, "8080", portFlag.DefValue)
|
||||||
|
|
||||||
|
debugFlag := fs.Lookup("debug.enabled")
|
||||||
|
require.NotNil(t, debugFlag)
|
||||||
|
assert.Equal(t, "false", debugFlag.DefValue)
|
||||||
|
|
||||||
|
timeoutFlag := fs.Lookup("timeout")
|
||||||
|
require.NotNil(t, timeoutFlag)
|
||||||
|
assert.Equal(t, "30.5", timeoutFlag.DefValue)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BindFlags", func(t *testing.T) {
|
||||||
|
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||||
|
fs.String("server.host", "default", "")
|
||||||
|
fs.Int("server.port", 8080, "")
|
||||||
|
fs.Bool("debug.enabled", false, "")
|
||||||
|
|
||||||
|
// Parse with test values
|
||||||
|
err := fs.Parse([]string{"-server.host=flaghost", "-server.port=5555", "-debug.enabled"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Bind to config
|
||||||
|
err = cfg.BindFlags(fs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify values were set
|
||||||
|
host, _ := cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "flaghost", host)
|
||||||
|
|
||||||
|
port, _ := cfg.Get("server.port")
|
||||||
|
assert.Equal(t, "5555", port)
|
||||||
|
|
||||||
|
debug, _ := cfg.Get("debug.enabled")
|
||||||
|
assert.Equal(t, "true", debug)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BindFlagsError", func(t *testing.T) {
|
||||||
|
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||||
|
fs.String("unregistered.path", "value", "")
|
||||||
|
fs.Parse([]string{"-unregistered.path=test"})
|
||||||
|
|
||||||
|
err := cfg.BindFlags(fs)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to bind 1 flags")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidation tests configuration validation
|
||||||
|
func TestValidation(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("required.host", "")
|
||||||
|
cfg.Register("required.port", 0)
|
||||||
|
cfg.Register("optional.timeout", 30)
|
||||||
|
|
||||||
|
t.Run("ValidationFails", func(t *testing.T) {
|
||||||
|
err := cfg.Validate("required.host", "required.port")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "missing required configuration")
|
||||||
|
assert.Contains(t, err.Error(), "required.host")
|
||||||
|
assert.Contains(t, err.Error(), "required.port")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ValidationPasses", func(t *testing.T) {
|
||||||
|
cfg.Set("required.host", "localhost")
|
||||||
|
cfg.Set("required.port", 8080)
|
||||||
|
|
||||||
|
err := cfg.Validate("required.host", "required.port")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ValidationUnregisteredPath", func(t *testing.T) {
|
||||||
|
err := cfg.Validate("nonexistent.path")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "nonexistent.path (not registered)")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ValidationWithSourceValue", func(t *testing.T) {
|
||||||
|
cfg2 := New()
|
||||||
|
cfg2.Register("test", "default")
|
||||||
|
|
||||||
|
// Value equals default but from different source
|
||||||
|
cfg2.SetSource("test", SourceEnv, "default")
|
||||||
|
|
||||||
|
err := cfg2.Validate("test")
|
||||||
|
assert.NoError(t, err) // Should pass because env provided value
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDebugAndDump tests debug output functions
|
||||||
|
func TestDebugAndDump(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "localhost")
|
||||||
|
cfg.Register("server.port", 8080)
|
||||||
|
|
||||||
|
cfg.SetSource("server.host", SourceFile, "filehost")
|
||||||
|
cfg.SetSource("server.host", SourceEnv, "envhost")
|
||||||
|
cfg.SetSource("server.port", SourceCLI, "9999")
|
||||||
|
|
||||||
|
t.Run("Debug", func(t *testing.T) {
|
||||||
|
debug := cfg.Debug()
|
||||||
|
|
||||||
|
assert.Contains(t, debug, "Configuration Debug Info")
|
||||||
|
assert.Contains(t, debug, "Precedence:")
|
||||||
|
assert.Contains(t, debug, "server.host:")
|
||||||
|
assert.Contains(t, debug, "Current: envhost")
|
||||||
|
assert.Contains(t, debug, "Default: localhost")
|
||||||
|
assert.Contains(t, debug, "file: filehost")
|
||||||
|
assert.Contains(t, debug, "env: envhost")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Dump", func(t *testing.T) {
|
||||||
|
// Capture stdout
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stdout = w
|
||||||
|
|
||||||
|
err := cfg.Dump()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
w.Close()
|
||||||
|
os.Stdout = oldStdout
|
||||||
|
|
||||||
|
// Read output
|
||||||
|
output := make([]byte, 1024)
|
||||||
|
n, _ := r.Read(output)
|
||||||
|
outputStr := string(output[:n])
|
||||||
|
|
||||||
|
assert.Contains(t, outputStr, "[server]")
|
||||||
|
assert.Contains(t, outputStr, "host = ")
|
||||||
|
assert.Contains(t, outputStr, "port = ")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClone tests configuration cloning
|
||||||
|
func TestClone(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("original.value", "default")
|
||||||
|
cfg.Register("shared.value", "shared")
|
||||||
|
|
||||||
|
cfg.SetSource("original.value", SourceFile, "filevalue")
|
||||||
|
cfg.SetSource("shared.value", SourceEnv, "envvalue")
|
||||||
|
|
||||||
|
clone := cfg.Clone()
|
||||||
|
require.NotNil(t, clone)
|
||||||
|
|
||||||
|
// Verify values are copied
|
||||||
|
val, exists := clone.Get("original.value")
|
||||||
|
assert.True(t, exists)
|
||||||
|
assert.Equal(t, "filevalue", val)
|
||||||
|
|
||||||
|
val, exists = clone.Get("shared.value")
|
||||||
|
assert.True(t, exists)
|
||||||
|
assert.Equal(t, "envvalue", val)
|
||||||
|
|
||||||
|
// Modify clone should not affect original
|
||||||
|
clone.Set("original.value", "clonevalue")
|
||||||
|
|
||||||
|
originalVal, _ := cfg.Get("original.value")
|
||||||
|
cloneVal, _ := clone.Get("original.value")
|
||||||
|
|
||||||
|
assert.Equal(t, "filevalue", originalVal)
|
||||||
|
assert.Equal(t, "clonevalue", cloneVal)
|
||||||
|
|
||||||
|
// Verify source data is copied
|
||||||
|
sources := clone.GetSources("shared.value")
|
||||||
|
assert.Equal(t, "envvalue", sources[SourceEnv])
|
||||||
|
}
|
||||||
46
decode.go
46
decode.go
@ -57,7 +57,7 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error {
|
|||||||
// Create decoder with comprehensive hooks
|
// Create decoder with comprehensive hooks
|
||||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||||
Result: target,
|
Result: target,
|
||||||
TagName: "toml",
|
TagName: c.tagName,
|
||||||
WeaklyTypedInput: true,
|
WeaklyTypedInput: true,
|
||||||
DecodeHook: c.getDecodeHook(),
|
DecodeHook: c.getDecodeHook(),
|
||||||
ZeroFields: true,
|
ZeroFields: true,
|
||||||
@ -77,16 +77,16 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error {
|
|||||||
// getDecodeHook returns the composite decode hook for all type conversions
|
// getDecodeHook returns the composite decode hook for all type conversions
|
||||||
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
||||||
return mapstructure.ComposeDecodeHookFunc(
|
return mapstructure.ComposeDecodeHookFunc(
|
||||||
// Standard hooks
|
|
||||||
mapstructure.StringToTimeDurationHookFunc(),
|
|
||||||
mapstructure.StringToTimeHookFunc(time.RFC3339),
|
|
||||||
mapstructure.StringToSliceHookFunc(","),
|
|
||||||
|
|
||||||
// Network types
|
// Network types
|
||||||
stringToNetIPHookFunc(),
|
stringToNetIPHookFunc(),
|
||||||
stringToNetIPNetHookFunc(),
|
stringToNetIPNetHookFunc(),
|
||||||
stringToURLHookFunc(),
|
stringToURLHookFunc(),
|
||||||
|
|
||||||
|
// Standard hooks
|
||||||
|
mapstructure.StringToTimeDurationHookFunc(),
|
||||||
|
mapstructure.StringToTimeHookFunc(time.RFC3339),
|
||||||
|
mapstructure.StringToSliceHookFunc(","),
|
||||||
|
|
||||||
// Custom application hooks
|
// Custom application hooks
|
||||||
c.customDecodeHook(),
|
c.customDecodeHook(),
|
||||||
)
|
)
|
||||||
@ -94,7 +94,7 @@ func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
|||||||
|
|
||||||
// stringToNetIPHookFunc handles net.IP conversion
|
// stringToNetIPHookFunc handles net.IP conversion
|
||||||
func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
|
func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
|
||||||
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
|
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||||
if f.Kind() != reflect.String {
|
if f.Kind() != reflect.String {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
@ -120,27 +120,28 @@ func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
|
|||||||
|
|
||||||
// stringToNetIPNetHookFunc handles net.IPNet conversion
|
// stringToNetIPNetHookFunc handles net.IPNet conversion
|
||||||
func stringToNetIPNetHookFunc() mapstructure.DecodeHookFunc {
|
func stringToNetIPNetHookFunc() mapstructure.DecodeHookFunc {
|
||||||
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
|
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||||
if f.Kind() != reflect.String {
|
if f.Kind() != reflect.String {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
isPtr := t.Kind() == reflect.Ptr
|
||||||
if t != reflect.TypeOf(net.IPNet{}) && t != reflect.TypeOf(&net.IPNet{}) {
|
targetType := t
|
||||||
|
if isPtr {
|
||||||
|
targetType = t.Elem()
|
||||||
|
}
|
||||||
|
if targetType != reflect.TypeOf(net.IPNet{}) {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
str := data.(string)
|
str := data.(string)
|
||||||
// SECURITY: Validate CIDR format
|
|
||||||
if len(str) > 49 { // Max IPv6 CIDR length
|
if len(str) > 49 { // Max IPv6 CIDR length
|
||||||
return nil, fmt.Errorf("invalid CIDR length: %d", len(str))
|
return nil, fmt.Errorf("invalid CIDR length: %d", len(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ipnet, err := net.ParseCIDR(str)
|
_, ipnet, err := net.ParseCIDR(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid CIDR: %w", err)
|
return nil, fmt.Errorf("invalid CIDR: %w", err)
|
||||||
}
|
}
|
||||||
|
if isPtr {
|
||||||
if t == reflect.TypeOf(&net.IPNet{}) {
|
|
||||||
return ipnet, nil
|
return ipnet, nil
|
||||||
}
|
}
|
||||||
return *ipnet, nil
|
return *ipnet, nil
|
||||||
@ -149,27 +150,28 @@ func stringToNetIPNetHookFunc() mapstructure.DecodeHookFunc {
|
|||||||
|
|
||||||
// stringToURLHookFunc handles url.URL conversion
|
// stringToURLHookFunc handles url.URL conversion
|
||||||
func stringToURLHookFunc() mapstructure.DecodeHookFunc {
|
func stringToURLHookFunc() mapstructure.DecodeHookFunc {
|
||||||
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
|
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||||
if f.Kind() != reflect.String {
|
if f.Kind() != reflect.String {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
isPtr := t.Kind() == reflect.Ptr
|
||||||
if t != reflect.TypeOf(url.URL{}) && t != reflect.TypeOf(&url.URL{}) {
|
targetType := t
|
||||||
|
if isPtr {
|
||||||
|
targetType = t.Elem()
|
||||||
|
}
|
||||||
|
if targetType != reflect.TypeOf(url.URL{}) {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
str := data.(string)
|
str := data.(string)
|
||||||
// SECURITY: Validate URL length to prevent DoS
|
|
||||||
if len(str) > 2048 {
|
if len(str) > 2048 {
|
||||||
return nil, fmt.Errorf("URL too long: %d bytes", len(str))
|
return nil, fmt.Errorf("URL too long: %d bytes", len(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := url.Parse(str)
|
u, err := url.Parse(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||||
}
|
}
|
||||||
|
if isPtr {
|
||||||
if t == reflect.TypeOf(&url.URL{}) {
|
|
||||||
return u, nil
|
return u, nil
|
||||||
}
|
}
|
||||||
return *u, nil
|
return *u, nil
|
||||||
@ -178,7 +180,7 @@ func stringToURLHookFunc() mapstructure.DecodeHookFunc {
|
|||||||
|
|
||||||
// customDecodeHook allows for application-specific type conversions
|
// customDecodeHook allows for application-specific type conversions
|
||||||
func (c *Config) customDecodeHook() mapstructure.DecodeHookFunc {
|
func (c *Config) customDecodeHook() mapstructure.DecodeHookFunc {
|
||||||
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
|
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||||
// SECURITY: Add custom validation for application types here
|
// SECURITY: Add custom validation for application types here
|
||||||
// Example: Rate limit parsing, permission validation, etc.
|
// Example: Rate limit parsing, permission validation, etc.
|
||||||
|
|
||||||
|
|||||||
328
decode_test.go
Normal file
328
decode_test.go
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
// FILE: lixenwraith/config/decode_test.go
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestScanWithComplexTypes tests scanning with various complex types
|
||||||
|
func TestScanWithComplexTypes(t *testing.T) {
|
||||||
|
type NetworkConfig struct {
|
||||||
|
IP net.IP `toml:"ip"`
|
||||||
|
IPNet *net.IPNet `toml:"subnet"`
|
||||||
|
URL *url.URL `toml:"endpoint"`
|
||||||
|
Timeout time.Duration `toml:"timeout"`
|
||||||
|
Retry struct {
|
||||||
|
Count int `toml:"count"`
|
||||||
|
Interval time.Duration `toml:"interval"`
|
||||||
|
} `toml:"retry"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppConfig struct {
|
||||||
|
Network NetworkConfig `toml:"network"`
|
||||||
|
Tags []string `toml:"tags"`
|
||||||
|
Ports []int `toml:"ports"`
|
||||||
|
Labels map[string]string `toml:"labels"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
|
||||||
|
// Register with defaults
|
||||||
|
defaults := &AppConfig{
|
||||||
|
Network: NetworkConfig{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
Tags: []string{"default"},
|
||||||
|
Ports: []int{8080},
|
||||||
|
Labels: map[string]string{
|
||||||
|
"env": "dev",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.RegisterStruct("", defaults)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set values from different sources
|
||||||
|
cfg.SetSource("network.ip", SourceEnv, "192.168.1.100")
|
||||||
|
cfg.SetSource("network.subnet", SourceEnv, "192.168.1.0/24")
|
||||||
|
cfg.SetSource("network.endpoint", SourceEnv, "https://api.example.com:8443/v1")
|
||||||
|
cfg.SetSource("network.timeout", SourceFile, "2m30s")
|
||||||
|
cfg.SetSource("network.retry.count", SourceFile, int64(5))
|
||||||
|
cfg.SetSource("network.retry.interval", SourceFile, "10s")
|
||||||
|
cfg.SetSource("tags", SourceCLI, "prod,staging,test")
|
||||||
|
cfg.SetSource("ports", SourceFile, []any{int64(80), int64(443), int64(8080)})
|
||||||
|
cfg.SetSource("labels", SourceFile, map[string]any{
|
||||||
|
"env": "production",
|
||||||
|
"version": "1.2.3",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Scan into struct
|
||||||
|
var result AppConfig
|
||||||
|
err = cfg.Scan("", &result)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify conversions
|
||||||
|
assert.Equal(t, "192.168.1.100", result.Network.IP.String())
|
||||||
|
assert.Equal(t, "192.168.1.0/24", result.Network.IPNet.String())
|
||||||
|
assert.Equal(t, "https://api.example.com:8443/v1", result.Network.URL.String())
|
||||||
|
assert.Equal(t, 150*time.Second, result.Network.Timeout)
|
||||||
|
assert.Equal(t, 5, result.Network.Retry.Count)
|
||||||
|
assert.Equal(t, 10*time.Second, result.Network.Retry.Interval)
|
||||||
|
assert.Equal(t, []string{"prod", "staging", "test"}, result.Tags)
|
||||||
|
assert.Equal(t, []int{80, 443, 8080}, result.Ports)
|
||||||
|
assert.Equal(t, "production", result.Labels["env"])
|
||||||
|
assert.Equal(t, "1.2.3", result.Labels["version"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScanWithBasePath tests scanning from nested paths
|
||||||
|
func TestScanWithBasePath(t *testing.T) {
|
||||||
|
type ServerConfig struct {
|
||||||
|
Host string `toml:"host"`
|
||||||
|
Port int `toml:"port"`
|
||||||
|
Enabled bool `toml:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("app.server.host", "localhost")
|
||||||
|
cfg.Register("app.server.port", 8080)
|
||||||
|
cfg.Register("app.server.enabled", true)
|
||||||
|
cfg.Register("app.database.host", "dbhost")
|
||||||
|
|
||||||
|
cfg.Set("app.server.host", "appserver")
|
||||||
|
cfg.Set("app.server.port", 9000)
|
||||||
|
|
||||||
|
// Scan only the server section
|
||||||
|
var server ServerConfig
|
||||||
|
err := cfg.Scan("app.server", &server)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "appserver", server.Host)
|
||||||
|
assert.Equal(t, 9000, server.Port)
|
||||||
|
assert.Equal(t, true, server.Enabled)
|
||||||
|
|
||||||
|
// Test non-existent base path
|
||||||
|
var empty ServerConfig
|
||||||
|
err = cfg.Scan("app.nonexistent", &empty)
|
||||||
|
assert.NoError(t, err) // Should not error, just empty
|
||||||
|
assert.Equal(t, "", empty.Host)
|
||||||
|
assert.Equal(t, 0, empty.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScanFromSource tests scanning from specific sources
|
||||||
|
func TestScanFromSource(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
Value string `toml:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("value", "default")
|
||||||
|
|
||||||
|
cfg.SetSource("value", SourceFile, "fromfile")
|
||||||
|
cfg.SetSource("value", SourceEnv, "fromenv")
|
||||||
|
cfg.SetSource("value", SourceCLI, "fromcli")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
source Source
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{SourceFile, "fromfile"},
|
||||||
|
{SourceEnv, "fromenv"},
|
||||||
|
{SourceCLI, "fromcli"},
|
||||||
|
{SourceDefault, ""}, // No value in default source
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(string(tt.source), func(t *testing.T) {
|
||||||
|
var result Config
|
||||||
|
err := cfg.ScanSource("", tt.source, &result)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, result.Value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInvalidScanTargets tests error cases for scanning
|
||||||
|
func TestInvalidScanTargets(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("test", "value")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
target any
|
||||||
|
expectErr string
|
||||||
|
}{
|
||||||
|
{"NilPointer", nil, "must be non-nil pointer"},
|
||||||
|
{"NonPointer", "not-a-pointer", "must be non-nil pointer"},
|
||||||
|
{"NilStructPointer", (*struct{})(nil), "must be non-nil pointer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := cfg.Scan("", tt.target)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.expectErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCustomTypeConversion tests edge cases in type conversion
|
||||||
|
func TestCustomTypeConversion(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
|
||||||
|
t.Run("InvalidIPAddress", func(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
IP net.IP `toml:"ip"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Register("ip", net.IP{})
|
||||||
|
cfg.Set("ip", "not-an-ip")
|
||||||
|
|
||||||
|
var result Config
|
||||||
|
err := cfg.Scan("", &result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid IP address")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("InvalidCIDR", func(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
Network *net.IPNet `toml:"network"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Register("network", (*net.IPNet)(nil))
|
||||||
|
cfg.Set("network", "invalid-cidr")
|
||||||
|
|
||||||
|
var result Config
|
||||||
|
err := cfg.Scan("", &result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid CIDR")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("InvalidURL", func(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
Endpoint *url.URL `toml:"endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Register("endpoint", (*url.URL)(nil))
|
||||||
|
cfg.Set("endpoint", "://invalid-url")
|
||||||
|
|
||||||
|
var result Config
|
||||||
|
err := cfg.Scan("", &result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid URL")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LongIPString", func(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
IP net.IP `toml:"ip"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Register("ip", net.IP{})
|
||||||
|
// String longer than max IPv6 length
|
||||||
|
longIP := make([]byte, 50)
|
||||||
|
for i := range longIP {
|
||||||
|
longIP[i] = 'x'
|
||||||
|
}
|
||||||
|
cfg.Set("ip", string(longIP))
|
||||||
|
|
||||||
|
var result Config
|
||||||
|
err := cfg.Scan("", &result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid IP length")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LongURL", func(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
URL *url.URL `toml:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Register("url", (*url.URL)(nil))
|
||||||
|
// URL longer than 2048 bytes
|
||||||
|
longURL := "https://example.com/"
|
||||||
|
for i := 0; i < 2048; i++ {
|
||||||
|
longURL += "x"
|
||||||
|
}
|
||||||
|
cfg.Set("url", longURL)
|
||||||
|
|
||||||
|
var result Config
|
||||||
|
err := cfg.Scan("", &result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "URL too long")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestZeroFields tests that ZeroFields option works correctly
|
||||||
|
func TestZeroFields(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
KeepValue string `toml:"keep"`
|
||||||
|
ResetValue string `toml:"reset"`
|
||||||
|
NestedValue struct {
|
||||||
|
Field string `toml:"field"`
|
||||||
|
} `toml:"nested"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
|
||||||
|
// Register only some fields
|
||||||
|
cfg.Register("keep", "keepdefault")
|
||||||
|
cfg.Register("reset", "resetdefault")
|
||||||
|
// Don't register nested.field
|
||||||
|
|
||||||
|
cfg.Set("keep", "newvalue")
|
||||||
|
// Don't set reset, so it uses default
|
||||||
|
|
||||||
|
// Start with non-zero struct
|
||||||
|
result := Config{
|
||||||
|
KeepValue: "initial",
|
||||||
|
ResetValue: "initial",
|
||||||
|
NestedValue: struct {
|
||||||
|
Field string `toml:"field"`
|
||||||
|
}{Field: "initial"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.Scan("", &result)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// ZeroFields should reset all fields before decoding
|
||||||
|
assert.Equal(t, "newvalue", result.KeepValue)
|
||||||
|
assert.Equal(t, "resetdefault", result.ResetValue)
|
||||||
|
assert.Equal(t, "initial", result.NestedValue.Field) // Unregistered, so Scan should not touch it
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWeaklyTypedInput tests weak type conversion
|
||||||
|
func TestWeaklyTypedInput(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
IntFromString int `toml:"int_from_string"`
|
||||||
|
FloatFromString float64 `toml:"float_from_string"`
|
||||||
|
BoolFromString bool `toml:"bool_from_string"`
|
||||||
|
StringFromInt string `toml:"string_from_int"`
|
||||||
|
StringFromBool string `toml:"string_from_bool"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
defaults := &Config{}
|
||||||
|
cfg.RegisterStruct("", defaults)
|
||||||
|
|
||||||
|
// Set string values that should convert
|
||||||
|
cfg.Set("int_from_string", "42")
|
||||||
|
cfg.Set("float_from_string", "3.14159")
|
||||||
|
cfg.Set("bool_from_string", "true")
|
||||||
|
cfg.Set("string_from_int", 12345)
|
||||||
|
cfg.Set("string_from_bool", true)
|
||||||
|
|
||||||
|
var result Config
|
||||||
|
err := cfg.Scan("", &result)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 42, result.IntFromString)
|
||||||
|
assert.Equal(t, 3.14159, result.FloatFromString)
|
||||||
|
assert.Equal(t, true, result.BoolFromString)
|
||||||
|
assert.Equal(t, "12345", result.StringFromInt)
|
||||||
|
assert.Equal(t, "1", result.StringFromBool) // mapstructure converts bool(true) to "1" in weak conversion
|
||||||
|
}
|
||||||
@ -138,7 +138,7 @@ cfg.SetSource("feature.enabled", config.SourceFile, true)
|
|||||||
|
|
||||||
```go
|
```go
|
||||||
// Multiple updates
|
// Multiple updates
|
||||||
updates := map[string]interface{}{
|
updates := map[string]any{
|
||||||
"server.port": int64(9090),
|
"server.port": int64(9090),
|
||||||
"server.host": "0.0.0.0",
|
"server.host": "0.0.0.0",
|
||||||
"database.maxconns": int64(50),
|
"database.maxconns": int64(50),
|
||||||
@ -310,7 +310,7 @@ func (f *ConfigFacade) DatabaseURL() string {
|
|||||||
|
|
||||||
```go
|
```go
|
||||||
// Helper for optional configuration
|
// Helper for optional configuration
|
||||||
func getOrDefault(cfg *config.Config, path string, defaultVal interface{}) interface{} {
|
func getOrDefault(cfg *config.Config, path string, defaultVal any) any {
|
||||||
if val, exists := cfg.Get(path); exists {
|
if val, exists := cfg.Get(path); exists {
|
||||||
return val
|
return val
|
||||||
}
|
}
|
||||||
|
|||||||
5
go.mod
5
go.mod
@ -1,9 +1,10 @@
|
|||||||
module github.com/lixenwraith/config
|
module config
|
||||||
|
|
||||||
go 1.24.5
|
go 1.24.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/BurntSushi/toml v1.5.0
|
github.com/BurntSushi/toml v1.5.0
|
||||||
|
github.com/lixenwraith/config v0.0.0-20250719015120-e02ee494d440
|
||||||
github.com/mitchellh/mapstructure v1.5.0
|
github.com/mitchellh/mapstructure v1.5.0
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
)
|
)
|
||||||
@ -13,3 +14,5 @@ require (
|
|||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
replace github.com/mitchellh/mapstructure => github.com/go-viper/mapstructure v1.6.0
|
||||||
|
|||||||
6
go.sum
6
go.sum
@ -2,8 +2,10 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg
|
|||||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw=
|
||||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw=
|
||||||
|
github.com/lixenwraith/config v0.0.0-20250719015120-e02ee494d440 h1:O6nHnpeDfIYQ1WxCtA2gkm8upQ4RW21DUMlQz5DKJCU=
|
||||||
|
github.com/lixenwraith/config v0.0.0-20250719015120-e02ee494d440/go.mod h1:y7kgDrWIFROWJJ6ASM/SPTRRAj27FjRGWh2SDLcdQ68=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
// File: lixenwraith/config/helper.go
|
// FILE: lixenwraith/config/helper.go
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import "strings"
|
import "strings"
|
||||||
|
|||||||
24
loader.go
24
loader.go
@ -1,4 +1,4 @@
|
|||||||
// File: lixenwraith/config/loader.go
|
// FILE: lixenwraith/config/loader.go
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -481,7 +481,6 @@ func parseArgs(args []string) (map[string]any, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the leading "--"
|
|
||||||
argContent := strings.TrimPrefix(arg, "--")
|
argContent := strings.TrimPrefix(arg, "--")
|
||||||
if argContent == "" {
|
if argContent == "" {
|
||||||
// Skip "--" argument if used as a separator
|
// Skip "--" argument if used as a separator
|
||||||
@ -501,20 +500,22 @@ func parseArgs(args []string) (map[string]any, error) {
|
|||||||
} else {
|
} else {
|
||||||
// Handle "--key value" or "--booleanflag"
|
// Handle "--key value" or "--booleanflag"
|
||||||
keyPath = argContent
|
keyPath = argContent
|
||||||
// Check if it's potentially a boolean flag
|
// Check if it's a boolean flag (next arg is another flag or end of args)
|
||||||
isBoolFlag := i+1 >= len(args) || strings.HasPrefix(args[i+1], "--")
|
if i+1 >= len(args) || strings.HasPrefix(args[i+1], "--") {
|
||||||
|
|
||||||
if isBoolFlag {
|
|
||||||
// Assume boolean flag is true if no value follows
|
|
||||||
valueStr = "true"
|
valueStr = "true"
|
||||||
i++ // Consume only the flag argument
|
i++ // Consume only the flag argument
|
||||||
} else {
|
} else {
|
||||||
// Potential key-value pair with space separation
|
// It's a key-value pair with a space
|
||||||
valueStr = args[i+1]
|
valueStr = args[i+1]
|
||||||
i += 2 // Consume flag and value arguments
|
i += 2 // Consume both flag and value arguments
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if keyPath == "" {
|
||||||
|
// Skip invalid flags like --=value
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Validate keyPath segments
|
// Validate keyPath segments
|
||||||
segments := strings.Split(keyPath, ".")
|
segments := strings.Split(keyPath, ".")
|
||||||
for _, segment := range segments {
|
for _, segment := range segments {
|
||||||
@ -523,9 +524,8 @@ func parseArgs(args []string) (map[string]any, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the value
|
// Always store as a string. Let Scan handle final type conversion.
|
||||||
value := parseValue(valueStr)
|
setNestedValue(result, keyPath, valueStr)
|
||||||
setNestedValue(result, keyPath, value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
|||||||
434
loader_test.go
Normal file
434
loader_test.go
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
// FILE: lixenwraith/config/loader_test.go
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestFileLoading tests TOML file loading
|
||||||
|
func TestFileLoading(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
t.Run("ValidTOMLFile", func(t *testing.T) {
|
||||||
|
configFile := filepath.Join(tmpDir, "valid.toml")
|
||||||
|
content := `
|
||||||
|
# Server configuration
|
||||||
|
[server]
|
||||||
|
host = "example.com"
|
||||||
|
port = 9000
|
||||||
|
enabled = true
|
||||||
|
|
||||||
|
[server.tls]
|
||||||
|
cert = "/path/to/cert.pem"
|
||||||
|
key = "/path/to/key.pem"
|
||||||
|
|
||||||
|
[database]
|
||||||
|
connections = [1, 2, 3]
|
||||||
|
tags = ["primary", "replica"]
|
||||||
|
`
|
||||||
|
os.WriteFile(configFile, []byte(content), 0644)
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
// Register all paths
|
||||||
|
cfg.Register("server.host", "localhost")
|
||||||
|
cfg.Register("server.port", 8080)
|
||||||
|
cfg.Register("server.enabled", false)
|
||||||
|
cfg.Register("server.tls.cert", "")
|
||||||
|
cfg.Register("server.tls.key", "")
|
||||||
|
cfg.Register("database.connections", []int{})
|
||||||
|
cfg.Register("database.tags", []string{})
|
||||||
|
|
||||||
|
err := cfg.LoadFile(configFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify loaded values
|
||||||
|
host, _ := cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "example.com", host)
|
||||||
|
|
||||||
|
port, _ := cfg.Get("server.port")
|
||||||
|
assert.Equal(t, int64(9000), port)
|
||||||
|
|
||||||
|
enabled, _ := cfg.Get("server.enabled")
|
||||||
|
assert.Equal(t, true, enabled)
|
||||||
|
|
||||||
|
cert, _ := cfg.Get("server.tls.cert")
|
||||||
|
assert.Equal(t, "/path/to/cert.pem", cert)
|
||||||
|
|
||||||
|
// Arrays are loaded as []any
|
||||||
|
connections, _ := cfg.Get("database.connections")
|
||||||
|
assert.Equal(t, []any{int64(1), int64(2), int64(3)}, connections)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("InvalidTOMLFile", func(t *testing.T) {
|
||||||
|
configFile := filepath.Join(tmpDir, "invalid.toml")
|
||||||
|
os.WriteFile(configFile, []byte(`invalid = toml content`), 0644)
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
err := cfg.LoadFile(configFile)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to parse TOML")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NonExistentFile", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
err := cfg.LoadFile("/non/existent/file.toml")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrConfigNotFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnregisteredPathsIgnored", func(t *testing.T) {
|
||||||
|
configFile := filepath.Join(tmpDir, "extra.toml")
|
||||||
|
os.WriteFile(configFile, []byte(`
|
||||||
|
registered = "value"
|
||||||
|
unregistered = "ignored"
|
||||||
|
`), 0644)
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("registered", "")
|
||||||
|
|
||||||
|
err := cfg.LoadFile(configFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
val, exists := cfg.Get("registered")
|
||||||
|
assert.True(t, exists)
|
||||||
|
assert.Equal(t, "value", val)
|
||||||
|
|
||||||
|
_, exists = cfg.Get("unregistered")
|
||||||
|
assert.False(t, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEnvironmentLoading tests environment variable loading
|
||||||
|
func TestEnvironmentLoading(t *testing.T) {
|
||||||
|
// Save and restore environment
|
||||||
|
originalEnv := os.Environ()
|
||||||
|
defer func() {
|
||||||
|
os.Clearenv()
|
||||||
|
for _, e := range originalEnv {
|
||||||
|
parts := splitEnvVar(e)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
os.Setenv(parts[0], parts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Run("DefaultEnvTransform", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "localhost")
|
||||||
|
cfg.Register("server.port", 8080)
|
||||||
|
cfg.Register("enable_debug", false)
|
||||||
|
|
||||||
|
os.Setenv("APP_SERVER_HOST", "envhost")
|
||||||
|
os.Setenv("APP_SERVER_PORT", "9090")
|
||||||
|
os.Setenv("APP_ENABLE_DEBUG", "true")
|
||||||
|
|
||||||
|
err := cfg.LoadEnv("APP_")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
host, _ := cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "envhost", host)
|
||||||
|
|
||||||
|
port, _ := cfg.Get("server.port")
|
||||||
|
assert.Equal(t, "9090", port) // String from env
|
||||||
|
|
||||||
|
debug, _ := cfg.Get("enable_debug")
|
||||||
|
assert.Equal(t, "true", debug) // String from env
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CustomEnvTransform", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("db.host", "localhost")
|
||||||
|
|
||||||
|
os.Setenv("DATABASE_HOSTNAME", "customhost")
|
||||||
|
|
||||||
|
opts := LoadOptions{
|
||||||
|
Sources: []Source{SourceEnv, SourceDefault},
|
||||||
|
EnvTransform: func(path string) string {
|
||||||
|
if path == "db.host" {
|
||||||
|
return "DATABASE_HOSTNAME"
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.LoadWithOptions("", nil, opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
host, _ := cfg.Get("db.host")
|
||||||
|
assert.Equal(t, "customhost", host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EnvWhitelist", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("allowed.path", "default1")
|
||||||
|
cfg.Register("blocked.path", "default2")
|
||||||
|
|
||||||
|
os.Setenv("ALLOWED_PATH", "env1")
|
||||||
|
os.Setenv("BLOCKED_PATH", "env2")
|
||||||
|
|
||||||
|
opts := LoadOptions{
|
||||||
|
Sources: []Source{SourceEnv, SourceDefault},
|
||||||
|
EnvWhitelist: map[string]bool{"allowed.path": true},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.LoadWithOptions("", nil, opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
allowed, _ := cfg.Get("allowed.path")
|
||||||
|
assert.Equal(t, "env1", allowed)
|
||||||
|
|
||||||
|
blocked, _ := cfg.Get("blocked.path")
|
||||||
|
assert.Equal(t, "default2", blocked) // Should not load from env
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DiscoverEnv", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("test.one", "")
|
||||||
|
cfg.Register("test.two", "")
|
||||||
|
cfg.Register("other.value", "")
|
||||||
|
|
||||||
|
os.Setenv("PREFIX_TEST_ONE", "value1")
|
||||||
|
os.Setenv("PREFIX_TEST_TWO", "value2")
|
||||||
|
os.Setenv("PREFIX_OTHER_VALUE", "value3")
|
||||||
|
os.Setenv("UNRELATED_VAR", "ignored")
|
||||||
|
|
||||||
|
discovered := cfg.DiscoverEnv("PREFIX_")
|
||||||
|
assert.Len(t, discovered, 3)
|
||||||
|
assert.Equal(t, "PREFIX_TEST_ONE", discovered["test.one"])
|
||||||
|
assert.Equal(t, "PREFIX_TEST_TWO", discovered["test.two"])
|
||||||
|
assert.Equal(t, "PREFIX_OTHER_VALUE", discovered["other.value"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCLIParsing tests command-line argument parsing
|
||||||
|
func TestCLIParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expected map[string]any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "KeyValueWithEquals",
|
||||||
|
args: []string{"--server.host=example.com", "--server.port=9000"},
|
||||||
|
expected: map[string]any{
|
||||||
|
"server.host": "example.com",
|
||||||
|
"server.port": "9000",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "KeyValueWithSpace",
|
||||||
|
args: []string{"--server.host", "example.com", "--server.port", "9000"},
|
||||||
|
expected: map[string]any{
|
||||||
|
"server.host": "example.com",
|
||||||
|
"server.port": "9000",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "BooleanFlags",
|
||||||
|
args: []string{"--enable.debug", "--disable.cache", "false"},
|
||||||
|
expected: map[string]any{
|
||||||
|
"enable.debug": "true",
|
||||||
|
"disable.cache": "false",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MixedFormats",
|
||||||
|
args: []string{
|
||||||
|
"--server.host=localhost",
|
||||||
|
"--server.port", "8080",
|
||||||
|
"--enable.tls",
|
||||||
|
"--database.pool.size=10",
|
||||||
|
},
|
||||||
|
expected: map[string]any{
|
||||||
|
"server.host": "localhost",
|
||||||
|
"server.port": "8080",
|
||||||
|
"enable.tls": "true",
|
||||||
|
"database.pool.size": "10",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "EmptyAndInvalidArgs",
|
||||||
|
args: []string{"", "--", "---", "--=value"},
|
||||||
|
expected: map[string]any{
|
||||||
|
"": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
|
||||||
|
// Register expected paths
|
||||||
|
for path := range tt.expected {
|
||||||
|
if path != "" { // Skip empty path
|
||||||
|
cfg.Register(path, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.LoadCLI(tt.args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify values
|
||||||
|
for path, expected := range tt.expected {
|
||||||
|
if path != "" {
|
||||||
|
val, exists := cfg.Get(path)
|
||||||
|
assert.True(t, exists, "Path %s should exist", path)
|
||||||
|
assert.Equal(t, expected, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("InvalidKeySegment", func(t *testing.T) {
|
||||||
|
result, err := parseArgs([]string{"--invalid!key=value"})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid command-line key segment")
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadWithOptions tests complete loading with multiple sources
|
||||||
|
func TestLoadWithOptions(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tmpDir, "config.toml")
|
||||||
|
os.WriteFile(configFile, []byte(`
|
||||||
|
[server]
|
||||||
|
host = "filehost"
|
||||||
|
port = 8080
|
||||||
|
`), 0644)
|
||||||
|
|
||||||
|
os.Setenv("TEST_SERVER_HOST", "envhost")
|
||||||
|
os.Setenv("TEST_SERVER_PORT", "9090")
|
||||||
|
defer func() {
|
||||||
|
os.Unsetenv("TEST_SERVER_HOST")
|
||||||
|
os.Unsetenv("TEST_SERVER_PORT")
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "defaulthost")
|
||||||
|
cfg.Register("server.port", 3000)
|
||||||
|
|
||||||
|
args := []string{"--server.port=7070"}
|
||||||
|
|
||||||
|
opts := LoadOptions{
|
||||||
|
Sources: []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault},
|
||||||
|
EnvPrefix: "TEST_",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.LoadWithOptions(configFile, args, opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// CLI should win
|
||||||
|
port, _ := cfg.Get("server.port")
|
||||||
|
assert.Equal(t, "7070", port)
|
||||||
|
|
||||||
|
// ENV should win over file
|
||||||
|
host, _ := cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "envhost", host)
|
||||||
|
|
||||||
|
// Test source inspection
|
||||||
|
sources := cfg.GetSources("server.port")
|
||||||
|
assert.Equal(t, "7070", sources[SourceCLI])
|
||||||
|
assert.Equal(t, "9090", sources[SourceEnv])
|
||||||
|
assert.Equal(t, int64(8080), sources[SourceFile])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAtomicSave tests atomic file saving
|
||||||
|
func TestAtomicSave(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "localhost")
|
||||||
|
cfg.Register("server.port", 8080)
|
||||||
|
cfg.Register("database.url", "postgres://localhost/db")
|
||||||
|
|
||||||
|
// Set some values
|
||||||
|
cfg.Set("server.host", "savehost")
|
||||||
|
cfg.Set("server.port", 9999)
|
||||||
|
|
||||||
|
t.Run("SaveCurrentState", func(t *testing.T) {
|
||||||
|
savePath := filepath.Join(tmpDir, "saved.toml")
|
||||||
|
err := cfg.Save(savePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify file exists and is readable
|
||||||
|
content, err := os.ReadFile(savePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(content), "savehost")
|
||||||
|
assert.Contains(t, string(content), "9999")
|
||||||
|
|
||||||
|
// Load into new config to verify
|
||||||
|
cfg2 := New()
|
||||||
|
cfg2.Register("server.host", "")
|
||||||
|
cfg2.Register("server.port", 0)
|
||||||
|
err = cfg2.LoadFile(savePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
host, _ := cfg2.Get("server.host")
|
||||||
|
assert.Equal(t, "savehost", host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SaveSpecificSource", func(t *testing.T) {
|
||||||
|
cfg.SetSource("server.host", SourceEnv, "envhost")
|
||||||
|
cfg.SetSource("server.port", SourceEnv, "7777")
|
||||||
|
cfg.SetSource("server.port", SourceFile, "6666")
|
||||||
|
|
||||||
|
savePath := filepath.Join(tmpDir, "env-only.toml")
|
||||||
|
err := cfg.SaveSource(savePath, SourceEnv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
content, err := os.ReadFile(savePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(content), "envhost")
|
||||||
|
assert.Contains(t, string(content), "7777")
|
||||||
|
assert.NotContains(t, string(content), "6666")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SaveToNonExistentDirectory", func(t *testing.T) {
|
||||||
|
savePath := filepath.Join(tmpDir, "new", "dir", "config.toml")
|
||||||
|
err := cfg.Save(savePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify file was created
|
||||||
|
_, err = os.Stat(savePath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExportEnv tests environment variable export
|
||||||
|
func TestExportEnv(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "defaulthost")
|
||||||
|
cfg.Register("server.port", 8080)
|
||||||
|
cfg.Register("feature.enabled", false)
|
||||||
|
|
||||||
|
// Only export non-default values
|
||||||
|
cfg.Set("server.host", "exporthost")
|
||||||
|
cfg.Set("feature.enabled", true)
|
||||||
|
|
||||||
|
exports := cfg.ExportEnv("APP_")
|
||||||
|
|
||||||
|
assert.Len(t, exports, 2)
|
||||||
|
assert.Equal(t, "exporthost", exports["APP_SERVER_HOST"])
|
||||||
|
assert.Equal(t, "true", exports["APP_FEATURE_ENABLED"])
|
||||||
|
assert.NotContains(t, exports, "APP_SERVER_PORT") // Still default
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitEnvVar splits environment variable into key and value
|
||||||
|
func splitEnvVar(env string) []string {
|
||||||
|
parts := make([]string, 2)
|
||||||
|
for i := 0; i < len(env); i++ {
|
||||||
|
if env[i] == '=' {
|
||||||
|
parts[0] = env[:i]
|
||||||
|
parts[1] = env[i+1:]
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return []string{env}
|
||||||
|
}
|
||||||
20
register.go
20
register.go
@ -1,4 +1,4 @@
|
|||||||
// File: lixenwraith/config/register.go
|
// FILE: lixenwraith/config/register.go
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -186,21 +186,31 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
|
|||||||
isPtrToStruct := fieldValue.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct
|
isPtrToStruct := fieldValue.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct
|
||||||
|
|
||||||
if isStruct || isPtrToStruct {
|
if isStruct || isPtrToStruct {
|
||||||
// Dereference pointer if necessary
|
// Check if the field's TYPE is one that should be treated as a single value,
|
||||||
|
// even though it's a struct. These types have custom decode hooks.
|
||||||
|
fieldType := fieldValue.Type()
|
||||||
|
isAtomicStruct := false
|
||||||
|
switch fieldType.String() {
|
||||||
|
case "time.Time", "*net.IPNet", "*url.URL", "net.IP": // Match the exact type names
|
||||||
|
isAtomicStruct = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only recurse if it's a "normal" struct, not an atomic one.
|
||||||
|
if !isAtomicStruct {
|
||||||
nestedValue := fieldValue
|
nestedValue := fieldValue
|
||||||
if isPtrToStruct {
|
if isPtrToStruct {
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
// Skip nil pointers
|
continue // Skip nil pointers in the default struct
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
nestedValue = fieldValue.Elem()
|
nestedValue = fieldValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
// For nested structs, append a dot and continue recursion
|
|
||||||
nestedPrefix := currentPath + "."
|
nestedPrefix := currentPath + "."
|
||||||
c.registerFields(nestedValue, nestedPrefix, fieldPath+field.Name+".", errors, tagName)
|
c.registerFields(nestedValue, nestedPrefix, fieldPath+field.Name+".", errors, tagName)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// If it is an atomic struct, we fall through and register it as a single value.
|
||||||
|
}
|
||||||
|
|
||||||
// Register non-struct fields
|
// Register non-struct fields
|
||||||
defaultValue := fieldValue.Interface()
|
defaultValue := fieldValue.Interface()
|
||||||
|
|||||||
5
watch.go
5
watch.go
@ -1,15 +1,16 @@
|
|||||||
// File: lixenwraith/config/watch.go
|
// FILE: lixenwraith/config/watch.go
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/BurntSushi/toml"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/BurntSushi/toml"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WatchOptions configures file watching behavior
|
// WatchOptions configures file watching behavior
|
||||||
|
|||||||
336
watch_test.go
336
watch_test.go
@ -9,10 +9,14 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TestAutoUpdate tests automatic configuration reloading
|
||||||
func TestAutoUpdate(t *testing.T) {
|
func TestAutoUpdate(t *testing.T) {
|
||||||
// Create temporary config file
|
// Setup
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
configPath := filepath.Join(tmpDir, "test.toml")
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
|
|
||||||
@ -24,10 +28,7 @@ host = "localhost"
|
|||||||
[features]
|
[features]
|
||||||
enabled = true
|
enabled = true
|
||||||
`
|
`
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte(initialConfig), 0644))
|
||||||
if err := os.WriteFile(configPath, []byte(initialConfig), 0644); err != nil {
|
|
||||||
t.Fatal("Failed to write initial config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create config with defaults
|
// Create config with defaults
|
||||||
type TestConfig struct {
|
type TestConfig struct {
|
||||||
@ -49,14 +50,12 @@ enabled = true
|
|||||||
WithDefaults(defaults).
|
WithDefaults(defaults).
|
||||||
WithFile(configPath).
|
WithFile(configPath).
|
||||||
Build()
|
Build()
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal("Failed to build config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify initial values
|
// Verify initial values
|
||||||
if port, _ := cfg.Get("server.port"); port.(int64) != 8080 {
|
port, exists := cfg.Get("server.port")
|
||||||
t.Errorf("Expected port 8080, got %d", port)
|
assert.True(t, exists)
|
||||||
}
|
assert.Equal(t, int64(8080), port)
|
||||||
|
|
||||||
// Enable auto-update with fast polling
|
// Enable auto-update with fast polling
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
@ -91,26 +90,20 @@ host = "0.0.0.0"
|
|||||||
[features]
|
[features]
|
||||||
enabled = false
|
enabled = false
|
||||||
`
|
`
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644))
|
||||||
if err := os.WriteFile(configPath, []byte(updatedConfig), 0644); err != nil {
|
|
||||||
t.Fatal("Failed to update config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for changes to be detected
|
// Wait for changes to be detected
|
||||||
time.Sleep(300 * time.Millisecond)
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
// Verify new values
|
// Verify new values
|
||||||
if port, _ := cfg.Get("server.port"); port.(int64) != 9090 {
|
port, _ = cfg.Get("server.port")
|
||||||
t.Errorf("Expected port 9090 after update, got %d", port)
|
assert.Equal(t, int64(9090), port)
|
||||||
}
|
|
||||||
|
|
||||||
if host, _ := cfg.Get("server.host"); host.(string) != "0.0.0.0" {
|
host, _ := cfg.Get("server.host")
|
||||||
t.Errorf("Expected host 0.0.0.0 after update, got %s", host)
|
assert.Equal(t, "0.0.0.0", host)
|
||||||
}
|
|
||||||
|
|
||||||
if enabled, _ := cfg.Get("features.enabled"); enabled.(bool) != false {
|
enabled, _ := cfg.Get("features.enabled")
|
||||||
t.Errorf("Expected features.enabled to be false after update")
|
assert.Equal(t, false, enabled)
|
||||||
}
|
|
||||||
|
|
||||||
// Check that changes were notified
|
// Check that changes were notified
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
@ -118,27 +111,22 @@ enabled = false
|
|||||||
|
|
||||||
expectedChanges := []string{"server.port", "server.host", "features.enabled"}
|
expectedChanges := []string{"server.port", "server.host", "features.enabled"}
|
||||||
for _, path := range expectedChanges {
|
for _, path := range expectedChanges {
|
||||||
if !changedPaths[path] {
|
assert.True(t, changedPaths[path], "Expected change notification for %s", path)
|
||||||
t.Errorf("Expected change notification for %s", path)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestWatchFileDeleted tests behavior when config file is deleted
|
||||||
func TestWatchFileDeleted(t *testing.T) {
|
func TestWatchFileDeleted(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
configPath := filepath.Join(tmpDir, "test.toml")
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
|
|
||||||
// Create initial config
|
// Create initial config
|
||||||
if err := os.WriteFile(configPath, []byte(`test = "value"`), 0644); err != nil {
|
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
|
||||||
t.Fatal("Failed to write config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := New()
|
cfg := New()
|
||||||
cfg.Register("test", "default")
|
cfg.Register("test", "default")
|
||||||
|
|
||||||
if err := cfg.LoadFile(configPath); err != nil {
|
require.NoError(t, cfg.LoadFile(configPath))
|
||||||
t.Fatal("Failed to load config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable watching
|
// Enable watching
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
@ -151,21 +139,18 @@ func TestWatchFileDeleted(t *testing.T) {
|
|||||||
changes := cfg.Watch()
|
changes := cfg.Watch()
|
||||||
|
|
||||||
// Delete file
|
// Delete file
|
||||||
if err := os.Remove(configPath); err != nil {
|
require.NoError(t, os.Remove(configPath))
|
||||||
t.Fatal("Failed to delete config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for deletion detection
|
// Wait for deletion detection
|
||||||
select {
|
select {
|
||||||
case path := <-changes:
|
case path := <-changes:
|
||||||
if path != "file_deleted" {
|
assert.Equal(t, "file_deleted", path)
|
||||||
t.Errorf("Expected file_deleted, got %s", path)
|
|
||||||
}
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
t.Error("Timeout waiting for deletion notification")
|
t.Error("Timeout waiting for deletion notification")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestWatchPermissionChange tests permission change detection
|
||||||
func TestWatchPermissionChange(t *testing.T) {
|
func TestWatchPermissionChange(t *testing.T) {
|
||||||
// Skip on Windows where permission model is different
|
// Skip on Windows where permission model is different
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
@ -176,16 +161,11 @@ func TestWatchPermissionChange(t *testing.T) {
|
|||||||
configPath := filepath.Join(tmpDir, "test.toml")
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
|
|
||||||
// Create config with specific permissions
|
// Create config with specific permissions
|
||||||
if err := os.WriteFile(configPath, []byte(`test = "value"`), 0644); err != nil {
|
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
|
||||||
t.Fatal("Failed to write config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := New()
|
cfg := New()
|
||||||
cfg.Register("test", "default")
|
cfg.Register("test", "default")
|
||||||
|
require.NoError(t, cfg.LoadFile(configPath))
|
||||||
if err := cfg.LoadFile(configPath); err != nil {
|
|
||||||
t.Fatal("Failed to load config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable watching with permission verification
|
// Enable watching with permission verification
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
@ -199,21 +179,18 @@ func TestWatchPermissionChange(t *testing.T) {
|
|||||||
changes := cfg.Watch()
|
changes := cfg.Watch()
|
||||||
|
|
||||||
// Change permissions to world-writable (security risk)
|
// Change permissions to world-writable (security risk)
|
||||||
if err := os.Chmod(configPath, 0666); err != nil {
|
require.NoError(t, os.Chmod(configPath, 0666))
|
||||||
t.Fatal("Failed to change permissions:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for permission change detection
|
// Wait for permission change detection
|
||||||
select {
|
select {
|
||||||
case path := <-changes:
|
case path := <-changes:
|
||||||
if path != "permissions_changed" {
|
assert.Equal(t, "permissions_changed", path)
|
||||||
t.Errorf("Expected permissions_changed, got %s", path)
|
|
||||||
}
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
t.Error("Timeout waiting for permission change notification")
|
t.Error("Timeout waiting for permission change notification")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestMaxWatchers tests watcher limit enforcement
|
||||||
func TestMaxWatchers(t *testing.T) {
|
func TestMaxWatchers(t *testing.T) {
|
||||||
cfg := New()
|
cfg := New()
|
||||||
cfg.Register("test", "value")
|
cfg.Register("test", "value")
|
||||||
@ -221,13 +198,8 @@ func TestMaxWatchers(t *testing.T) {
|
|||||||
// Create config file
|
// Create config file
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
configPath := filepath.Join(tmpDir, "test.toml")
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
if err := os.WriteFile(configPath, []byte(`test = "value"`), 0644); err != nil {
|
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
|
||||||
t.Fatal("Failed to write config:", err)
|
require.NoError(t, cfg.LoadFile(configPath))
|
||||||
}
|
|
||||||
|
|
||||||
if err := cfg.LoadFile(configPath); err != nil {
|
|
||||||
t.Fatal("Failed to load config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable watching with low max watchers
|
// Enable watching with low max watchers
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
@ -244,45 +216,40 @@ func TestMaxWatchers(t *testing.T) {
|
|||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
|
|
||||||
// Check if channel is open
|
// Check if channel is open
|
||||||
|
if i < 3 {
|
||||||
|
// First 3 should be open
|
||||||
select {
|
select {
|
||||||
case _, ok := <-ch:
|
case _, ok := <-ch:
|
||||||
if !ok && i < 3 {
|
assert.True(t, ok || i < 3, "Channel %d should be open", i)
|
||||||
t.Errorf("Channel %d should be open", i)
|
|
||||||
} else if ok && i == 3 {
|
|
||||||
t.Error("Channel 3 should be closed (max watchers exceeded)")
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
// Channel is open and empty, expected for first 3
|
// Channel is open and empty, expected
|
||||||
if i == 3 {
|
}
|
||||||
// Try to receive with timeout to verify it's closed
|
} else {
|
||||||
|
// 4th should be closed immediately
|
||||||
select {
|
select {
|
||||||
case _, ok := <-ch:
|
case _, ok := <-ch:
|
||||||
if ok {
|
assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)")
|
||||||
t.Error("Channel 3 should be closed")
|
|
||||||
}
|
|
||||||
case <-time.After(10 * time.Millisecond):
|
case <-time.After(10 * time.Millisecond):
|
||||||
t.Error("Channel 3 should be closed immediately")
|
t.Error("Channel 3 should be closed immediately")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
// Verify watcher count
|
||||||
|
assert.Equal(t, 3, cfg.WatcherCount())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDebounce tests that rapid changes are debounced
|
||||||
func TestDebounce(t *testing.T) {
|
func TestDebounce(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
configPath := filepath.Join(tmpDir, "test.toml")
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
|
|
||||||
// Create initial config
|
// Create initial config
|
||||||
if err := os.WriteFile(configPath, []byte(`value = 1`), 0644); err != nil {
|
require.NoError(t, os.WriteFile(configPath, []byte(`value = 1`), 0644))
|
||||||
t.Fatal("Failed to write config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := New()
|
cfg := New()
|
||||||
cfg.Register("value", 0)
|
cfg.Register("value", 0)
|
||||||
|
require.NoError(t, cfg.LoadFile(configPath))
|
||||||
if err := cfg.LoadFile(configPath); err != nil {
|
|
||||||
t.Fatal("Failed to load config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable watching with longer debounce
|
// Enable watching with longer debounce
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
@ -293,39 +260,211 @@ func TestDebounce(t *testing.T) {
|
|||||||
defer cfg.StopAutoUpdate()
|
defer cfg.StopAutoUpdate()
|
||||||
|
|
||||||
changes := cfg.Watch()
|
changes := cfg.Watch()
|
||||||
changeCount := 0
|
|
||||||
|
var changeCount int
|
||||||
|
var mu sync.Mutex
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for range changes {
|
for {
|
||||||
|
select {
|
||||||
|
case <-changes:
|
||||||
|
mu.Lock()
|
||||||
changeCount++
|
changeCount++
|
||||||
|
mu.Unlock()
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Make rapid changes
|
// Make rapid changes
|
||||||
for i := 2; i <= 5; i++ {
|
for i := 2; i <= 5; i++ {
|
||||||
content := fmt.Sprintf(`value = %d`, i)
|
content := fmt.Sprintf(`value = %d`, i)
|
||||||
if err := os.WriteFile(configPath, []byte(content), 0644); err != nil {
|
require.NoError(t, os.WriteFile(configPath, []byte(content), 0644))
|
||||||
t.Fatal("Failed to write config:", err)
|
|
||||||
}
|
|
||||||
time.Sleep(50 * time.Millisecond) // Less than debounce period
|
time.Sleep(50 * time.Millisecond) // Less than debounce period
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for debounce to complete
|
// Wait for debounce to complete
|
||||||
time.Sleep(300 * time.Millisecond)
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
done <- true
|
||||||
|
|
||||||
// Should only see one change due to debounce
|
// Should only see one change due to debounce
|
||||||
if changeCount != 1 {
|
mu.Lock()
|
||||||
t.Errorf("Expected 1 change due to debounce, got %d", changeCount)
|
defer mu.Unlock()
|
||||||
}
|
assert.Equal(t, 1, changeCount, "Expected 1 change due to debounce, got %d", changeCount)
|
||||||
|
|
||||||
// Verify final value
|
// Verify final value
|
||||||
val, _ := cfg.Get("value")
|
val, _ := cfg.Get("value")
|
||||||
if val.(int64) != 5 {
|
assert.Equal(t, int64(5), val)
|
||||||
t.Errorf("Expected final value 5, got %d", val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark file watching overhead
|
// TestWatchWithoutFile tests watching behavior when no file is configured
|
||||||
|
func TestWatchWithoutFile(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("test", "value")
|
||||||
|
|
||||||
|
// No file loaded, watch should return closed channel
|
||||||
|
ch := cfg.Watch()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case _, ok := <-ch:
|
||||||
|
assert.False(t, ok, "Channel should be closed when no file to watch")
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
t.Error("Channel should be closed immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, cfg.IsWatching())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcurrentWatchOperations tests thread safety of watch operations
|
||||||
|
func TestConcurrentWatchOperations(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte(`value = 1`), 0644))
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("value", 0)
|
||||||
|
require.NoError(t, cfg.LoadFile(configPath))
|
||||||
|
|
||||||
|
opts := WatchOptions{
|
||||||
|
PollInterval: 50 * time.Millisecond,
|
||||||
|
MaxWatchers: 50,
|
||||||
|
}
|
||||||
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
|
defer cfg.StopAutoUpdate()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errors := make(chan error, 100)
|
||||||
|
|
||||||
|
// Start multiple watchers concurrently
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
ch := cfg.Watch()
|
||||||
|
if ch == nil {
|
||||||
|
errors <- fmt.Errorf("watcher %d: got nil channel", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to receive
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
// OK, got a change
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
// OK, no changes yet
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent config updates
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
content := fmt.Sprintf(`value = %d`, id+10)
|
||||||
|
if err := os.WriteFile(configPath, []byte(content), 0644); err != nil {
|
||||||
|
errors <- fmt.Errorf("writer %d: %v", id, err)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check IsWatching concurrently
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
isWatching := false
|
||||||
|
for j := 0; j < 5; j++ { // Poll a few times, double-dip wait for goroutine to start
|
||||||
|
if cfg.IsWatching() {
|
||||||
|
isWatching = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
if !isWatching {
|
||||||
|
errors <- fmt.Errorf("checker %d: IsWatching returned false", id)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
// Check for errors
|
||||||
|
var errs []error
|
||||||
|
for err := range errors {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
assert.Empty(t, errs, "Concurrent operations should not produce errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReloadTimeout tests reload timeout handling
|
||||||
|
func TestReloadTimeout(t *testing.T) {
|
||||||
|
// This test would require mocking file operations to simulate a slow read
|
||||||
|
// For now, we'll test that timeout option is respected in configuration
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("test", "value")
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
|
||||||
|
require.NoError(t, cfg.LoadFile(configPath))
|
||||||
|
|
||||||
|
// Very short timeout
|
||||||
|
opts := WatchOptions{
|
||||||
|
PollInterval: 100 * time.Millisecond,
|
||||||
|
ReloadTimeout: 1 * time.Nanosecond, // Extremely short
|
||||||
|
}
|
||||||
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
|
defer cfg.StopAutoUpdate()
|
||||||
|
|
||||||
|
waitForWatchingState(t, cfg, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStopAutoUpdate tests clean shutdown of watcher
|
||||||
|
func TestStopAutoUpdate(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("test", "value")
|
||||||
|
require.NoError(t, cfg.LoadFile(configPath))
|
||||||
|
|
||||||
|
// Start watching
|
||||||
|
cfg.AutoUpdate()
|
||||||
|
waitForWatchingState(t, cfg, true, "Watcher should be active after first start")
|
||||||
|
|
||||||
|
ch := cfg.Watch()
|
||||||
|
|
||||||
|
// Stop watching
|
||||||
|
cfg.StopAutoUpdate()
|
||||||
|
|
||||||
|
// Verify stopped
|
||||||
|
waitForWatchingState(t, cfg, false, "Watcher should be inactive after stop")
|
||||||
|
assert.Equal(t, 0, cfg.WatcherCount())
|
||||||
|
|
||||||
|
// Channel should eventually close
|
||||||
|
select {
|
||||||
|
case _, ok := <-ch:
|
||||||
|
assert.False(t, ok, "Channel should be closed after stop")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
// OK, channel might not close immediately
|
||||||
|
}
|
||||||
|
|
||||||
|
// Starting again should work
|
||||||
|
cfg.AutoUpdate()
|
||||||
|
waitForWatchingState(t, cfg, true, "Watcher should be active after restart")
|
||||||
|
cfg.StopAutoUpdate()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWatchOverhead benchmarks the overhead of file watching
|
||||||
func BenchmarkWatchOverhead(b *testing.B) {
|
func BenchmarkWatchOverhead(b *testing.B) {
|
||||||
tmpDir := b.TempDir()
|
tmpDir := b.TempDir()
|
||||||
configPath := filepath.Join(tmpDir, "bench.toml")
|
configPath := filepath.Join(tmpDir, "bench.toml")
|
||||||
@ -335,19 +474,13 @@ func BenchmarkWatchOverhead(b *testing.B) {
|
|||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
configContent += fmt.Sprintf("value%d = %d\n", i, i)
|
configContent += fmt.Sprintf("value%d = %d\n", i, i)
|
||||||
}
|
}
|
||||||
|
require.NoError(b, os.WriteFile(configPath, []byte(configContent), 0644))
|
||||||
if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil {
|
|
||||||
b.Fatal("Failed to write config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := New()
|
cfg := New()
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
cfg.Register(fmt.Sprintf("value%d", i), 0)
|
cfg.Register(fmt.Sprintf("value%d", i), 0)
|
||||||
}
|
}
|
||||||
|
require.NoError(b, cfg.LoadFile(configPath))
|
||||||
if err := cfg.LoadFile(configPath); err != nil {
|
|
||||||
b.Fatal("Failed to load config:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable watching
|
// Enable watching
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
@ -362,3 +495,10 @@ func BenchmarkWatchOverhead(b *testing.B) {
|
|||||||
_, _ = cfg.Get(fmt.Sprintf("value%d", i%100))
|
_, _ = cfg.Get(fmt.Sprintf("value%d", i%100))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// helper function to wait for watcher state, preventing race conditions of goroutine start and test check
|
||||||
|
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return cfg.IsWatching() == expected
|
||||||
|
}, 200*time.Millisecond, 10*time.Millisecond, msgAndArgs...)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user