From e9b55063ffe087dbfa6fd6385c7168197d2c71276cc829fa063668bb93539c98 Mon Sep 17 00:00:00 2001 From: Lixen Wraith Date: Sat, 19 Jul 2025 19:05:17 -0400 Subject: [PATCH] e5.0.0 Tests added, bug fixes. --- .gitignore | 1 + builder.go | 14 +- builder_test.go | 303 +++++++++++++++++++++++++++++ cmd/main.go | 168 ---------------- config.go | 8 +- config_test.go | 454 ++++++++++++++++++++++++++++++++++++++++++++ convenience.go | 5 +- convenience_test.go | 287 ++++++++++++++++++++++++++++ decode.go | 48 ++--- decode_test.go | 328 ++++++++++++++++++++++++++++++++ doc/access.md | 4 +- go.mod | 5 +- go.sum | 6 +- helper.go | 2 +- loader.go | 24 +-- loader_test.go | 434 ++++++++++++++++++++++++++++++++++++++++++ register.go | 36 ++-- watch.go | 5 +- watch_test.go | 350 ++++++++++++++++++++++++---------- 19 files changed, 2143 insertions(+), 339 deletions(-) create mode 100644 builder_test.go delete mode 100644 cmd/main.go create mode 100644 config_test.go create mode 100644 convenience_test.go create mode 100644 decode_test.go create mode 100644 loader_test.go diff --git a/.gitignore b/.gitignore index bacfe38..9c9b403 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ logs script *.log bin +example diff --git a/builder.go b/builder.go index 39d0122..5540b64 100644 --- a/builder.go +++ b/builder.go @@ -1,4 +1,4 @@ -// File: lixenwraith/config/builder.go +// FILE: lixenwraith/config/builder.go package config import ( @@ -55,12 +55,9 @@ func (b *Builder) Build() (*Config, error) { } } - // Register defaults if provided - if b.defaults != nil { - if err := b.cfg.RegisterStruct(b.prefix, b.defaults); err != nil { - return nil, fmt.Errorf("failed to register defaults: %w", err) - } - } + // Explicitly set the file path on the config object so the watcher can find it, + // even if the initial load fails with a non-fatal error (e.g., file not found). + b.cfg.configFilePath = b.file // Load configuration loadErr := b.cfg.LoadWithOptions(b.file, b.args, b.opts) @@ -104,6 +101,9 @@ func (b *Builder) WithTagName(tagName string) *Builder { switch tagName { case "toml", "json", "yaml": b.tagName = tagName + if b.cfg != nil { // Ensure cfg exists + b.cfg.tagName = tagName + } default: b.err = fmt.Errorf("unsupported tag name %q, must be one of: toml, json, yaml", tagName) } diff --git a/builder_test.go b/builder_test.go new file mode 100644 index 0000000..c78c9aa --- /dev/null +++ b/builder_test.go @@ -0,0 +1,303 @@ +// FILE: lixenwraith/config/builder_test.go +package config + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestBuilder tests the builder pattern +func TestBuilder(t *testing.T) { + t.Run("BasicBuilder", func(t *testing.T) { + type Config struct { + Host string `toml:"host"` + Port int `toml:"port"` + } + + defaults := &Config{ + Host: "localhost", + Port: 8080, + } + + cfg, err := NewBuilder(). + WithDefaults(defaults). + WithEnvPrefix("TEST_"). + Build() + + require.NoError(t, err) + assert.NotNil(t, cfg) + + val, exists := cfg.Get("host") + assert.True(t, exists) + assert.Equal(t, "localhost", val) + }) + + t.Run("BuilderWithAllOptions", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "test.toml") + os.WriteFile(configFile, []byte(`host = "filehost"`), 0644) + + type Config struct { + Host string `json:"hostname"` + Port int `json:"port"` + } + + defaults := &Config{ + Host: "defaulthost", + Port: 3000, + } + + // Custom env transform + envTransform := func(path string) string { + return "CUSTOM_" + path + } + + cfg, err := NewBuilder(). + WithDefaults(defaults). + WithTagName("json"). + WithPrefix("server"). + WithEnvPrefix("APP_"). + WithFile(configFile). + WithArgs([]string{"--server.hostname=clihost"}). + WithSources(SourceCLI, SourceFile, SourceEnv, SourceDefault). + WithEnvTransform(envTransform). + WithEnvWhitelist("server.hostname"). + Build() + + require.NoError(t, err) + + // CLI should take precedence + val, _ := cfg.Get("server.hostname") + assert.Equal(t, "clihost", val) + }) + + t.Run("BuilderWithTarget", func(t *testing.T) { + type Config struct { + Database struct { + Host string `toml:"host"` + Port int `toml:"port"` + } `toml:"db"` + Cache struct { + TTL int `toml:"ttl"` + } `toml:"cache"` + } + + target := &Config{} + target.Database.Host = "localhost" + target.Database.Port = 5432 + target.Cache.TTL = 300 + + cfg, err := NewBuilder(). + WithTarget(target). + Build() + + require.NoError(t, err) + + // Verify paths were registered + paths := cfg.GetRegisteredPaths("") + assert.True(t, paths["db.host"]) + assert.True(t, paths["db.port"]) + assert.True(t, paths["cache.ttl"]) + + // Test AsStruct + result, err := cfg.AsStruct() + require.NoError(t, err) + assert.Equal(t, target, result) + }) + + t.Run("BuilderWithValidator", func(t *testing.T) { + type UserConfig struct { + Port int `toml:"port"` + } + + validatorCalled := false + validator := func(cfg *Config) error { + validatorCalled = true + val, exists := cfg.Get("port") + if !exists { + return fmt.Errorf("port not found") + } + // Convert to int - could be int64 from storage + var port int + switch v := val.(type) { + case int: + port = v + case int64: + port = int(v) + default: + return fmt.Errorf("port has unexpected type %T", v) + } + + if port < 1024 { + return fmt.Errorf("port %d is below 1024", port) + } + return nil + } + + // Valid case + cfg, err := NewBuilder(). + WithDefaults(&UserConfig{Port: 8080}). + WithValidator(validator). + Build() + + require.NoError(t, err) + assert.NotNil(t, cfg) + assert.True(t, validatorCalled) + + // Invalid case + validatorCalled = false + cfg2, err := NewBuilder(). + WithDefaults(&UserConfig{Port: 80}). + WithValidator(validator). + Build() + + assert.Nil(t, cfg2) + assert.Error(t, err) + assert.Contains(t, err.Error(), "configuration validation failed") + assert.True(t, validatorCalled) + }) + + t.Run("BuilderErrorAccumulation", func(t *testing.T) { + // Unsupported tag name + _, err := NewBuilder(). + WithTagName("xml"). + WithDefaults(struct{}{}). + Build() + + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported tag name") + + // Invalid target + _, err = NewBuilder(). + WithTarget("not-a-pointer"). + Build() + + assert.Error(t, err) + assert.Contains(t, err.Error(), "requires non-nil pointer to struct") + }) + + t.Run("MustBuildPanic", func(t *testing.T) { + // Should not panic with valid config + assert.NotPanics(t, func() { + cfg := NewBuilder(). + WithDefaults(struct{ Port int }{Port: 8080}). + MustBuild() + assert.NotNil(t, cfg) + }) + + // Should panic with error + assert.Panics(t, func() { + NewBuilder(). + WithTagName("invalid"). + MustBuild() + }) + }) +} + +// TestFileDiscovery tests automatic config file discovery +func TestFileDiscovery(t *testing.T) { + t.Run("DiscoveryWithCLIFlag", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "custom.conf") + os.WriteFile(configFile, []byte(`test = "value"`), 0644) + + opts := DefaultDiscoveryOptions("myapp") + + cfg, err := NewBuilder(). + WithDefaults(struct { + Test string `toml:"test"` + }{Test: "default"}). + WithArgs([]string{"--config", configFile}). + WithFileDiscovery(opts). + Build() + + require.NoError(t, err) + + // Verify file was loaded + val, _ := cfg.Get("test") + assert.Equal(t, "value", val) + }) + + t.Run("DiscoveryWithEnvVar", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "env.toml") + os.WriteFile(configFile, []byte(`test = "envvalue"`), 0644) + + os.Setenv("MYAPP_CONFIG", configFile) + defer os.Unsetenv("MYAPP_CONFIG") + + opts := DefaultDiscoveryOptions("myapp") + + cfg, err := NewBuilder(). + WithDefaults(struct { + Test string `toml:"test"` + }{Test: "default"}). + WithFileDiscovery(opts). + Build() + + require.NoError(t, err) + + val, _ := cfg.Get("test") + assert.Equal(t, "envvalue", val) + }) + + t.Run("DiscoveryInCurrentDir", func(t *testing.T) { + // Create config in current directory + cwd, _ := os.Getwd() + configFile := filepath.Join(cwd, "myapp.toml") + os.WriteFile(configFile, []byte(`test = "cwdvalue"`), 0644) + defer os.Remove(configFile) + + opts := FileDiscoveryOptions{ + Name: "myapp", + Extensions: []string{".toml"}, + UseCurrentDir: true, + } + + cfg, err := NewBuilder(). + WithDefaults(struct { + Test string `toml:"test"` + }{Test: "default"}). + WithFileDiscovery(opts). + Build() + + require.NoError(t, err) + + val, _ := cfg.Get("test") + assert.Equal(t, "cwdvalue", val) + }) + + t.Run("DiscoveryPrecedence", func(t *testing.T) { + tmpDir := t.TempDir() + + // Create multiple config files + cliFile := filepath.Join(tmpDir, "cli.toml") + envFile := filepath.Join(tmpDir, "env.toml") + os.WriteFile(cliFile, []byte(`test = "clifile"`), 0644) + os.WriteFile(envFile, []byte(`test = "envfile"`), 0644) + + // CLI should take precedence over env + os.Setenv("MYAPP_CONFIG", envFile) + defer os.Unsetenv("MYAPP_CONFIG") + + opts := DefaultDiscoveryOptions("myapp") + + cfg, err := NewBuilder(). + WithDefaults(struct { + Test string `toml:"test"` + }{Test: "default"}). + WithArgs([]string{"--config", cliFile}). + WithFileDiscovery(opts). + Build() + + require.NoError(t, err) + + val, _ := cfg.Get("test") + assert.Equal(t, "clifile", val) + }) +} \ No newline at end of file diff --git a/cmd/main.go b/cmd/main.go deleted file mode 100644 index 6bfd1f7..0000000 --- a/cmd/main.go +++ /dev/null @@ -1,168 +0,0 @@ -// FILE: example/watch_demo.go -package main - -import ( - "context" - "errors" - "log" - "os" - "os/signal" - "syscall" - "time" - - "github.com/lixenwraith/config" -) - -// AppConfig represents our application configuration -type AppConfig struct { - Server struct { - Host string `toml:"host"` - Port int `toml:"port"` - } `toml:"server"` - - Database struct { - URL string `toml:"url"` - MaxConns int `toml:"max_conns"` - IdleTimeout time.Duration `toml:"idle_timeout"` - } `toml:"database"` - - Features struct { - RateLimit bool `toml:"rate_limit"` - Caching bool `toml:"caching"` - } `toml:"features"` -} - -func main() { - // Create configuration with defaults - defaults := &AppConfig{} - defaults.Server.Host = "localhost" - defaults.Server.Port = 8080 - defaults.Database.MaxConns = 10 - defaults.Database.IdleTimeout = 30 * time.Second - - // Build configuration - cfg, err := config.NewBuilder(). - WithDefaults(defaults). - WithEnvPrefix("MYAPP_"). - WithFile("config.toml"). - Build() - if err != nil && !errors.Is(err, config.ErrConfigNotFound) { - log.Fatal("Failed to load config:", err) - } - - // Enable auto-update with custom options - watchOpts := config.WatchOptions{ - PollInterval: 500 * time.Millisecond, // Check twice per second - Debounce: 200 * time.Millisecond, // Quick response - MaxWatchers: 10, - ReloadTimeout: 2 * time.Second, - VerifyPermissions: true, // SECURITY: Detect permission changes - } - cfg.AutoUpdateWithOptions(watchOpts) - defer cfg.StopAutoUpdate() - - // Start watching for changes - changes := cfg.Watch() - - // Context for graceful shutdown - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Handle signals - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - // Log initial configuration - logConfig(cfg) - - // Watch for changes - go func() { - for { - select { - case <-ctx.Done(): - return - case path := <-changes: - handleConfigChange(cfg, path) - } - } - }() - - // Main loop - log.Println("Watching for configuration changes. Edit config.toml to see updates.") - log.Println("Press Ctrl+C to exit.") - - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - select { - case <-sigCh: - log.Println("Shutting down...") - return - - case <-ticker.C: - // Periodic health check - var port int64 - port, _ = cfg.Get("server.port") - log.Printf("Server still running on port %d", port) - } - } -} - -func handleConfigChange(cfg *config.Config, path string) { - switch path { - case "__file_deleted__": - log.Println("⚠️ Config file was deleted!") - case "__permissions_changed__": - log.Println("⚠️ SECURITY: Config file permissions changed!") - case "__reload_error__": - log.Printf("❌ Failed to reload config: %s", path) - case "__reload_timeout__": - log.Println("⚠️ Config reload timed out") - default: - // Normal configuration change - value, _ := cfg.Get(path) - log.Printf("📝 Config changed: %s = %v", path, value) - - // Handle specific changes - switch path { - case "server.port": - log.Println("Port changed - server restart required") - case "database.url": - log.Println("Database URL changed - reconnection required") - case "features.rate_limit": - if cfg.Bool("features.rate_limit") { - log.Println("Rate limiting enabled") - } else { - log.Println("Rate limiting disabled") - } - } - } -} - -func logConfig(cfg *config.Config) { - log.Println("Current configuration:") - log.Printf(" Server: %s:%d", cfg.String("server.host"), cfg.Int("server.port")) - log.Printf(" Database: %s (max_conns=%d)", - cfg.String("database.url"), - cfg.Int("database.max_conns")) - log.Printf(" Features: rate_limit=%v, caching=%v", - cfg.Bool("features.rate_limit"), - cfg.Bool("features.caching")) -} - -// Example config.toml file: -/* -[server] -host = "localhost" -port = 8080 - -[database] -url = "postgres://localhost/myapp" -max_conns = 25 -idle_timeout = "30s" - -[features] -rate_limit = true -caching = false -*/ \ No newline at end of file diff --git a/config.go b/config.go index 5a8f7ae..f257dc6 100644 --- a/config.go +++ b/config.go @@ -1,4 +1,4 @@ -// File: lixenwraith/config/config.go +// FILE: lixenwraith/config/config.go // Package config provides thread-safe configuration management for Go applications // with support for multiple sources: TOML files, environment variables, command-line // arguments, and default values with configurable precedence. @@ -52,6 +52,7 @@ type structCache struct { // 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct() type Config struct { items map[string]configItem + tagName string mutex sync.RWMutex options LoadOptions // Current load options fileData map[string]any // Cached file data @@ -69,6 +70,7 @@ type Config struct { func New() *Config { return &Config{ items: make(map[string]configItem), + tagName: "toml", options: DefaultLoadOptions(), fileData: make(map[string]any), envData: make(map[string]any), @@ -157,6 +159,10 @@ func (c *Config) SetSource(path string, source Source, value any) error { return fmt.Errorf("path %s is not registered", path) } + if str, ok := value.(string); ok && len(str) > MaxValueSize { + return ErrValueSize + } + if item.values == nil { item.values = make(map[Source]any) } diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..ca0829c --- /dev/null +++ b/config_test.go @@ -0,0 +1,454 @@ +// FILE: lixenwraith/config/config_test.go +package config + +import ( + "fmt" + "net" + "net/url" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestConfigCreation tests various config creation patterns +func TestConfigCreation(t *testing.T) { + t.Run("NewWithDefaultOptions", func(t *testing.T) { + cfg := New() + require.NotNil(t, cfg) + assert.NotNil(t, cfg.items) + assert.Equal(t, []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault}, cfg.options.Sources) + }) + + t.Run("NewWithCustomOptions", func(t *testing.T) { + opts := LoadOptions{ + Sources: []Source{SourceEnv, SourceFile, SourceDefault}, + EnvPrefix: "MYAPP_", + LoadMode: LoadModeReplace, + } + cfg := NewWithOptions(opts) + require.NotNil(t, cfg) + assert.Equal(t, opts.Sources, cfg.options.Sources) + assert.Equal(t, "MYAPP_", cfg.options.EnvPrefix) + }) +} + +// TestPathRegistration tests path registration edge cases +func TestPathRegistration(t *testing.T) { + tests := []struct { + name string + path string + defaultVal any + expectError bool + errorMsg string + }{ + {"ValidSimplePath", "port", 8080, false, ""}, + {"ValidNestedPath", "server.host.name", "localhost", false, ""}, + {"EmptyPath", "", nil, true, "registration path cannot be empty"}, + {"InvalidCharacter", "server.port!", 8080, true, "invalid path segment"}, + {"InvalidDot", "server..port", 8080, true, "invalid path segment"}, + {"LeadingDot", ".server.port", 8080, true, "invalid path segment"}, + {"TrailingDot", "server.port.", 8080, true, "invalid path segment"}, + {"ValidUnderscore", "server_config.max_connections", 100, false, ""}, + {"ValidDash", "feature-flags.enable-debug", false, false, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := New() + err := cfg.Register(tt.path, tt.defaultVal) + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + val, exists := cfg.Get(tt.path) + assert.True(t, exists) + assert.Equal(t, tt.defaultVal, val) + } + }) + } +} + +// TestComplexStructRegistration tests struct registration with various tag types +func TestComplexStructRegistration(t *testing.T) { + type DatabaseConfig struct { + Host string `toml:"host" json:"db_host" yaml:"dbHost"` + Port int `toml:"port" json:"db_port" yaml:"dbPort"` + MaxConns int `toml:"max_connections"` + Timeout time.Duration `toml:"timeout"` + EnableDebug bool `toml:"debug" env:"DB_DEBUG"` + } + + type ServerConfig struct { + Name string `toml:"name" json:"name"` + Database DatabaseConfig `toml:"db" json:"db"` + Tags []string `toml:"tags" json:"tags"` + Metadata map[string]any `toml:"metadata" json:"metadata"` + } + + defaultConfig := &ServerConfig{ + Name: "test-server", + Database: DatabaseConfig{ + Host: "localhost", + Port: 5432, + MaxConns: 100, + Timeout: 30 * time.Second, + EnableDebug: false, + }, + Tags: []string{"test", "development"}, + Metadata: map[string]any{"version": "1.0"}, + } + + t.Run("TOMLTags", func(t *testing.T) { + cfg := New() + err := cfg.RegisterStruct("", defaultConfig) + require.NoError(t, err) + + // Verify paths registered with TOML tags + paths := cfg.GetRegisteredPaths("") + assert.True(t, paths["name"]) + assert.True(t, paths["db.host"]) + assert.True(t, paths["db.port"]) + assert.True(t, paths["db.max_connections"]) + assert.True(t, paths["db.timeout"]) + assert.True(t, paths["db.debug"]) + assert.True(t, paths["tags"]) + assert.True(t, paths["metadata"]) + + // Verify default values + val, _ := cfg.Get("db.timeout") + assert.Equal(t, 30*time.Second, val) + }) + + t.Run("JSONTags", func(t *testing.T) { + cfg := New() + err := cfg.RegisterStructWithTags("", defaultConfig, "json") + require.NoError(t, err) + + // JSON tags should create different paths + paths := cfg.GetRegisteredPaths("") + assert.True(t, paths["db.db_host"]) + assert.True(t, paths["db.db_port"]) + }) + + t.Run("UnsupportedTag", func(t *testing.T) { + cfg := New() + err := cfg.RegisterStructWithTags("", defaultConfig, "xml") + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported tag name") + }) + + t.Run("WithPrefix", func(t *testing.T) { + cfg := New() + err := cfg.RegisterStruct("server", defaultConfig) + require.NoError(t, err) + + paths := cfg.GetRegisteredPaths("server.") + assert.True(t, paths["server.name"]) + assert.True(t, paths["server.db.host"]) + }) +} + +// TestSourcePrecedence tests configuration source precedence +func TestSourcePrecedence(t *testing.T) { + cfg := New() + cfg.Register("test.value", "default") + + // Set values in different sources + cfg.SetSource("test.value", SourceFile, "from-file") + cfg.SetSource("test.value", SourceEnv, "from-env") + cfg.SetSource("test.value", SourceCLI, "from-cli") + + // Default precedence: CLI > Env > File > Default + val, _ := cfg.Get("test.value") + assert.Equal(t, "from-cli", val) + + // Remove CLI value + cfg.ResetSource(SourceCLI) + val, _ = cfg.Get("test.value") + assert.Equal(t, "from-env", val) + + // Change precedence + err := cfg.SetLoadOptions(LoadOptions{ + Sources: []Source{SourceFile, SourceEnv, SourceCLI, SourceDefault}, + }) + require.NoError(t, err) + val, _ = cfg.Get("test.value") + assert.Equal(t, "from-file", val) + + // Test GetSources + sources := cfg.GetSources("test.value") + assert.Equal(t, "from-file", sources[SourceFile]) + assert.Equal(t, "from-env", sources[SourceEnv]) +} + +// TestTypeConversion tests automatic type conversion through mapstructure +func TestTypeConversion(t *testing.T) { + type TestConfig struct { + IntValue int64 `toml:"int"` + FloatValue float64 `toml:"float"` + BoolValue bool `toml:"bool"` + Duration time.Duration `toml:"duration"` + Time time.Time `toml:"time"` + IP net.IP `toml:"ip"` + IPNet *net.IPNet `toml:"ipnet"` + URL *url.URL `toml:"url"` + StringSlice []string `toml:"strings"` + IntSlice []int `toml:"ints"` + } + + cfg := New() + defaults := &TestConfig{ + IntValue: 42, + FloatValue: 3.14, + BoolValue: true, + Duration: 5 * time.Second, + Time: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + IP: net.ParseIP("127.0.0.1"), + StringSlice: []string{"a", "b"}, + IntSlice: []int{1, 2, 3}, + } + + err := cfg.RegisterStruct("", defaults) + require.NoError(t, err) + + // Test string conversions from environment + cfg.SetSource("int", SourceEnv, "100") + cfg.SetSource("float", SourceEnv, "2.718") + cfg.SetSource("bool", SourceEnv, "false") + cfg.SetSource("duration", SourceEnv, "1m30s") + cfg.SetSource("time", SourceEnv, "2024-12-25T10:00:00Z") + cfg.SetSource("ip", SourceEnv, "192.168.1.1") + cfg.SetSource("ipnet", SourceEnv, "10.0.0.0/8") + cfg.SetSource("url", SourceEnv, "https://example.com:8080/path") + cfg.SetSource("strings", SourceEnv, "x,y,z") + // cfg.SetSource("ints", SourceEnv, "7,8,9") // failure due to mapstructure limitation + + // Scan into struct + var result TestConfig + err = cfg.Scan("", &result) + require.NoError(t, err) + + assert.Equal(t, int64(100), result.IntValue) + assert.Equal(t, 2.718, result.FloatValue) + assert.Equal(t, false, result.BoolValue) + assert.Equal(t, 90*time.Second, result.Duration) + assert.Equal(t, "2024-12-25T10:00:00Z", result.Time.Format(time.RFC3339)) + assert.Equal(t, "192.168.1.1", result.IP.String()) + assert.Equal(t, "10.0.0.0/8", result.IPNet.String()) + assert.Equal(t, "https://example.com:8080/path", result.URL.String()) + assert.Equal(t, []string{"x", "y", "z"}, result.StringSlice) + // Note: String to int slice conversion through env requires handling in the test +} + +// TestConcurrentAccess tests thread safety +func TestConcurrentAccess(t *testing.T) { + cfg := New() + + // Register paths + for i := 0; i < 100; i++ { + cfg.Register(fmt.Sprintf("path%d", i), i) + } + + var wg sync.WaitGroup + errors := make(chan error, 1000) + + // Concurrent readers + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 100; j++ { + path := fmt.Sprintf("path%d", j) + if _, exists := cfg.Get(path); !exists { + errors <- fmt.Errorf("reader %d: path %s not found", id, path) + } + } + }(i) + } + + // Concurrent writers + for i := 0; i < 5; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 100; j++ { + path := fmt.Sprintf("path%d", j) + value := fmt.Sprintf("writer%d-value%d", id, j) + if err := cfg.Set(path, value); err != nil { + errors <- fmt.Errorf("writer %d: %v", id, err) + } + } + }(i) + } + + // Concurrent source changes + for i := 0; i < 3; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + sources := []Source{SourceFile, SourceEnv, SourceCLI} + for j := 0; j < 50; j++ { + path := fmt.Sprintf("path%d", j) + source := sources[j%len(sources)] + value := fmt.Sprintf("source%d-value%d", id, j) + if err := cfg.SetSource(path, source, value); err != nil { + errors <- fmt.Errorf("source writer %d: %v", id, err) + } + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for errors + var errs []error + for err := range errors { + errs = append(errs, err) + } + assert.Empty(t, errs, "Concurrent access should not produce errors") +} + +// TestUnregister tests path unregistration +func TestUnregister(t *testing.T) { + cfg := New() + + // Register nested paths + cfg.Register("server.host", "localhost") + cfg.Register("server.port", 8080) + cfg.Register("server.tls.enabled", true) + cfg.Register("server.tls.cert", "/path/to/cert") + cfg.Register("database.host", "dbhost") + + t.Run("UnregisterSinglePath", func(t *testing.T) { + err := cfg.Unregister("server.port") + assert.NoError(t, err) + _, exists := cfg.Get("server.port") + assert.False(t, exists) + + // Other paths should remain + _, exists = cfg.Get("server.host") + assert.True(t, exists) + }) + + t.Run("UnregisterParentPath", func(t *testing.T) { + err := cfg.Unregister("server.tls") + assert.NoError(t, err) + + // All child paths should be removed + _, exists := cfg.Get("server.tls.enabled") + assert.False(t, exists) + _, exists = cfg.Get("server.tls.cert") + assert.False(t, exists) + + // Sibling paths should remain + _, exists = cfg.Get("server.host") + assert.True(t, exists) + }) + + t.Run("UnregisterNonExistentPath", func(t *testing.T) { + err := cfg.Unregister("nonexistent.path") + assert.Error(t, err) + assert.Contains(t, err.Error(), "path not registered") + }) +} + +// TestResetFunctionality tests reset operations +func TestResetFunctionality(t *testing.T) { + cfg := New() + cfg.Register("test1", "default1") + cfg.Register("test2", "default2") + + // Set values in different sources + cfg.SetSource("test1", SourceFile, "file1") + cfg.SetSource("test1", SourceEnv, "env1") + cfg.SetSource("test2", SourceCLI, "cli2") + + t.Run("ResetSingleSource", func(t *testing.T) { + cfg.ResetSource(SourceEnv) + + // Env value should be gone + _, exists := cfg.GetSource("test1", SourceEnv) + assert.False(t, exists) + + // Other sources should remain + val, exists := cfg.GetSource("test1", SourceFile) + assert.True(t, exists) + assert.Equal(t, "file1", val) + }) + + t.Run("ResetAll", func(t *testing.T) { + cfg.Reset() + + // All values should revert to defaults + val1, _ := cfg.Get("test1") + val2, _ := cfg.Get("test2") + assert.Equal(t, "default1", val1) + assert.Equal(t, "default2", val2) + + // Source values should be cleared + sources := cfg.GetSources("test1") + assert.Empty(t, sources) + }) +} + +// TestValueSizeLimit tests the MaxValueSize constraint +func TestValueSizeLimit(t *testing.T) { + cfg := New() + cfg.Register("test", "") + + // Create a value larger than MaxValueSize + largeValue := make([]byte, MaxValueSize+1) + for i := range largeValue { + largeValue[i] = 'x' + } + + err := cfg.Set("test", string(largeValue)) + assert.Error(t, err) + assert.Equal(t, ErrValueSize, err) +} + +// TestGetRegisteredPaths tests path listing functionality +func TestGetRegisteredPaths(t *testing.T) { + cfg := New() + + paths := []string{ + "server.host", + "server.port", + "server.tls.enabled", + "database.host", + "database.port", + "cache.ttl", + } + + for _, path := range paths { + cfg.Register(path, "") + } + + t.Run("GetAllPaths", func(t *testing.T) { + all := cfg.GetRegisteredPaths("") + assert.Len(t, all, len(paths)) + for _, path := range paths { + assert.True(t, all[path]) + } + }) + + t.Run("GetPathsWithPrefix", func(t *testing.T) { + serverPaths := cfg.GetRegisteredPaths("server.") + assert.Len(t, serverPaths, 3) + assert.True(t, serverPaths["server.host"]) + assert.True(t, serverPaths["server.port"]) + assert.True(t, serverPaths["server.tls.enabled"]) + }) + + t.Run("GetPathsWithDefaults", func(t *testing.T) { + defaults := cfg.GetRegisteredPathsWithDefaults("database.") + assert.Len(t, defaults, 2) + assert.Contains(t, defaults, "database.host") + assert.Contains(t, defaults, "database.port") + }) +} \ No newline at end of file diff --git a/convenience.go b/convenience.go index 20fdb2b..1572809 100644 --- a/convenience.go +++ b/convenience.go @@ -1,13 +1,14 @@ -// File: lixenwraith/config/convenience.go +// FILE: lixenwraith/config/convenience.go package config import ( "flag" "fmt" - "github.com/BurntSushi/toml" "os" "reflect" "strings" + + "github.com/BurntSushi/toml" ) // Quick creates a fully configured Config instance with a single call diff --git a/convenience_test.go b/convenience_test.go new file mode 100644 index 0000000..803331b --- /dev/null +++ b/convenience_test.go @@ -0,0 +1,287 @@ +// FILE: lixenwraith/config/convenience_test.go +package config + +import ( + "flag" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestQuickFunctions tests the convenience Quick* functions +func TestQuickFunctions(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "quick.toml") + os.WriteFile(configFile, []byte(` +host = "quickhost" +port = 7777 +`), 0644) + + type QuickConfig struct { + Host string `toml:"host"` + Port int `toml:"port"` + SSL bool `toml:"ssl"` + } + + defaults := &QuickConfig{ + Host: "localhost", + Port: 8080, + SSL: false, + } + + t.Run("Quick", func(t *testing.T) { + // Mock os.Args + oldArgs := os.Args + os.Args = []string{"cmd", "--port=9999"} + defer func() { os.Args = oldArgs }() + + cfg, err := Quick(defaults, "QUICK_", configFile) + require.NoError(t, err) + + // CLI should override + port, _ := cfg.Get("port") + assert.Equal(t, "9999", port) + + // File value + host, _ := cfg.Get("host") + assert.Equal(t, "quickhost", host) + }) + + t.Run("QuickCustom", func(t *testing.T) { + opts := LoadOptions{ + Sources: []Source{SourceFile, SourceDefault}, // Only file and defaults + EnvPrefix: "CUSTOM_", + } + + cfg, err := QuickCustom(defaults, opts, configFile) + require.NoError(t, err) + + // Should use file value + port, _ := cfg.Get("port") + assert.Equal(t, int64(7777), port) + }) + + t.Run("MustQuickPanic", func(t *testing.T) { + // Valid case - should not panic + assert.NotPanics(t, func() { + cfg := MustQuick(defaults, "TEST_", configFile) + assert.NotNil(t, cfg) + }) + + // Invalid struct - should panic + assert.Panics(t, func() { + MustQuick("not-a-struct", "TEST_", configFile) + }) + }) + + t.Run("QuickTyped", func(t *testing.T) { + target := &QuickConfig{ + Host: "typedhost", + Port: 6666, + SSL: true, + } + + cfg, err := QuickTyped(target, "TYPED_", configFile) + require.NoError(t, err) + + // Should populate from file + updated, err := cfg.AsStruct() + require.NoError(t, err) + + typedCfg := updated.(*QuickConfig) + assert.Equal(t, "quickhost", typedCfg.Host) + assert.Equal(t, 7777, typedCfg.Port) + }) +} + +// TestFlagGeneration tests flag generation and binding +func TestFlagGeneration(t *testing.T) { + cfg := New() + cfg.Register("server.host", "localhost") + cfg.Register("server.port", 8080) + cfg.Register("debug.enabled", false) + cfg.Register("timeout", 30.5) + cfg.Register("name", "app") + cfg.Register("complex", map[string]any{"key": "value"}) + + t.Run("GenerateFlags", func(t *testing.T) { + fs := cfg.GenerateFlags() + require.NotNil(t, fs) + + // Verify flags exist + hostFlag := fs.Lookup("server.host") + require.NotNil(t, hostFlag) + assert.Equal(t, "localhost", hostFlag.DefValue) + + portFlag := fs.Lookup("server.port") + require.NotNil(t, portFlag) + assert.Equal(t, "8080", portFlag.DefValue) + + debugFlag := fs.Lookup("debug.enabled") + require.NotNil(t, debugFlag) + assert.Equal(t, "false", debugFlag.DefValue) + + timeoutFlag := fs.Lookup("timeout") + require.NotNil(t, timeoutFlag) + assert.Equal(t, "30.5", timeoutFlag.DefValue) + }) + + t.Run("BindFlags", func(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.String("server.host", "default", "") + fs.Int("server.port", 8080, "") + fs.Bool("debug.enabled", false, "") + + // Parse with test values + err := fs.Parse([]string{"-server.host=flaghost", "-server.port=5555", "-debug.enabled"}) + require.NoError(t, err) + + // Bind to config + err = cfg.BindFlags(fs) + require.NoError(t, err) + + // Verify values were set + host, _ := cfg.Get("server.host") + assert.Equal(t, "flaghost", host) + + port, _ := cfg.Get("server.port") + assert.Equal(t, "5555", port) + + debug, _ := cfg.Get("debug.enabled") + assert.Equal(t, "true", debug) + }) + + t.Run("BindFlagsError", func(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.String("unregistered.path", "value", "") + fs.Parse([]string{"-unregistered.path=test"}) + + err := cfg.BindFlags(fs) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to bind 1 flags") + }) +} + +// TestValidation tests configuration validation +func TestValidation(t *testing.T) { + cfg := New() + cfg.Register("required.host", "") + cfg.Register("required.port", 0) + cfg.Register("optional.timeout", 30) + + t.Run("ValidationFails", func(t *testing.T) { + err := cfg.Validate("required.host", "required.port") + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing required configuration") + assert.Contains(t, err.Error(), "required.host") + assert.Contains(t, err.Error(), "required.port") + }) + + t.Run("ValidationPasses", func(t *testing.T) { + cfg.Set("required.host", "localhost") + cfg.Set("required.port", 8080) + + err := cfg.Validate("required.host", "required.port") + assert.NoError(t, err) + }) + + t.Run("ValidationUnregisteredPath", func(t *testing.T) { + err := cfg.Validate("nonexistent.path") + assert.Error(t, err) + assert.Contains(t, err.Error(), "nonexistent.path (not registered)") + }) + + t.Run("ValidationWithSourceValue", func(t *testing.T) { + cfg2 := New() + cfg2.Register("test", "default") + + // Value equals default but from different source + cfg2.SetSource("test", SourceEnv, "default") + + err := cfg2.Validate("test") + assert.NoError(t, err) // Should pass because env provided value + }) +} + +// TestDebugAndDump tests debug output functions +func TestDebugAndDump(t *testing.T) { + cfg := New() + cfg.Register("server.host", "localhost") + cfg.Register("server.port", 8080) + + cfg.SetSource("server.host", SourceFile, "filehost") + cfg.SetSource("server.host", SourceEnv, "envhost") + cfg.SetSource("server.port", SourceCLI, "9999") + + t.Run("Debug", func(t *testing.T) { + debug := cfg.Debug() + + assert.Contains(t, debug, "Configuration Debug Info") + assert.Contains(t, debug, "Precedence:") + assert.Contains(t, debug, "server.host:") + assert.Contains(t, debug, "Current: envhost") + assert.Contains(t, debug, "Default: localhost") + assert.Contains(t, debug, "file: filehost") + assert.Contains(t, debug, "env: envhost") + }) + + t.Run("Dump", func(t *testing.T) { + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + err := cfg.Dump() + assert.NoError(t, err) + + w.Close() + os.Stdout = oldStdout + + // Read output + output := make([]byte, 1024) + n, _ := r.Read(output) + outputStr := string(output[:n]) + + assert.Contains(t, outputStr, "[server]") + assert.Contains(t, outputStr, "host = ") + assert.Contains(t, outputStr, "port = ") + }) +} + +// TestClone tests configuration cloning +func TestClone(t *testing.T) { + cfg := New() + cfg.Register("original.value", "default") + cfg.Register("shared.value", "shared") + + cfg.SetSource("original.value", SourceFile, "filevalue") + cfg.SetSource("shared.value", SourceEnv, "envvalue") + + clone := cfg.Clone() + require.NotNil(t, clone) + + // Verify values are copied + val, exists := clone.Get("original.value") + assert.True(t, exists) + assert.Equal(t, "filevalue", val) + + val, exists = clone.Get("shared.value") + assert.True(t, exists) + assert.Equal(t, "envvalue", val) + + // Modify clone should not affect original + clone.Set("original.value", "clonevalue") + + originalVal, _ := cfg.Get("original.value") + cloneVal, _ := clone.Get("original.value") + + assert.Equal(t, "filevalue", originalVal) + assert.Equal(t, "clonevalue", cloneVal) + + // Verify source data is copied + sources := clone.GetSources("shared.value") + assert.Equal(t, "envvalue", sources[SourceEnv]) +} \ No newline at end of file diff --git a/decode.go b/decode.go index d61bd91..af4d6f6 100644 --- a/decode.go +++ b/decode.go @@ -57,7 +57,7 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error { // Create decoder with comprehensive hooks decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ Result: target, - TagName: "toml", + TagName: c.tagName, WeaklyTypedInput: true, DecodeHook: c.getDecodeHook(), ZeroFields: true, @@ -77,16 +77,16 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error { // getDecodeHook returns the composite decode hook for all type conversions func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc { return mapstructure.ComposeDecodeHookFunc( - // Standard hooks - mapstructure.StringToTimeDurationHookFunc(), - mapstructure.StringToTimeHookFunc(time.RFC3339), - mapstructure.StringToSliceHookFunc(","), - // Network types stringToNetIPHookFunc(), stringToNetIPNetHookFunc(), stringToURLHookFunc(), + // Standard hooks + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToTimeHookFunc(time.RFC3339), + mapstructure.StringToSliceHookFunc(","), + // Custom application hooks c.customDecodeHook(), ) @@ -94,7 +94,7 @@ func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc { // stringToNetIPHookFunc handles net.IP conversion func stringToNetIPHookFunc() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String { return data, nil } @@ -120,27 +120,28 @@ func stringToNetIPHookFunc() mapstructure.DecodeHookFunc { // stringToNetIPNetHookFunc handles net.IPNet conversion func stringToNetIPNetHookFunc() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String { return data, nil } - - if t != reflect.TypeOf(net.IPNet{}) && t != reflect.TypeOf(&net.IPNet{}) { + isPtr := t.Kind() == reflect.Ptr + targetType := t + if isPtr { + targetType = t.Elem() + } + if targetType != reflect.TypeOf(net.IPNet{}) { return data, nil } str := data.(string) - // SECURITY: Validate CIDR format if len(str) > 49 { // Max IPv6 CIDR length return nil, fmt.Errorf("invalid CIDR length: %d", len(str)) } - _, ipnet, err := net.ParseCIDR(str) if err != nil { return nil, fmt.Errorf("invalid CIDR: %w", err) } - - if t == reflect.TypeOf(&net.IPNet{}) { + if isPtr { return ipnet, nil } return *ipnet, nil @@ -149,27 +150,28 @@ func stringToNetIPNetHookFunc() mapstructure.DecodeHookFunc { // stringToURLHookFunc handles url.URL conversion func stringToURLHookFunc() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String { return data, nil } - - if t != reflect.TypeOf(url.URL{}) && t != reflect.TypeOf(&url.URL{}) { + isPtr := t.Kind() == reflect.Ptr + targetType := t + if isPtr { + targetType = t.Elem() + } + if targetType != reflect.TypeOf(url.URL{}) { return data, nil } str := data.(string) - // SECURITY: Validate URL length to prevent DoS if len(str) > 2048 { return nil, fmt.Errorf("URL too long: %d bytes", len(str)) } - u, err := url.Parse(str) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } - - if t == reflect.TypeOf(&url.URL{}) { + if isPtr { return u, nil } return *u, nil @@ -178,7 +180,7 @@ func stringToURLHookFunc() mapstructure.DecodeHookFunc { // customDecodeHook allows for application-specific type conversions func (c *Config) customDecodeHook() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + return func(f reflect.Type, t reflect.Type, data any) (any, error) { // SECURITY: Add custom validation for application types here // Example: Rate limit parsing, permission validation, etc. @@ -215,4 +217,4 @@ func navigateToPath(nested map[string]any, path string) any { } return current -} \ No newline at end of file +} diff --git a/decode_test.go b/decode_test.go new file mode 100644 index 0000000..e91466c --- /dev/null +++ b/decode_test.go @@ -0,0 +1,328 @@ +// FILE: lixenwraith/config/decode_test.go +package config + +import ( + "net" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestScanWithComplexTypes tests scanning with various complex types +func TestScanWithComplexTypes(t *testing.T) { + type NetworkConfig struct { + IP net.IP `toml:"ip"` + IPNet *net.IPNet `toml:"subnet"` + URL *url.URL `toml:"endpoint"` + Timeout time.Duration `toml:"timeout"` + Retry struct { + Count int `toml:"count"` + Interval time.Duration `toml:"interval"` + } `toml:"retry"` + } + + type AppConfig struct { + Network NetworkConfig `toml:"network"` + Tags []string `toml:"tags"` + Ports []int `toml:"ports"` + Labels map[string]string `toml:"labels"` + } + + cfg := New() + + // Register with defaults + defaults := &AppConfig{ + Network: NetworkConfig{ + IP: net.ParseIP("127.0.0.1"), + Timeout: 30 * time.Second, + }, + Tags: []string{"default"}, + Ports: []int{8080}, + Labels: map[string]string{ + "env": "dev", + }, + } + + err := cfg.RegisterStruct("", defaults) + require.NoError(t, err) + + // Set values from different sources + cfg.SetSource("network.ip", SourceEnv, "192.168.1.100") + cfg.SetSource("network.subnet", SourceEnv, "192.168.1.0/24") + cfg.SetSource("network.endpoint", SourceEnv, "https://api.example.com:8443/v1") + cfg.SetSource("network.timeout", SourceFile, "2m30s") + cfg.SetSource("network.retry.count", SourceFile, int64(5)) + cfg.SetSource("network.retry.interval", SourceFile, "10s") + cfg.SetSource("tags", SourceCLI, "prod,staging,test") + cfg.SetSource("ports", SourceFile, []any{int64(80), int64(443), int64(8080)}) + cfg.SetSource("labels", SourceFile, map[string]any{ + "env": "production", + "version": "1.2.3", + }) + + // Scan into struct + var result AppConfig + err = cfg.Scan("", &result) + require.NoError(t, err) + + // Verify conversions + assert.Equal(t, "192.168.1.100", result.Network.IP.String()) + assert.Equal(t, "192.168.1.0/24", result.Network.IPNet.String()) + assert.Equal(t, "https://api.example.com:8443/v1", result.Network.URL.String()) + assert.Equal(t, 150*time.Second, result.Network.Timeout) + assert.Equal(t, 5, result.Network.Retry.Count) + assert.Equal(t, 10*time.Second, result.Network.Retry.Interval) + assert.Equal(t, []string{"prod", "staging", "test"}, result.Tags) + assert.Equal(t, []int{80, 443, 8080}, result.Ports) + assert.Equal(t, "production", result.Labels["env"]) + assert.Equal(t, "1.2.3", result.Labels["version"]) +} + +// TestScanWithBasePath tests scanning from nested paths +func TestScanWithBasePath(t *testing.T) { + type ServerConfig struct { + Host string `toml:"host"` + Port int `toml:"port"` + Enabled bool `toml:"enabled"` + } + + cfg := New() + cfg.Register("app.server.host", "localhost") + cfg.Register("app.server.port", 8080) + cfg.Register("app.server.enabled", true) + cfg.Register("app.database.host", "dbhost") + + cfg.Set("app.server.host", "appserver") + cfg.Set("app.server.port", 9000) + + // Scan only the server section + var server ServerConfig + err := cfg.Scan("app.server", &server) + require.NoError(t, err) + + assert.Equal(t, "appserver", server.Host) + assert.Equal(t, 9000, server.Port) + assert.Equal(t, true, server.Enabled) + + // Test non-existent base path + var empty ServerConfig + err = cfg.Scan("app.nonexistent", &empty) + assert.NoError(t, err) // Should not error, just empty + assert.Equal(t, "", empty.Host) + assert.Equal(t, 0, empty.Port) +} + +// TestScanFromSource tests scanning from specific sources +func TestScanFromSource(t *testing.T) { + type Config struct { + Value string `toml:"value"` + } + + cfg := New() + cfg.Register("value", "default") + + cfg.SetSource("value", SourceFile, "fromfile") + cfg.SetSource("value", SourceEnv, "fromenv") + cfg.SetSource("value", SourceCLI, "fromcli") + + tests := []struct { + source Source + expected string + }{ + {SourceFile, "fromfile"}, + {SourceEnv, "fromenv"}, + {SourceCLI, "fromcli"}, + {SourceDefault, ""}, // No value in default source + } + + for _, tt := range tests { + t.Run(string(tt.source), func(t *testing.T) { + var result Config + err := cfg.ScanSource("", tt.source, &result) + require.NoError(t, err) + assert.Equal(t, tt.expected, result.Value) + }) + } +} + +// TestInvalidScanTargets tests error cases for scanning +func TestInvalidScanTargets(t *testing.T) { + cfg := New() + cfg.Register("test", "value") + + tests := []struct { + name string + target any + expectErr string + }{ + {"NilPointer", nil, "must be non-nil pointer"}, + {"NonPointer", "not-a-pointer", "must be non-nil pointer"}, + {"NilStructPointer", (*struct{})(nil), "must be non-nil pointer"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := cfg.Scan("", tt.target) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErr) + }) + } +} + +// TestCustomTypeConversion tests edge cases in type conversion +func TestCustomTypeConversion(t *testing.T) { + cfg := New() + + t.Run("InvalidIPAddress", func(t *testing.T) { + type Config struct { + IP net.IP `toml:"ip"` + } + + cfg.Register("ip", net.IP{}) + cfg.Set("ip", "not-an-ip") + + var result Config + err := cfg.Scan("", &result) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid IP address") + }) + + t.Run("InvalidCIDR", func(t *testing.T) { + type Config struct { + Network *net.IPNet `toml:"network"` + } + + cfg.Register("network", (*net.IPNet)(nil)) + cfg.Set("network", "invalid-cidr") + + var result Config + err := cfg.Scan("", &result) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid CIDR") + }) + + t.Run("InvalidURL", func(t *testing.T) { + type Config struct { + Endpoint *url.URL `toml:"endpoint"` + } + + cfg.Register("endpoint", (*url.URL)(nil)) + cfg.Set("endpoint", "://invalid-url") + + var result Config + err := cfg.Scan("", &result) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid URL") + }) + + t.Run("LongIPString", func(t *testing.T) { + type Config struct { + IP net.IP `toml:"ip"` + } + + cfg.Register("ip", net.IP{}) + // String longer than max IPv6 length + longIP := make([]byte, 50) + for i := range longIP { + longIP[i] = 'x' + } + cfg.Set("ip", string(longIP)) + + var result Config + err := cfg.Scan("", &result) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid IP length") + }) + + t.Run("LongURL", func(t *testing.T) { + type Config struct { + URL *url.URL `toml:"url"` + } + + cfg.Register("url", (*url.URL)(nil)) + // URL longer than 2048 bytes + longURL := "https://example.com/" + for i := 0; i < 2048; i++ { + longURL += "x" + } + cfg.Set("url", longURL) + + var result Config + err := cfg.Scan("", &result) + assert.Error(t, err) + assert.Contains(t, err.Error(), "URL too long") + }) +} + +// TestZeroFields tests that ZeroFields option works correctly +func TestZeroFields(t *testing.T) { + type Config struct { + KeepValue string `toml:"keep"` + ResetValue string `toml:"reset"` + NestedValue struct { + Field string `toml:"field"` + } `toml:"nested"` + } + + cfg := New() + + // Register only some fields + cfg.Register("keep", "keepdefault") + cfg.Register("reset", "resetdefault") + // Don't register nested.field + + cfg.Set("keep", "newvalue") + // Don't set reset, so it uses default + + // Start with non-zero struct + result := Config{ + KeepValue: "initial", + ResetValue: "initial", + NestedValue: struct { + Field string `toml:"field"` + }{Field: "initial"}, + } + + err := cfg.Scan("", &result) + require.NoError(t, err) + + // ZeroFields should reset all fields before decoding + assert.Equal(t, "newvalue", result.KeepValue) + assert.Equal(t, "resetdefault", result.ResetValue) + assert.Equal(t, "initial", result.NestedValue.Field) // Unregistered, so Scan should not touch it +} + +// TestWeaklyTypedInput tests weak type conversion +func TestWeaklyTypedInput(t *testing.T) { + type Config struct { + IntFromString int `toml:"int_from_string"` + FloatFromString float64 `toml:"float_from_string"` + BoolFromString bool `toml:"bool_from_string"` + StringFromInt string `toml:"string_from_int"` + StringFromBool string `toml:"string_from_bool"` + } + + cfg := New() + defaults := &Config{} + cfg.RegisterStruct("", defaults) + + // Set string values that should convert + cfg.Set("int_from_string", "42") + cfg.Set("float_from_string", "3.14159") + cfg.Set("bool_from_string", "true") + cfg.Set("string_from_int", 12345) + cfg.Set("string_from_bool", true) + + var result Config + err := cfg.Scan("", &result) + require.NoError(t, err) + + assert.Equal(t, 42, result.IntFromString) + assert.Equal(t, 3.14159, result.FloatFromString) + assert.Equal(t, true, result.BoolFromString) + assert.Equal(t, "12345", result.StringFromInt) + assert.Equal(t, "1", result.StringFromBool) // mapstructure converts bool(true) to "1" in weak conversion +} \ No newline at end of file diff --git a/doc/access.md b/doc/access.md index 9699edf..4da25c8 100644 --- a/doc/access.md +++ b/doc/access.md @@ -138,7 +138,7 @@ cfg.SetSource("feature.enabled", config.SourceFile, true) ```go // Multiple updates -updates := map[string]interface{}{ +updates := map[string]any{ "server.port": int64(9090), "server.host": "0.0.0.0", "database.maxconns": int64(50), @@ -310,7 +310,7 @@ func (f *ConfigFacade) DatabaseURL() string { ```go // Helper for optional configuration -func getOrDefault(cfg *config.Config, path string, defaultVal interface{}) interface{} { +func getOrDefault(cfg *config.Config, path string, defaultVal any) any { if val, exists := cfg.Get(path); exists { return val } diff --git a/go.mod b/go.mod index 6db8229..ffaa49d 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,10 @@ -module github.com/lixenwraith/config +module config go 1.24.5 require ( github.com/BurntSushi/toml v1.5.0 + github.com/lixenwraith/config v0.0.0-20250719015120-e02ee494d440 github.com/mitchellh/mapstructure v1.5.0 github.com/stretchr/testify v1.10.0 ) @@ -13,3 +14,5 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/mitchellh/mapstructure => github.com/go-viper/mapstructure v1.6.0 diff --git a/go.sum b/go.sum index 117ebb2..e2ab293 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,10 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= -github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw= +github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw= +github.com/lixenwraith/config v0.0.0-20250719015120-e02ee494d440 h1:O6nHnpeDfIYQ1WxCtA2gkm8upQ4RW21DUMlQz5DKJCU= +github.com/lixenwraith/config v0.0.0-20250719015120-e02ee494d440/go.mod h1:y7kgDrWIFROWJJ6ASM/SPTRRAj27FjRGWh2SDLcdQ68= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/helper.go b/helper.go index 5529324..f1c2de8 100644 --- a/helper.go +++ b/helper.go @@ -1,4 +1,4 @@ -// File: lixenwraith/config/helper.go +// FILE: lixenwraith/config/helper.go package config import "strings" diff --git a/loader.go b/loader.go index 46ce000..a2f56dc 100644 --- a/loader.go +++ b/loader.go @@ -1,4 +1,4 @@ -// File: lixenwraith/config/loader.go +// FILE: lixenwraith/config/loader.go package config import ( @@ -481,7 +481,6 @@ func parseArgs(args []string) (map[string]any, error) { continue } - // Remove the leading "--" argContent := strings.TrimPrefix(arg, "--") if argContent == "" { // Skip "--" argument if used as a separator @@ -501,20 +500,22 @@ func parseArgs(args []string) (map[string]any, error) { } else { // Handle "--key value" or "--booleanflag" keyPath = argContent - // Check if it's potentially a boolean flag - isBoolFlag := i+1 >= len(args) || strings.HasPrefix(args[i+1], "--") - - if isBoolFlag { - // Assume boolean flag is true if no value follows + // Check if it's a boolean flag (next arg is another flag or end of args) + if i+1 >= len(args) || strings.HasPrefix(args[i+1], "--") { valueStr = "true" i++ // Consume only the flag argument } else { - // Potential key-value pair with space separation + // It's a key-value pair with a space valueStr = args[i+1] - i += 2 // Consume flag and value arguments + i += 2 // Consume both flag and value arguments } } + if keyPath == "" { + // Skip invalid flags like --=value + continue + } + // Validate keyPath segments segments := strings.Split(keyPath, ".") for _, segment := range segments { @@ -523,9 +524,8 @@ func parseArgs(args []string) (map[string]any, error) { } } - // Parse the value - value := parseValue(valueStr) - setNestedValue(result, keyPath, value) + // Always store as a string. Let Scan handle final type conversion. + setNestedValue(result, keyPath, valueStr) } return result, nil diff --git a/loader_test.go b/loader_test.go new file mode 100644 index 0000000..44eb38c --- /dev/null +++ b/loader_test.go @@ -0,0 +1,434 @@ +// FILE: lixenwraith/config/loader_test.go +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestFileLoading tests TOML file loading +func TestFileLoading(t *testing.T) { + tmpDir := t.TempDir() + + t.Run("ValidTOMLFile", func(t *testing.T) { + configFile := filepath.Join(tmpDir, "valid.toml") + content := ` +# Server configuration +[server] +host = "example.com" +port = 9000 +enabled = true + +[server.tls] +cert = "/path/to/cert.pem" +key = "/path/to/key.pem" + +[database] +connections = [1, 2, 3] +tags = ["primary", "replica"] +` + os.WriteFile(configFile, []byte(content), 0644) + + cfg := New() + // Register all paths + cfg.Register("server.host", "localhost") + cfg.Register("server.port", 8080) + cfg.Register("server.enabled", false) + cfg.Register("server.tls.cert", "") + cfg.Register("server.tls.key", "") + cfg.Register("database.connections", []int{}) + cfg.Register("database.tags", []string{}) + + err := cfg.LoadFile(configFile) + require.NoError(t, err) + + // Verify loaded values + host, _ := cfg.Get("server.host") + assert.Equal(t, "example.com", host) + + port, _ := cfg.Get("server.port") + assert.Equal(t, int64(9000), port) + + enabled, _ := cfg.Get("server.enabled") + assert.Equal(t, true, enabled) + + cert, _ := cfg.Get("server.tls.cert") + assert.Equal(t, "/path/to/cert.pem", cert) + + // Arrays are loaded as []any + connections, _ := cfg.Get("database.connections") + assert.Equal(t, []any{int64(1), int64(2), int64(3)}, connections) + }) + + t.Run("InvalidTOMLFile", func(t *testing.T) { + configFile := filepath.Join(tmpDir, "invalid.toml") + os.WriteFile(configFile, []byte(`invalid = toml content`), 0644) + + cfg := New() + err := cfg.LoadFile(configFile) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse TOML") + }) + + t.Run("NonExistentFile", func(t *testing.T) { + cfg := New() + err := cfg.LoadFile("/non/existent/file.toml") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrConfigNotFound) + }) + + t.Run("UnregisteredPathsIgnored", func(t *testing.T) { + configFile := filepath.Join(tmpDir, "extra.toml") + os.WriteFile(configFile, []byte(` +registered = "value" +unregistered = "ignored" +`), 0644) + + cfg := New() + cfg.Register("registered", "") + + err := cfg.LoadFile(configFile) + require.NoError(t, err) + + val, exists := cfg.Get("registered") + assert.True(t, exists) + assert.Equal(t, "value", val) + + _, exists = cfg.Get("unregistered") + assert.False(t, exists) + }) +} + +// TestEnvironmentLoading tests environment variable loading +func TestEnvironmentLoading(t *testing.T) { + // Save and restore environment + originalEnv := os.Environ() + defer func() { + os.Clearenv() + for _, e := range originalEnv { + parts := splitEnvVar(e) + if len(parts) == 2 { + os.Setenv(parts[0], parts[1]) + } + } + }() + + t.Run("DefaultEnvTransform", func(t *testing.T) { + cfg := New() + cfg.Register("server.host", "localhost") + cfg.Register("server.port", 8080) + cfg.Register("enable_debug", false) + + os.Setenv("APP_SERVER_HOST", "envhost") + os.Setenv("APP_SERVER_PORT", "9090") + os.Setenv("APP_ENABLE_DEBUG", "true") + + err := cfg.LoadEnv("APP_") + require.NoError(t, err) + + host, _ := cfg.Get("server.host") + assert.Equal(t, "envhost", host) + + port, _ := cfg.Get("server.port") + assert.Equal(t, "9090", port) // String from env + + debug, _ := cfg.Get("enable_debug") + assert.Equal(t, "true", debug) // String from env + }) + + t.Run("CustomEnvTransform", func(t *testing.T) { + cfg := New() + cfg.Register("db.host", "localhost") + + os.Setenv("DATABASE_HOSTNAME", "customhost") + + opts := LoadOptions{ + Sources: []Source{SourceEnv, SourceDefault}, + EnvTransform: func(path string) string { + if path == "db.host" { + return "DATABASE_HOSTNAME" + } + return path + }, + } + + err := cfg.LoadWithOptions("", nil, opts) + require.NoError(t, err) + + host, _ := cfg.Get("db.host") + assert.Equal(t, "customhost", host) + }) + + t.Run("EnvWhitelist", func(t *testing.T) { + cfg := New() + cfg.Register("allowed.path", "default1") + cfg.Register("blocked.path", "default2") + + os.Setenv("ALLOWED_PATH", "env1") + os.Setenv("BLOCKED_PATH", "env2") + + opts := LoadOptions{ + Sources: []Source{SourceEnv, SourceDefault}, + EnvWhitelist: map[string]bool{"allowed.path": true}, + } + + err := cfg.LoadWithOptions("", nil, opts) + require.NoError(t, err) + + allowed, _ := cfg.Get("allowed.path") + assert.Equal(t, "env1", allowed) + + blocked, _ := cfg.Get("blocked.path") + assert.Equal(t, "default2", blocked) // Should not load from env + }) + + t.Run("DiscoverEnv", func(t *testing.T) { + cfg := New() + cfg.Register("test.one", "") + cfg.Register("test.two", "") + cfg.Register("other.value", "") + + os.Setenv("PREFIX_TEST_ONE", "value1") + os.Setenv("PREFIX_TEST_TWO", "value2") + os.Setenv("PREFIX_OTHER_VALUE", "value3") + os.Setenv("UNRELATED_VAR", "ignored") + + discovered := cfg.DiscoverEnv("PREFIX_") + assert.Len(t, discovered, 3) + assert.Equal(t, "PREFIX_TEST_ONE", discovered["test.one"]) + assert.Equal(t, "PREFIX_TEST_TWO", discovered["test.two"]) + assert.Equal(t, "PREFIX_OTHER_VALUE", discovered["other.value"]) + }) +} + +// TestCLIParsing tests command-line argument parsing +func TestCLIParsing(t *testing.T) { + tests := []struct { + name string + args []string + expected map[string]any + }{ + { + name: "KeyValueWithEquals", + args: []string{"--server.host=example.com", "--server.port=9000"}, + expected: map[string]any{ + "server.host": "example.com", + "server.port": "9000", + }, + }, + { + name: "KeyValueWithSpace", + args: []string{"--server.host", "example.com", "--server.port", "9000"}, + expected: map[string]any{ + "server.host": "example.com", + "server.port": "9000", + }, + }, + { + name: "BooleanFlags", + args: []string{"--enable.debug", "--disable.cache", "false"}, + expected: map[string]any{ + "enable.debug": "true", + "disable.cache": "false", + }, + }, + { + name: "MixedFormats", + args: []string{ + "--server.host=localhost", + "--server.port", "8080", + "--enable.tls", + "--database.pool.size=10", + }, + expected: map[string]any{ + "server.host": "localhost", + "server.port": "8080", + "enable.tls": "true", + "database.pool.size": "10", + }, + }, + { + name: "EmptyAndInvalidArgs", + args: []string{"", "--", "---", "--=value"}, + expected: map[string]any{ + "": "value", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := New() + + // Register expected paths + for path := range tt.expected { + if path != "" { // Skip empty path + cfg.Register(path, "") + } + } + + err := cfg.LoadCLI(tt.args) + require.NoError(t, err) + + // Verify values + for path, expected := range tt.expected { + if path != "" { + val, exists := cfg.Get(path) + assert.True(t, exists, "Path %s should exist", path) + assert.Equal(t, expected, val) + } + } + }) + } + + t.Run("InvalidKeySegment", func(t *testing.T) { + result, err := parseArgs([]string{"--invalid!key=value"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid command-line key segment") + assert.Nil(t, result) + }) +} + +// TestLoadWithOptions tests complete loading with multiple sources +func TestLoadWithOptions(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.toml") + os.WriteFile(configFile, []byte(` +[server] +host = "filehost" +port = 8080 +`), 0644) + + os.Setenv("TEST_SERVER_HOST", "envhost") + os.Setenv("TEST_SERVER_PORT", "9090") + defer func() { + os.Unsetenv("TEST_SERVER_HOST") + os.Unsetenv("TEST_SERVER_PORT") + }() + + cfg := New() + cfg.Register("server.host", "defaulthost") + cfg.Register("server.port", 3000) + + args := []string{"--server.port=7070"} + + opts := LoadOptions{ + Sources: []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault}, + EnvPrefix: "TEST_", + } + + err := cfg.LoadWithOptions(configFile, args, opts) + require.NoError(t, err) + + // CLI should win + port, _ := cfg.Get("server.port") + assert.Equal(t, "7070", port) + + // ENV should win over file + host, _ := cfg.Get("server.host") + assert.Equal(t, "envhost", host) + + // Test source inspection + sources := cfg.GetSources("server.port") + assert.Equal(t, "7070", sources[SourceCLI]) + assert.Equal(t, "9090", sources[SourceEnv]) + assert.Equal(t, int64(8080), sources[SourceFile]) +} + +// TestAtomicSave tests atomic file saving +func TestAtomicSave(t *testing.T) { + tmpDir := t.TempDir() + + cfg := New() + cfg.Register("server.host", "localhost") + cfg.Register("server.port", 8080) + cfg.Register("database.url", "postgres://localhost/db") + + // Set some values + cfg.Set("server.host", "savehost") + cfg.Set("server.port", 9999) + + t.Run("SaveCurrentState", func(t *testing.T) { + savePath := filepath.Join(tmpDir, "saved.toml") + err := cfg.Save(savePath) + require.NoError(t, err) + + // Verify file exists and is readable + content, err := os.ReadFile(savePath) + require.NoError(t, err) + assert.Contains(t, string(content), "savehost") + assert.Contains(t, string(content), "9999") + + // Load into new config to verify + cfg2 := New() + cfg2.Register("server.host", "") + cfg2.Register("server.port", 0) + err = cfg2.LoadFile(savePath) + require.NoError(t, err) + + host, _ := cfg2.Get("server.host") + assert.Equal(t, "savehost", host) + }) + + t.Run("SaveSpecificSource", func(t *testing.T) { + cfg.SetSource("server.host", SourceEnv, "envhost") + cfg.SetSource("server.port", SourceEnv, "7777") + cfg.SetSource("server.port", SourceFile, "6666") + + savePath := filepath.Join(tmpDir, "env-only.toml") + err := cfg.SaveSource(savePath, SourceEnv) + require.NoError(t, err) + + content, err := os.ReadFile(savePath) + require.NoError(t, err) + assert.Contains(t, string(content), "envhost") + assert.Contains(t, string(content), "7777") + assert.NotContains(t, string(content), "6666") + }) + + t.Run("SaveToNonExistentDirectory", func(t *testing.T) { + savePath := filepath.Join(tmpDir, "new", "dir", "config.toml") + err := cfg.Save(savePath) + require.NoError(t, err) + + // Verify file was created + _, err = os.Stat(savePath) + assert.NoError(t, err) + }) +} + +// TestExportEnv tests environment variable export +func TestExportEnv(t *testing.T) { + cfg := New() + cfg.Register("server.host", "defaulthost") + cfg.Register("server.port", 8080) + cfg.Register("feature.enabled", false) + + // Only export non-default values + cfg.Set("server.host", "exporthost") + cfg.Set("feature.enabled", true) + + exports := cfg.ExportEnv("APP_") + + assert.Len(t, exports, 2) + assert.Equal(t, "exporthost", exports["APP_SERVER_HOST"]) + assert.Equal(t, "true", exports["APP_FEATURE_ENABLED"]) + assert.NotContains(t, exports, "APP_SERVER_PORT") // Still default +} + +// splitEnvVar splits environment variable into key and value +func splitEnvVar(env string) []string { + parts := make([]string, 2) + for i := 0; i < len(env); i++ { + if env[i] == '=' { + parts[0] = env[:i] + parts[1] = env[i+1:] + return parts + } + } + return []string{env} +} \ No newline at end of file diff --git a/register.go b/register.go index e10e589..a546c98 100644 --- a/register.go +++ b/register.go @@ -1,4 +1,4 @@ -// File: lixenwraith/config/register.go +// FILE: lixenwraith/config/register.go package config import ( @@ -186,20 +186,30 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e isPtrToStruct := fieldValue.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct if isStruct || isPtrToStruct { - // Dereference pointer if necessary - nestedValue := fieldValue - if isPtrToStruct { - if fieldValue.IsNil() { - // Skip nil pointers - continue - } - nestedValue = fieldValue.Elem() + // Check if the field's TYPE is one that should be treated as a single value, + // even though it's a struct. These types have custom decode hooks. + fieldType := fieldValue.Type() + isAtomicStruct := false + switch fieldType.String() { + case "time.Time", "*net.IPNet", "*url.URL", "net.IP": // Match the exact type names + isAtomicStruct = true } - // For nested structs, append a dot and continue recursion - nestedPrefix := currentPath + "." - c.registerFields(nestedValue, nestedPrefix, fieldPath+field.Name+".", errors, tagName) - continue + // Only recurse if it's a "normal" struct, not an atomic one. + if !isAtomicStruct { + nestedValue := fieldValue + if isPtrToStruct { + if fieldValue.IsNil() { + continue // Skip nil pointers in the default struct + } + nestedValue = fieldValue.Elem() + } + + nestedPrefix := currentPath + "." + c.registerFields(nestedValue, nestedPrefix, fieldPath+field.Name+".", errors, tagName) + continue + } + // If it is an atomic struct, we fall through and register it as a single value. } // Register non-struct fields diff --git a/watch.go b/watch.go index f18396b..d1a47ab 100644 --- a/watch.go +++ b/watch.go @@ -1,15 +1,16 @@ -// File: lixenwraith/config/watch.go +// FILE: lixenwraith/config/watch.go package config import ( "context" "fmt" - "github.com/BurntSushi/toml" "os" "reflect" "sync" "sync/atomic" "time" + + "github.com/BurntSushi/toml" ) // WatchOptions configures file watching behavior diff --git a/watch_test.go b/watch_test.go index 6c92196..6a6a1b0 100644 --- a/watch_test.go +++ b/watch_test.go @@ -9,10 +9,14 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// TestAutoUpdate tests automatic configuration reloading func TestAutoUpdate(t *testing.T) { - // Create temporary config file + // Setup tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "test.toml") @@ -24,10 +28,7 @@ host = "localhost" [features] enabled = true ` - - if err := os.WriteFile(configPath, []byte(initialConfig), 0644); err != nil { - t.Fatal("Failed to write initial config:", err) - } + require.NoError(t, os.WriteFile(configPath, []byte(initialConfig), 0644)) // Create config with defaults type TestConfig struct { @@ -49,14 +50,12 @@ enabled = true WithDefaults(defaults). WithFile(configPath). Build() - if err != nil { - t.Fatal("Failed to build config:", err) - } + require.NoError(t, err) // Verify initial values - if port, _ := cfg.Get("server.port"); port.(int64) != 8080 { - t.Errorf("Expected port 8080, got %d", port) - } + port, exists := cfg.Get("server.port") + assert.True(t, exists) + assert.Equal(t, int64(8080), port) // Enable auto-update with fast polling opts := WatchOptions{ @@ -91,26 +90,20 @@ host = "0.0.0.0" [features] enabled = false ` - - if err := os.WriteFile(configPath, []byte(updatedConfig), 0644); err != nil { - t.Fatal("Failed to update config:", err) - } + require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644)) // Wait for changes to be detected time.Sleep(300 * time.Millisecond) // Verify new values - if port, _ := cfg.Get("server.port"); port.(int64) != 9090 { - t.Errorf("Expected port 9090 after update, got %d", port) - } + port, _ = cfg.Get("server.port") + assert.Equal(t, int64(9090), port) - if host, _ := cfg.Get("server.host"); host.(string) != "0.0.0.0" { - t.Errorf("Expected host 0.0.0.0 after update, got %s", host) - } + host, _ := cfg.Get("server.host") + assert.Equal(t, "0.0.0.0", host) - if enabled, _ := cfg.Get("features.enabled"); enabled.(bool) != false { - t.Errorf("Expected features.enabled to be false after update") - } + enabled, _ := cfg.Get("features.enabled") + assert.Equal(t, false, enabled) // Check that changes were notified mu.Lock() @@ -118,27 +111,22 @@ enabled = false expectedChanges := []string{"server.port", "server.host", "features.enabled"} for _, path := range expectedChanges { - if !changedPaths[path] { - t.Errorf("Expected change notification for %s", path) - } + assert.True(t, changedPaths[path], "Expected change notification for %s", path) } } +// TestWatchFileDeleted tests behavior when config file is deleted func TestWatchFileDeleted(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "test.toml") // Create initial config - if err := os.WriteFile(configPath, []byte(`test = "value"`), 0644); err != nil { - t.Fatal("Failed to write config:", err) - } + require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644)) cfg := New() cfg.Register("test", "default") - if err := cfg.LoadFile(configPath); err != nil { - t.Fatal("Failed to load config:", err) - } + require.NoError(t, cfg.LoadFile(configPath)) // Enable watching opts := WatchOptions{ @@ -151,21 +139,18 @@ func TestWatchFileDeleted(t *testing.T) { changes := cfg.Watch() // Delete file - if err := os.Remove(configPath); err != nil { - t.Fatal("Failed to delete config:", err) - } + require.NoError(t, os.Remove(configPath)) // Wait for deletion detection select { case path := <-changes: - if path != "file_deleted" { - t.Errorf("Expected file_deleted, got %s", path) - } + assert.Equal(t, "file_deleted", path) case <-time.After(500 * time.Millisecond): t.Error("Timeout waiting for deletion notification") } } +// TestWatchPermissionChange tests permission change detection func TestWatchPermissionChange(t *testing.T) { // Skip on Windows where permission model is different if runtime.GOOS == "windows" { @@ -176,16 +161,11 @@ func TestWatchPermissionChange(t *testing.T) { configPath := filepath.Join(tmpDir, "test.toml") // Create config with specific permissions - if err := os.WriteFile(configPath, []byte(`test = "value"`), 0644); err != nil { - t.Fatal("Failed to write config:", err) - } + require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644)) cfg := New() cfg.Register("test", "default") - - if err := cfg.LoadFile(configPath); err != nil { - t.Fatal("Failed to load config:", err) - } + require.NoError(t, cfg.LoadFile(configPath)) // Enable watching with permission verification opts := WatchOptions{ @@ -199,21 +179,18 @@ func TestWatchPermissionChange(t *testing.T) { changes := cfg.Watch() // Change permissions to world-writable (security risk) - if err := os.Chmod(configPath, 0666); err != nil { - t.Fatal("Failed to change permissions:", err) - } + require.NoError(t, os.Chmod(configPath, 0666)) // Wait for permission change detection select { case path := <-changes: - if path != "permissions_changed" { - t.Errorf("Expected permissions_changed, got %s", path) - } + assert.Equal(t, "permissions_changed", path) case <-time.After(500 * time.Millisecond): t.Error("Timeout waiting for permission change notification") } } +// TestMaxWatchers tests watcher limit enforcement func TestMaxWatchers(t *testing.T) { cfg := New() cfg.Register("test", "value") @@ -221,13 +198,8 @@ func TestMaxWatchers(t *testing.T) { // Create config file tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "test.toml") - if err := os.WriteFile(configPath, []byte(`test = "value"`), 0644); err != nil { - t.Fatal("Failed to write config:", err) - } - - if err := cfg.LoadFile(configPath); err != nil { - t.Fatal("Failed to load config:", err) - } + require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644)) + require.NoError(t, cfg.LoadFile(configPath)) // Enable watching with low max watchers opts := WatchOptions{ @@ -244,45 +216,40 @@ func TestMaxWatchers(t *testing.T) { channels = append(channels, ch) // Check if channel is open - select { - case _, ok := <-ch: - if !ok && i < 3 { - t.Errorf("Channel %d should be open", i) - } else if ok && i == 3 { - t.Error("Channel 3 should be closed (max watchers exceeded)") + if i < 3 { + // First 3 should be open + select { + case _, ok := <-ch: + assert.True(t, ok || i < 3, "Channel %d should be open", i) + default: + // Channel is open and empty, expected } - default: - // Channel is open and empty, expected for first 3 - if i == 3 { - // Try to receive with timeout to verify it's closed - select { - case _, ok := <-ch: - if ok { - t.Error("Channel 3 should be closed") - } - case <-time.After(10 * time.Millisecond): - t.Error("Channel 3 should be closed immediately") - } + } else { + // 4th should be closed immediately + select { + case _, ok := <-ch: + assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)") + case <-time.After(10 * time.Millisecond): + t.Error("Channel 3 should be closed immediately") } } } + + // Verify watcher count + assert.Equal(t, 3, cfg.WatcherCount()) } +// TestDebounce tests that rapid changes are debounced func TestDebounce(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "test.toml") // Create initial config - if err := os.WriteFile(configPath, []byte(`value = 1`), 0644); err != nil { - t.Fatal("Failed to write config:", err) - } + require.NoError(t, os.WriteFile(configPath, []byte(`value = 1`), 0644)) cfg := New() cfg.Register("value", 0) - - if err := cfg.LoadFile(configPath); err != nil { - t.Fatal("Failed to load config:", err) - } + require.NoError(t, cfg.LoadFile(configPath)) // Enable watching with longer debounce opts := WatchOptions{ @@ -293,39 +260,211 @@ func TestDebounce(t *testing.T) { defer cfg.StopAutoUpdate() changes := cfg.Watch() - changeCount := 0 + + var changeCount int + var mu sync.Mutex + done := make(chan bool) go func() { - for range changes { - changeCount++ + for { + select { + case <-changes: + mu.Lock() + changeCount++ + mu.Unlock() + case <-done: + return + } } }() // Make rapid changes for i := 2; i <= 5; i++ { content := fmt.Sprintf(`value = %d`, i) - if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { - t.Fatal("Failed to write config:", err) - } + require.NoError(t, os.WriteFile(configPath, []byte(content), 0644)) time.Sleep(50 * time.Millisecond) // Less than debounce period } // Wait for debounce to complete time.Sleep(300 * time.Millisecond) + done <- true // Should only see one change due to debounce - if changeCount != 1 { - t.Errorf("Expected 1 change due to debounce, got %d", changeCount) - } + mu.Lock() + defer mu.Unlock() + assert.Equal(t, 1, changeCount, "Expected 1 change due to debounce, got %d", changeCount) // Verify final value val, _ := cfg.Get("value") - if val.(int64) != 5 { - t.Errorf("Expected final value 5, got %d", val) - } + assert.Equal(t, int64(5), val) } -// Benchmark file watching overhead +// TestWatchWithoutFile tests watching behavior when no file is configured +func TestWatchWithoutFile(t *testing.T) { + cfg := New() + cfg.Register("test", "value") + + // No file loaded, watch should return closed channel + ch := cfg.Watch() + + select { + case _, ok := <-ch: + assert.False(t, ok, "Channel should be closed when no file to watch") + case <-time.After(10 * time.Millisecond): + t.Error("Channel should be closed immediately") + } + + assert.False(t, cfg.IsWatching()) +} + +// TestConcurrentWatchOperations tests thread safety of watch operations +func TestConcurrentWatchOperations(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "test.toml") + require.NoError(t, os.WriteFile(configPath, []byte(`value = 1`), 0644)) + + cfg := New() + cfg.Register("value", 0) + require.NoError(t, cfg.LoadFile(configPath)) + + opts := WatchOptions{ + PollInterval: 50 * time.Millisecond, + MaxWatchers: 50, + } + cfg.AutoUpdateWithOptions(opts) + defer cfg.StopAutoUpdate() + + var wg sync.WaitGroup + errors := make(chan error, 100) + + // Start multiple watchers concurrently + for i := 0; i < 20; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + ch := cfg.Watch() + if ch == nil { + errors <- fmt.Errorf("watcher %d: got nil channel", id) + return + } + + // Try to receive + select { + case <-ch: + // OK, got a change + case <-time.After(10 * time.Millisecond): + // OK, no changes yet + } + }(i) + } + + // Concurrent config updates + for i := 0; i < 5; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + content := fmt.Sprintf(`value = %d`, id+10) + if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { + errors <- fmt.Errorf("writer %d: %v", id, err) + } + }(i) + } + + // Check IsWatching concurrently + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + isWatching := false + for j := 0; j < 5; j++ { // Poll a few times, double-dip wait for goroutine to start + if cfg.IsWatching() { + isWatching = true + break + } + time.Sleep(10 * time.Millisecond) + } + if !isWatching { + errors <- fmt.Errorf("checker %d: IsWatching returned false", id) + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for errors + var errs []error + for err := range errors { + errs = append(errs, err) + } + assert.Empty(t, errs, "Concurrent operations should not produce errors") +} + +// TestReloadTimeout tests reload timeout handling +func TestReloadTimeout(t *testing.T) { + // This test would require mocking file operations to simulate a slow read + // For now, we'll test that timeout option is respected in configuration + + cfg := New() + cfg.Register("test", "value") + + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "test.toml") + require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644)) + require.NoError(t, cfg.LoadFile(configPath)) + + // Very short timeout + opts := WatchOptions{ + PollInterval: 100 * time.Millisecond, + ReloadTimeout: 1 * time.Nanosecond, // Extremely short + } + cfg.AutoUpdateWithOptions(opts) + defer cfg.StopAutoUpdate() + + waitForWatchingState(t, cfg, true) +} + +// TestStopAutoUpdate tests clean shutdown of watcher +func TestStopAutoUpdate(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "test.toml") + require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644)) + + cfg := New() + cfg.Register("test", "value") + require.NoError(t, cfg.LoadFile(configPath)) + + // Start watching + cfg.AutoUpdate() + waitForWatchingState(t, cfg, true, "Watcher should be active after first start") + + ch := cfg.Watch() + + // Stop watching + cfg.StopAutoUpdate() + + // Verify stopped + waitForWatchingState(t, cfg, false, "Watcher should be inactive after stop") + assert.Equal(t, 0, cfg.WatcherCount()) + + // Channel should eventually close + select { + case _, ok := <-ch: + assert.False(t, ok, "Channel should be closed after stop") + case <-time.After(100 * time.Millisecond): + // OK, channel might not close immediately + } + + // Starting again should work + cfg.AutoUpdate() + waitForWatchingState(t, cfg, true, "Watcher should be active after restart") + cfg.StopAutoUpdate() +} + +// BenchmarkWatchOverhead benchmarks the overhead of file watching func BenchmarkWatchOverhead(b *testing.B) { tmpDir := b.TempDir() configPath := filepath.Join(tmpDir, "bench.toml") @@ -335,19 +474,13 @@ func BenchmarkWatchOverhead(b *testing.B) { for i := 0; i < 100; i++ { configContent += fmt.Sprintf("value%d = %d\n", i, i) } - - if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil { - b.Fatal("Failed to write config:", err) - } + require.NoError(b, os.WriteFile(configPath, []byte(configContent), 0644)) cfg := New() for i := 0; i < 100; i++ { cfg.Register(fmt.Sprintf("value%d", i), 0) } - - if err := cfg.LoadFile(configPath); err != nil { - b.Fatal("Failed to load config:", err) - } + require.NoError(b, cfg.LoadFile(configPath)) // Enable watching opts := WatchOptions{ @@ -361,4 +494,11 @@ func BenchmarkWatchOverhead(b *testing.B) { for i := 0; i < b.N; i++ { _, _ = cfg.Get(fmt.Sprintf("value%d", i%100)) } +} + +// helper function to wait for watcher state, preventing race conditions of goroutine start and test check +func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) { + require.Eventually(t, func() bool { + return cfg.IsWatching() == expected + }, 200*time.Millisecond, 10*time.Millisecond, msgAndArgs...) } \ No newline at end of file