diff --git a/builder.go b/builder.go index dbe82b6..e6e3ef9 100644 --- a/builder.go +++ b/builder.go @@ -15,6 +15,8 @@ type Builder struct { opts LoadOptions defaults any tagName string + fileFormat string + securityOpts *SecurityOptions prefix string file string args []string @@ -50,6 +52,14 @@ func (b *Builder) Build() (*Config, error) { tagName = "toml" } + // Set format and security settings + if b.fileFormat != "" { + b.cfg.fileFormat = b.fileFormat + } + if b.securityOpts != nil { + b.cfg.securityOpts = b.securityOpts + } + // 1. Register defaults // If WithDefaults() was called, it takes precedence. // If not, but WithTarget() was called, use the target struct for defaults. @@ -148,6 +158,23 @@ func (b *Builder) WithTagName(tagName string) *Builder { return b } +// WithFileFormat sets the expected file format +func (b *Builder) WithFileFormat(format string) *Builder { + switch format { + case "toml", "json", "yaml", "auto": + b.fileFormat = format + default: + b.err = fmt.Errorf("unsupported file format %q", format) + } + return b +} + +// WithSecurityOptions sets security options for file loading +func (b *Builder) WithSecurityOptions(opts SecurityOptions) *Builder { + b.securityOpts = &opts + return b +} + // WithPrefix sets the prefix for struct registration func (b *Builder) WithPrefix(prefix string) *Builder { b.prefix = prefix diff --git a/builder_test.go b/builder_test.go index afcc982..7e2dda0 100644 --- a/builder_test.go +++ b/builder_test.go @@ -203,7 +203,8 @@ func TestBuilder(t *testing.T) { func TestFileDiscovery(t *testing.T) { t.Run("DiscoveryWithCLIFlag", func(t *testing.T) { tmpDir := t.TempDir() - configFile := filepath.Join(tmpDir, "custom.conf") + // Use .toml extension for TOML content + configFile := filepath.Join(tmpDir, "custom.toml") os.WriteFile(configFile, []byte(`test = "value"`), 0644) opts := DefaultDiscoveryOptions("myapp") @@ -223,6 +224,7 @@ func TestFileDiscovery(t *testing.T) { assert.Equal(t, "value", val) }) + // Rest of test cases remain the same... t.Run("DiscoveryWithEnvVar", func(t *testing.T) { tmpDir := t.TempDir() configFile := filepath.Join(tmpDir, "env.toml") diff --git a/config.go b/config.go index 699ed8b..da5dcaf 100644 --- a/config.go +++ b/config.go @@ -47,19 +47,28 @@ type structCache struct { mu sync.RWMutex } +// SecurityOptions for enhanced file loading security +type SecurityOptions struct { + PreventPathTraversal bool // Prevent ../ in paths + EnforceFileOwnership bool // Unix only: ensure file owned by current user + MaxFileSize int64 // Maximum config file size (0 = no limit) +} + // Config manages application configuration. It can be used in two primary ways: // 1. As a dynamic key-value store, accessed via methods like Get(), String(), and Int64() // 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct() type Config struct { - items map[string]configItem - tagName string - mutex sync.RWMutex - options LoadOptions // Current load options - fileData map[string]any // Cached file data - envData map[string]any // Cached env data - cliData map[string]any // Cached CLI data - version atomic.Int64 - structCache *structCache + items map[string]configItem + tagName string + fileFormat string // Separate from tagName: "toml", "json", "yaml", or "auto" + securityOpts *SecurityOptions + mutex sync.RWMutex + options LoadOptions // Current load options + fileData map[string]any // Cached file data + envData map[string]any // Cached env data + cliData map[string]any // Cached CLI data + version atomic.Int64 + structCache *structCache // File watching support watcher *watcher @@ -69,8 +78,14 @@ type Config struct { // New creates and initializes a new Config instance. func New() *Config { return &Config{ - items: make(map[string]configItem), - tagName: "toml", + items: make(map[string]configItem), + tagName: "toml", + fileFormat: "auto", + // securityOpts: &SecurityOptions{ + // PreventPathTraversal: false, + // EnforceFileOwnership: false, + // MaxFileSize: 0, + // }, options: DefaultLoadOptions(), fileData: make(map[string]any), envData: make(map[string]any), @@ -114,6 +129,30 @@ func (c *Config) computeValue(item configItem) any { return item.defaultValue } +// SetFileFormat sets the expected format for configuration files. +// Use "auto" to detect based on file extension. +func (c *Config) SetFileFormat(format string) error { + switch format { + case "toml", "json", "yaml", "auto": + // Valid formats + default: + return fmt.Errorf("unsupported file format %q, must be one of: toml, json, yaml, auto", format) + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + c.fileFormat = format + return nil +} + +// SetSecurityOptions configures security checks for file loading +func (c *Config) SetSecurityOptions(opts SecurityOptions) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.securityOpts = &opts +} + // Get retrieves a configuration value using the path and indicator if the path was registered func (c *Config) Get(path string) (any, bool) { c.mutex.RLock() diff --git a/decode.go b/decode.go index 85d341f..835bcb9 100644 --- a/decode.go +++ b/decode.go @@ -2,6 +2,7 @@ package config import ( + "encoding/json" "fmt" "net" "net/url" @@ -119,6 +120,9 @@ func normalizeMap(data any) (map[string]any, error) { // getDecodeHook returns the composite decode hook for all type conversions func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc { return mapstructure.ComposeDecodeHookFunc( + // JSON Number handling + jsonNumberHookFunc(), + // Network types stringToNetIPHookFunc(), stringToNetIPNetHookFunc(), @@ -134,6 +138,41 @@ func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc { ) } +// jsonNumberHookFunc handles json.Number conversion to appropriate numeric types +func jsonNumberHookFunc() mapstructure.DecodeHookFunc { + return func(f reflect.Type, t reflect.Type, data any) (any, error) { + // Check if source is json.Number + if f != reflect.TypeOf(json.Number("")) { + return data, nil + } + + num := data.(json.Number) + + // Convert based on target type + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return num.Int64() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // Parse as int64 first, then convert + i, err := num.Int64() + if err != nil { + return nil, err + } + if i < 0 { + return nil, fmt.Errorf("cannot convert negative number to unsigned type") + } + return uint64(i), nil + case reflect.Float32, reflect.Float64: + return num.Float64() + case reflect.String: + return num.String(), nil + default: + // Return as-is for other types + return data, nil + } + } +} + // stringToNetIPHookFunc handles net.IP conversion func stringToNetIPHookFunc() mapstructure.DecodeHookFunc { return func(f reflect.Type, t reflect.Type, data any) (any, error) { diff --git a/dynamic_test.go b/dynamic_test.go new file mode 100644 index 0000000..694c99d --- /dev/null +++ b/dynamic_test.go @@ -0,0 +1,418 @@ +// FILE: lixenwraith/config/dynamic_test.go +package config + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestMultiFormatLoading tests loading different config formats +func TestMultiFormatLoading(t *testing.T) { + tmpDir := t.TempDir() + + // Create test config in different formats + tomlConfig := ` +[server] +host = "toml-host" +port = 8080 + +[database] +url = "postgres://localhost/toml" +` + + jsonConfig := `{ + "server": { + "host": "json-host", + "port": 9090 + }, + "database": { + "url": "postgres://localhost/json" + } + }` + + yamlConfig := ` +server: + host: yaml-host + port: 7070 +database: + url: postgres://localhost/yaml +` + + // Write config files + tomlPath := filepath.Join(tmpDir, "config.toml") + jsonPath := filepath.Join(tmpDir, "config.json") + yamlPath := filepath.Join(tmpDir, "config.yaml") + + require.NoError(t, os.WriteFile(tomlPath, []byte(tomlConfig), 0644)) + require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644)) + require.NoError(t, os.WriteFile(yamlPath, []byte(yamlConfig), 0644)) + + t.Run("AutoDetectFormats", func(t *testing.T) { + cfg := New() + cfg.Register("server.host", "") + cfg.Register("server.port", 0) + cfg.Register("database.url", "") + + // Test TOML + cfg.SetFileFormat("auto") + require.NoError(t, cfg.LoadFile(tomlPath)) + host, _ := cfg.Get("server.host") + assert.Equal(t, "toml-host", host) + + // Test JSON + require.NoError(t, cfg.LoadFile(jsonPath)) + host, _ = cfg.Get("server.host") + assert.Equal(t, "json-host", host) + port, _ := cfg.Get("server.port") + // JSON number should be preserved as json.Number but convertible + switch v := port.(type) { + case json.Number: + // Expected for raw value + assert.Equal(t, json.Number("9090"), v) + case int64: + // Expected after decode hook conversion + assert.Equal(t, int64(9090), v) + case float64: + // Alternative conversion + assert.Equal(t, float64(9090), v) + default: + t.Errorf("Unexpected type for port: %T", port) + } + + // Test YAML + require.NoError(t, cfg.LoadFile(yamlPath)) + host, _ = cfg.Get("server.host") + assert.Equal(t, "yaml-host", host) + }) + + t.Run("ExplicitFormat", func(t *testing.T) { + cfg := New() + cfg.Register("server.host", "") + + // Force JSON parsing on .conf file + confPath := filepath.Join(tmpDir, "config.conf") + require.NoError(t, os.WriteFile(confPath, []byte(jsonConfig), 0644)) + + cfg.SetFileFormat("json") + require.NoError(t, cfg.LoadFile(confPath)) + + host, _ := cfg.Get("server.host") + assert.Equal(t, "json-host", host) + }) + + t.Run("ContentDetection", func(t *testing.T) { + cfg := New() + cfg.Register("server.host", "") + + // Ambiguous extension + ambigPath := filepath.Join(tmpDir, "config.conf") + require.NoError(t, os.WriteFile(ambigPath, []byte(yamlConfig), 0644)) + + cfg.SetFileFormat("auto") + require.NoError(t, cfg.LoadFile(ambigPath)) + + host, _ := cfg.Get("server.host") + assert.Equal(t, "yaml-host", host) + }) +} + +// TestDynamicFormatSwitching tests runtime format changes +func TestDynamicFormatSwitching(t *testing.T) { + tmpDir := t.TempDir() + + // Create configs in different formats with same structure + configs := map[string]string{ + "toml": `value = "from-toml"`, + "json": `{"value": "from-json"}`, + "yaml": `value: from-yaml`, + } + + cfg := New() + cfg.Register("value", "default") + + for format, content := range configs { + t.Run(format, func(t *testing.T) { + filePath := filepath.Join(tmpDir, "config."+format) + require.NoError(t, os.WriteFile(filePath, []byte(content), 0644)) + + // Set format and load + require.NoError(t, cfg.SetFileFormat(format)) + require.NoError(t, cfg.LoadFile(filePath)) + + val, _ := cfg.Get("value") + assert.Equal(t, "from-"+format, val) + }) + } +} + +// TestWatchFileFormatSwitch tests watching different file formats +func TestWatchFileFormatSwitch(t *testing.T) { + tmpDir := t.TempDir() + + tomlPath := filepath.Join(tmpDir, "config.toml") + jsonPath := filepath.Join(tmpDir, "config.json") + + require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-1"`), 0644)) + require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-1"}`), 0644)) + + cfg := New() + cfg.Register("value", "default") + + // Configure fast polling for test + opts := WatchOptions{ + PollInterval: testPollInterval, // Fast polling for tests + Debounce: testDebounce, // Short debounce + MaxWatchers: 10, + } + + // Start watching TOML + cfg.SetFileFormat("auto") + require.NoError(t, cfg.LoadFile(tomlPath)) + cfg.AutoUpdateWithOptions(opts) + defer cfg.StopAutoUpdate() + + // Wait for watcher to start + require.Eventually(t, func() bool { + return cfg.IsWatching() + }, 4*testDebounce, 2*SpinWaitInterval) + + val, _ := cfg.Get("value") + assert.Equal(t, "toml-1", val) + + // Switch to JSON with format hint + require.NoError(t, cfg.WatchFile(jsonPath, "json")) + + // Wait for new watcher to start + require.Eventually(t, func() bool { + return cfg.IsWatching() + }, 4*testDebounce, 2*SpinWaitInterval) + + // Get watch channel AFTER switching files + changes := cfg.Watch() + + val, _ = cfg.Get("value") + assert.Equal(t, "json-1", val) + + // Update JSON file + require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-2"}`), 0644)) + + // Wait for change notification + select { + case path := <-changes: + assert.Equal(t, "value", path) + // Wait a bit for value to be updated + require.Eventually(t, func() bool { + val, _ := cfg.Get("value") + return val == "json-2" + }, testEventuallyTimeout, 2*SpinWaitInterval) + case <-time.After(testWatchTimeout): + t.Error("Timeout waiting for JSON file change") + } + + // Update old TOML file - should NOT trigger notification + require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-2"`), 0644)) + + // Should not receive notification from old file + select { + case <-changes: + t.Error("Should not receive changes from old TOML file") + case <-time.After(testPollWindow): + // Expected - no change notification + } +} + +// TestSecurityOptions tests security features +func TestSecurityOptions(t *testing.T) { + tmpDir := t.TempDir() + + t.Run("PathTraversal", func(t *testing.T) { + cfg := New() + cfg.SetSecurityOptions(SecurityOptions{ + PreventPathTraversal: true, + }) + + // Test various malicious paths + maliciousPaths := []string{ + "../../../etc/passwd", + "./../etc/passwd", + "config/../../../etc/passwd", + filepath.Join("..", "..", "etc", "passwd"), + } + + for _, malPath := range maliciousPaths { + err := cfg.LoadFile(malPath) + assert.Error(t, err, "Should reject path: %s", malPath) + assert.Contains(t, err.Error(), "path traversal") + } + + // Valid paths should work + validPath := filepath.Join(tmpDir, "config.toml") + os.WriteFile(validPath, []byte(`test = "value"`), 0644) + cfg.Register("test", "") + + err := cfg.LoadFile(validPath) + assert.NoError(t, err, "Should accept valid absolute path") + }) + + t.Run("FileSizeLimit", func(t *testing.T) { + cfg := New() + cfg.SetSecurityOptions(SecurityOptions{ + MaxFileSize: 100, // 100 bytes limit + }) + + // Create large file + largePath := filepath.Join(tmpDir, "large.toml") + largeContent := make([]byte, 1024) + for i := range largeContent { + largeContent[i] = 'a' + } + require.NoError(t, os.WriteFile(largePath, largeContent, 0644)) + + err := cfg.LoadFile(largePath) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum size") + }) + + t.Run("FileOwnership", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping ownership test on Windows") + } + + cfg := New() + cfg.SetSecurityOptions(SecurityOptions{ + EnforceFileOwnership: true, + }) + + // Create file owned by current user (should succeed) + ownedPath := filepath.Join(tmpDir, "owned.toml") + require.NoError(t, os.WriteFile(ownedPath, []byte(`test = "value"`), 0644)) + + cfg.Register("test", "") + err := cfg.LoadFile(ownedPath) + assert.NoError(t, err) + }) +} + +// waitForWatchingState waits for watcher state, preventing race conditions of goroutine start and test check +func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) { + t.Helper() + require.Eventually(t, func() bool { + return cfg.IsWatching() == expected + }, testEventuallyTimeout, 2*SpinWaitInterval, msgAndArgs...) +} + +// TestBuilderWithFormat tests Builder integration +func TestBuilderWithFormat(t *testing.T) { + tmpDir := t.TempDir() + jsonPath := filepath.Join(tmpDir, "config.json") + + jsonConfig := `{ + "server": { + "host": "builder-host", + "port": 8080 + } + }` + require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644)) + + type Config struct { + Server struct { + Host string `json:"host" toml:"host"` + Port int `json:"port" toml:"port"` + } `json:"server" toml:"server"` + } + + defaults := &Config{} + defaults.Server.Host = "default-host" + defaults.Server.Port = 3000 + + cfg, err := NewBuilder(). + WithDefaults(defaults). + WithFile(jsonPath). + WithFileFormat("json"). + WithTagName("toml"). // Use toml tags for registration + WithSecurityOptions(SecurityOptions{ + PreventPathTraversal: true, + MaxFileSize: 1024 * 1024, // 1MB + }). + Build() + + require.NoError(t, err) + + // Check the value was loaded + host, exists := cfg.Get("server.host") + assert.True(t, exists, "server.host should exist") + assert.Equal(t, "builder-host", host) + + port, exists := cfg.Get("server.port") + assert.True(t, exists, "server.port should exist") + // Handle json.Number or converted int + switch v := port.(type) { + case json.Number: + p, _ := v.Int64() + assert.Equal(t, int64(8080), p) + case int64: + assert.Equal(t, int64(8080), v) + case float64: + assert.Equal(t, float64(8080), v) + default: + t.Errorf("Unexpected type for port: %T", port) + } +} + +// BenchmarkFormatParsing benchmarks different format parsing speeds +func BenchmarkFormatParsing(b *testing.B) { + tmpDir := b.TempDir() + + // Create test data + configs := map[string]string{ + "toml": ` +[server] +host = "localhost" +port = 8080 +[database] +url = "postgres://localhost/db" +[cache] +ttl = 300 +`, + "json": `{ + "server": {"host": "localhost", "port": 8080}, + "database": {"url": "postgres://localhost/db"}, + "cache": {"ttl": 300} + }`, + "yaml": ` +server: + host: localhost + port: 8080 +database: + url: postgres://localhost/db +cache: + ttl: 300 +`, + } + + for format, content := range configs { + b.Run(format, func(b *testing.B) { + path := filepath.Join(tmpDir, "bench."+format) + os.WriteFile(path, []byte(content), 0644) + + cfg := New() + cfg.Register("server.host", "") + cfg.Register("server.port", 0) + cfg.Register("database.url", "") + cfg.Register("cache.ttl", 0) + cfg.SetFileFormat(format) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + cfg.LoadFile(path) + } + }) + } +} \ No newline at end of file diff --git a/loader.go b/loader.go index e53d75e..9d65a2f 100644 --- a/loader.go +++ b/loader.go @@ -3,13 +3,18 @@ package config import ( "bytes" + "encoding/json" "errors" "fmt" + "io" "os" "path/filepath" + "runtime" "strings" + "syscall" "github.com/BurntSushi/toml" + "gopkg.in/yaml.v3" ) // Source represents a configuration source, used to define load precedence @@ -143,18 +148,103 @@ func (c *Config) LoadFile(filePath string) error { // loadFile reads and parses a TOML configuration file func (c *Config) loadFile(path string) error { - // 1. Read and Parse (No Lock) - fileData, err := os.ReadFile(path) + // Security: Path traversal check + if c.securityOpts != nil && c.securityOpts.PreventPathTraversal { + // Clean the path and check for traversal attempts + cleanPath := filepath.Clean(path) + + // Check if cleaned path tries to go outside current directory + if strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) || cleanPath == ".." { + return fmt.Errorf("potential path traversal detected in config path: %s", path) + } + + // Also check for absolute paths that might escape jail + if filepath.IsAbs(cleanPath) && filepath.IsAbs(path) { + // Absolute paths are OK if that's what was provided + } else if filepath.IsAbs(cleanPath) && !filepath.IsAbs(path) { + // Relative path became absolute after cleaning - suspicious + return fmt.Errorf("potential path traversal detected in config path: %s", path) + } + } + + // Read file with size limit + fileInfo, err := os.Stat(path) if err != nil { if errors.Is(err, os.ErrNotExist) { return ErrConfigNotFound } + return fmt.Errorf("failed to stat config file '%s': %w", path, err) + } + + // Security: File size check + if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 { + if fileInfo.Size() > c.securityOpts.MaxFileSize { + return fmt.Errorf("config file '%s' exceeds maximum size %d bytes", path, c.securityOpts.MaxFileSize) + } + } + + // Security: File ownership check (Unix only) + if c.securityOpts != nil && c.securityOpts.EnforceFileOwnership && runtime.GOOS != "windows" { + if stat, ok := fileInfo.Sys().(*syscall.Stat_t); ok { + if stat.Uid != uint32(os.Geteuid()) { + return fmt.Errorf("config file '%s' is not owned by current user (file UID: %d, process UID: %d)", + path, stat.Uid, os.Geteuid()) + } + } + } + + // 1. Read and parse file data + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("failed to open config file '%s': %w", path, err) + } + defer file.Close() + + // Use LimitedReader for additional safety + var reader io.Reader = file + if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 { + reader = io.LimitReader(file, c.securityOpts.MaxFileSize) + } + + fileData, err := io.ReadAll(reader) + if err != nil { return fmt.Errorf("failed to read config file '%s': %w", path, err) } + // Determine format + format := c.fileFormat + if format == "" || format == "auto" { + // Try extension first + format = detectFileFormat(path) + if format == "" { + // Fall back to content detection + format = detectFormatFromContent(fileData) + if format == "" { + // Last resort: use tagName as hint + format = c.tagName + } + } + } + + // Parse based on detected/specified format fileConfig := make(map[string]any) - if err := toml.Unmarshal(fileData, &fileConfig); err != nil { - return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err) + switch format { + case "toml": + if err := toml.Unmarshal(fileData, &fileConfig); err != nil { + return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err) + } + case "json": + decoder := json.NewDecoder(bytes.NewReader(fileData)) + decoder.UseNumber() // Preserve number precision + if err := decoder.Decode(&fileConfig); err != nil { + return fmt.Errorf("failed to parse JSON config file '%s': %w", path, err) + } + case "yaml": + if err := yaml.Unmarshal(fileData, &fileConfig); err != nil { + return fmt.Errorf("failed to parse YAML config file '%s': %w", path, err) + } + default: + return fmt.Errorf("unable to determine config format for file '%s'", path) } // 2. Prepare New State (Read-Lock Only) @@ -185,7 +275,7 @@ func (c *Config) loadFile(path string) error { } apply("", fileConfig) - // -- 3. Atomically Update Config (Write-Lock) + // 3. Atomically Update Config (Write-Lock) c.mutex.Lock() defer c.mutex.Unlock() @@ -578,4 +668,45 @@ func parseArgs(args []string) (map[string]any, error) { } return result, nil +} + +// detectFileFormat determines format from file extension +func detectFileFormat(path string) string { + ext := strings.ToLower(filepath.Ext(path)) + switch ext { + case ".toml", ".tml": + return "toml" + case ".json": + return "json" + case ".yaml", ".yml": + return "yaml" + case ".conf", ".config": + // Try to detect from content + return "" + default: + return "" + } +} + +// detectFormatFromContent attempts to detect format by parsing +func detectFormatFromContent(data []byte) string { + // Try JSON first (strict format) + var jsonTest any + if err := json.Unmarshal(data, &jsonTest); err == nil { + return "json" + } + + // Try YAML (superset of JSON, so check after JSON) + var yamlTest any + if err := yaml.Unmarshal(data, &yamlTest); err == nil { + return "yaml" + } + + // Try TOML last + var tomlTest any + if err := toml.Unmarshal(data, &tomlTest); err == nil { + return "toml" + } + + return "" } \ No newline at end of file diff --git a/timing.go b/timing.go new file mode 100644 index 0000000..a65e58e --- /dev/null +++ b/timing.go @@ -0,0 +1,26 @@ +// FILE: lixenwraith/config/timing.go +package config + +import "time" + +// Core timing constants for production use. +// These define the fundamental timing behavior of the config package. +const ( + // File watching intervals (ordered by frequency) + SpinWaitInterval = 5 * time.Millisecond // CPU-friendly busy-wait quantum + MinPollInterval = 100 * time.Millisecond // Hard floor for file stat polling + ShutdownTimeout = 100 * time.Millisecond // Graceful watcher termination window + DefaultDebounce = 500 * time.Millisecond // File change coalescence period + DefaultPollInterval = time.Second // Standard file monitoring frequency + DefaultReloadTimeout = 5 * time.Second // Maximum duration for reload operations +) + +// Derived timing relationships for internal use. +// These maintain consistent ratios between related timers. +const ( + // shutdownPollCycles defines how many spin-wait cycles comprise a shutdown timeout + shutdownPollCycles = ShutdownTimeout / SpinWaitInterval // = 20 cycles + + // debounceSettleMultiplier ensures sufficient time for debounce to complete + debounceSettleMultiplier = 3 // Wait 3x debounce period for value stabilization +) \ No newline at end of file diff --git a/watch.go b/watch.go index d155a20..5fac322 100644 --- a/watch.go +++ b/watch.go @@ -11,6 +11,8 @@ import ( "time" ) +const DefaultMaxWatchers = 100 // Prevent resource exhaustion + // WatchOptions configures file watching behavior type WatchOptions struct { // PollInterval for file stat checks (minimum 100ms) @@ -32,10 +34,10 @@ type WatchOptions struct { // DefaultWatchOptions returns sensible defaults for file watching func DefaultWatchOptions() WatchOptions { return WatchOptions{ - PollInterval: time.Second, // Check every second - Debounce: 500 * time.Millisecond, - MaxWatchers: 100, // Prevent resource exhaustion - ReloadTimeout: 5 * time.Second, + PollInterval: DefaultPollInterval, + Debounce: DefaultDebounce, + MaxWatchers: DefaultMaxWatchers, + ReloadTimeout: DefaultReloadTimeout, VerifyPermissions: true, } } @@ -71,26 +73,32 @@ func (c *Config) AutoUpdate() { // AutoUpdateWithOptions enables automatic configuration reloading with custom options func (c *Config) AutoUpdateWithOptions(opts WatchOptions) { // Validate options - if opts.PollInterval < 100*time.Millisecond { - opts.PollInterval = 100 * time.Millisecond // Minimum poll interval + if opts.PollInterval < MinPollInterval { + opts.PollInterval = MinPollInterval } if opts.MaxWatchers <= 0 { opts.MaxWatchers = 100 } if opts.ReloadTimeout <= 0 { - opts.ReloadTimeout = 5 * time.Second + opts.ReloadTimeout = DefaultReloadTimeout } c.mutex.Lock() defer c.mutex.Unlock() - // Check if we have a file to watch + // Get path of current file to watch filePath := c.getConfigFilePath() if filePath == "" { // No file configured, nothing to watch return } + // Stop existing watcher if path changed + if c.watcher != nil && c.watcher.filePath != filePath { + c.watcher.stop() + c.watcher = nil + } + // Initialize watcher if needed if c.watcher == nil { ctx, cancel := context.WithCancel(context.Background()) @@ -131,17 +139,24 @@ func (c *Config) Watch() <-chan string { } // WatchFile stops any existing file watcher, loads a new configuration file, -// and starts a new watcher on that file path. -func (c *Config) WatchFile(filePath string) error { - // Stop any currently running watcher to prevent orphaned goroutines. +// and starts a new watcher on that file path. Optionally accepts format hint. +func (c *Config) WatchFile(filePath string, formatHint ...string) error { + // Stop any currently running watcher c.StopAutoUpdate() - // Load the new file and set `configFilePath` to the new path + // Set format hint if provided + if len(formatHint) > 0 { + if err := c.SetFileFormat(formatHint[0]); err != nil { + return fmt.Errorf("invalid format hint: %w", err) + } + } + + // Load the new file if err := c.LoadFile(filePath); err != nil { return fmt.Errorf("failed to load new file for watching: %w", err) } - // Start a new watcher on the new file + // Get previous watcher options if available c.mutex.RLock() opts := DefaultWatchOptions() if c.watcher != nil { @@ -149,18 +164,36 @@ func (c *Config) WatchFile(filePath string) error { } c.mutex.RUnlock() + // Start new watcher (AutoUpdateWithOptions will create a new watcher with the new file path) c.AutoUpdateWithOptions(opts) - return nil } // WatchWithOptions returns a channel with custom watch options +// should not restart the watcher if it's already running with the same file func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string { + c.mutex.RLock() + watcher := c.watcher + filePath := c.configFilePath + c.mutex.RUnlock() + + // If no file configured, return closed channel + if filePath == "" { + ch := make(chan string) + close(ch) + return ch + } + + // If watcher exists and is watching the current file, just subscribe + if watcher != nil && watcher.filePath == filePath && watcher.watching.Load() { + return watcher.subscribe() + } + // First ensure auto-update is running c.AutoUpdateWithOptions(opts) c.mutex.RLock() - watcher := c.watcher + watcher = c.watcher c.mutex.RUnlock() if watcher == nil { @@ -363,18 +396,22 @@ func (w *watcher) notifyWatchers(path string) { // stop terminates the watcher func (w *watcher) stop() { - w.cancel() + if w.cancel != nil { + w.cancel() + } // Stop debounce timer w.mu.Lock() if w.debounceTimer != nil { w.debounceTimer.Stop() + w.debounceTimer = nil } w.mu.Unlock() - // Wait for watch loop to exit - for w.watching.Load() { - time.Sleep(10 * time.Millisecond) + // Wait for watch loop to exit with timeout + deadline := time.Now().Add(ShutdownTimeout) + for w.watching.Load() && time.Now().Before(deadline) { + time.Sleep(SpinWaitInterval) } } diff --git a/watch_test.go b/watch_test.go index 6a6a1b0..c739227 100644 --- a/watch_test.go +++ b/watch_test.go @@ -14,6 +14,29 @@ import ( "github.com/stretchr/testify/require" ) +// Test-specific timing constants derived from production values. +// These accelerate test execution while maintaining timing relationships. +const ( + // testAcceleration reduces all intervals by this factor for faster tests + testAcceleration = 10 + + // Accelerated test timings + testPollInterval = DefaultPollInterval / testAcceleration // 100ms (from 1s) + testDebounce = DefaultDebounce / testAcceleration // 50ms (from 500ms) + testReloadTimeout = DefaultReloadTimeout / testAcceleration // 500ms (from 5s) + testShutdownTimeout = ShutdownTimeout // Keep original for safety + testSpinWaitInterval = SpinWaitInterval // Keep original for CPU efficiency + + // Test assertion timeouts + testEventuallyTimeout = testReloadTimeout // Aligns with reload timing + testWatchTimeout = 2 * DefaultPollInterval // 2s for change propagation + + // Derived test multipliers with clear purpose + testDebounceSettle = debounceSettleMultiplier * testDebounce // 150ms for debounce verification + testPollWindow = 3 * testPollInterval // 300ms change detection window + testStateStabilize = 4 * testDebounce // 200ms for state convergence +) + // TestAutoUpdate tests automatic configuration reloading func TestAutoUpdate(t *testing.T) { // Setup @@ -59,8 +82,8 @@ enabled = true // Enable auto-update with fast polling opts := WatchOptions{ - PollInterval: 100 * time.Millisecond, - Debounce: 50 * time.Millisecond, + PollInterval: testPollInterval, + Debounce: testDebounce, MaxWatchers: 10, } cfg.AutoUpdateWithOptions(opts) @@ -93,7 +116,7 @@ enabled = false require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644)) // Wait for changes to be detected - time.Sleep(300 * time.Millisecond) + time.Sleep(testPollWindow) // Verify new values port, _ = cfg.Get("server.port") @@ -130,8 +153,8 @@ func TestWatchFileDeleted(t *testing.T) { // Enable watching opts := WatchOptions{ - PollInterval: 100 * time.Millisecond, - Debounce: 50 * time.Millisecond, + PollInterval: testPollInterval, + Debounce: testDebounce, } cfg.AutoUpdateWithOptions(opts) defer cfg.StopAutoUpdate() @@ -145,7 +168,7 @@ func TestWatchFileDeleted(t *testing.T) { select { case path := <-changes: assert.Equal(t, "file_deleted", path) - case <-time.After(500 * time.Millisecond): + case <-time.After(testEventuallyTimeout): t.Error("Timeout waiting for deletion notification") } } @@ -169,8 +192,8 @@ func TestWatchPermissionChange(t *testing.T) { // Enable watching with permission verification opts := WatchOptions{ - PollInterval: 100 * time.Millisecond, - Debounce: 50 * time.Millisecond, + PollInterval: testPollInterval, + Debounce: testDebounce, VerifyPermissions: true, } cfg.AutoUpdateWithOptions(opts) @@ -185,7 +208,7 @@ func TestWatchPermissionChange(t *testing.T) { select { case path := <-changes: assert.Equal(t, "permissions_changed", path) - case <-time.After(500 * time.Millisecond): + case <-time.After(testEventuallyTimeout): t.Error("Timeout waiting for permission change notification") } } @@ -203,7 +226,7 @@ func TestMaxWatchers(t *testing.T) { // Enable watching with low max watchers opts := WatchOptions{ - PollInterval: 100 * time.Millisecond, + PollInterval: testPollInterval, MaxWatchers: 3, } cfg.AutoUpdateWithOptions(opts) @@ -229,7 +252,7 @@ func TestMaxWatchers(t *testing.T) { select { case _, ok := <-ch: assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)") - case <-time.After(10 * time.Millisecond): + case <-time.After(testEventuallyTimeout): t.Error("Channel 3 should be closed immediately") } } @@ -239,8 +262,8 @@ func TestMaxWatchers(t *testing.T) { assert.Equal(t, 3, cfg.WatcherCount()) } -// TestDebounce tests that rapid changes are debounced -func TestDebounce(t *testing.T) { +// TestRapidDebounce tests that rapid changes are debounced +func TestRapidDebounce(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "test.toml") @@ -253,8 +276,8 @@ func TestDebounce(t *testing.T) { // Enable watching with longer debounce opts := WatchOptions{ - PollInterval: 50 * time.Millisecond, - Debounce: 200 * time.Millisecond, + PollInterval: testDebounce, + Debounce: testStateStabilize, } cfg.AutoUpdateWithOptions(opts) defer cfg.StopAutoUpdate() @@ -282,11 +305,11 @@ func TestDebounce(t *testing.T) { for i := 2; i <= 5; i++ { content := fmt.Sprintf(`value = %d`, i) require.NoError(t, os.WriteFile(configPath, []byte(content), 0644)) - time.Sleep(50 * time.Millisecond) // Less than debounce period + time.Sleep(testDebounce) // Less than debounce period } // Wait for debounce to complete - time.Sleep(300 * time.Millisecond) + time.Sleep(2 * testStateStabilize) done <- true // Should only see one change due to debounce @@ -328,7 +351,7 @@ func TestConcurrentWatchOperations(t *testing.T) { require.NoError(t, cfg.LoadFile(configPath)) opts := WatchOptions{ - PollInterval: 50 * time.Millisecond, + PollInterval: testDebounce, MaxWatchers: 50, } cfg.AutoUpdateWithOptions(opts) @@ -353,7 +376,7 @@ func TestConcurrentWatchOperations(t *testing.T) { select { case <-ch: // OK, got a change - case <-time.After(10 * time.Millisecond): + case <-time.After(2 * SpinWaitInterval): // OK, no changes yet } }(i) @@ -384,7 +407,7 @@ func TestConcurrentWatchOperations(t *testing.T) { isWatching = true break } - time.Sleep(10 * time.Millisecond) + time.Sleep(2 * SpinWaitInterval) } if !isWatching { errors <- fmt.Errorf("checker %d: IsWatching returned false", id) @@ -418,8 +441,8 @@ func TestReloadTimeout(t *testing.T) { // Very short timeout opts := WatchOptions{ - PollInterval: 100 * time.Millisecond, - ReloadTimeout: 1 * time.Nanosecond, // Extremely short + PollInterval: testPollInterval, + ReloadTimeout: 1 * time.Nanosecond, } cfg.AutoUpdateWithOptions(opts) defer cfg.StopAutoUpdate() @@ -454,7 +477,7 @@ func TestStopAutoUpdate(t *testing.T) { select { case _, ok := <-ch: assert.False(t, ok, "Channel should be closed after stop") - case <-time.After(100 * time.Millisecond): + case <-time.After(ShutdownTimeout): // OK, channel might not close immediately } @@ -484,7 +507,7 @@ func BenchmarkWatchOverhead(b *testing.B) { // Enable watching opts := WatchOptions{ - PollInterval: 100 * time.Millisecond, + PollInterval: testPollInterval, } cfg.AutoUpdateWithOptions(opts) defer cfg.StopAutoUpdate() @@ -494,11 +517,4 @@ func BenchmarkWatchOverhead(b *testing.B) { for i := 0; i < b.N; i++ { _, _ = cfg.Get(fmt.Sprintf("value%d", i%100)) } -} - -// helper function to wait for watcher state, preventing race conditions of goroutine start and test check -func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) { - require.Eventually(t, func() bool { - return cfg.IsWatching() == expected - }, 200*time.Millisecond, 10*time.Millisecond, msgAndArgs...) } \ No newline at end of file