From 06cddbe00e3b9e91b1b61ee611ea36f328fdb59603a6b479ca2cb212281f5240 Mon Sep 17 00:00:00 2001 From: Lixen Wraith Date: Sun, 20 Jul 2025 02:09:32 -0400 Subject: [PATCH] e5.2.0 Decoder and loader refactored, bug fixes. --- .gitignore | 3 +- builder.go | 93 +++++++++++++++++++++++++++----------- builder_test.go | 59 ++++++++++++++++++++++++ convenience.go | 48 ++++++++++++++++++++ convenience_test.go | 40 +++++++++++++++++ decode.go | 71 ++++++++++++++++++++++++++--- doc/access.md | 31 +++++++++++++ doc/builder.md | 31 ++++++++++++- example/main.go | 84 +++++++++++----------------------- go.mod | 2 +- loader.go | 107 ++++++++++++++++++++++++++++++++------------ watch.go | 83 +++++++++++----------------------- 12 files changed, 474 insertions(+), 178 deletions(-) diff --git a/.gitignore b/.gitignore index 3b3200b..3de1e60 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ log logs script *.log -bin \ No newline at end of file +*.toml +bin diff --git a/builder.go b/builder.go index 6f0a9b5..dbe82b6 100644 --- a/builder.go +++ b/builder.go @@ -11,15 +11,16 @@ import ( // Builder provides a fluent API for constructing a Config instance. It allows for // chaining configuration options before final build of the config object. type Builder struct { - cfg *Config - opts LoadOptions - defaults any - tagName string - prefix string - file string - args []string - err error - validators []ValidatorFunc + cfg *Config + opts LoadOptions + defaults any + tagName string + prefix string + file string + args []string + err error + validators []ValidatorFunc + typedValidators []any } // ValidatorFunc defines the signature for a function that can validate a Config instance. @@ -29,10 +30,11 @@ type ValidatorFunc func(c *Config) error // NewBuilder creates a new configuration builder func NewBuilder() *Builder { return &Builder{ - cfg: New(), - opts: DefaultLoadOptions(), - args: os.Args[1:], - validators: make([]ValidatorFunc, 0), + cfg: New(), + opts: DefaultLoadOptions(), + args: os.Args[1:], + validators: make([]ValidatorFunc, 0), + typedValidators: make([]any, 0), } } @@ -48,9 +50,9 @@ func (b *Builder) Build() (*Config, error) { tagName = "toml" } - // The logic for registering defaults must be prioritized: - // 1. If WithDefaults() was called, it takes precedence. - // 2. If not, but WithTarget() was called, use the target struct for defaults. + // 1. Register defaults + // If WithDefaults() was called, it takes precedence. + // If not, but WithTarget() was called, use the target struct for defaults. if b.defaults != nil { // WithDefaults() was called explicitly. if err := b.cfg.RegisterStructWithTags(b.prefix, b.defaults, tagName); err != nil { @@ -65,23 +67,50 @@ func (b *Builder) Build() (*Config, error) { } // Explicitly set the file path on the config object so the watcher can find it, - // even if the initial load fails with a non-fatal error (e.g., file not found). + // even if the initial load fails with a non-fatal error (file not found). b.cfg.configFilePath = b.file - // Load configuration + // 2. Load configuration loadErr := b.cfg.LoadWithOptions(b.file, b.args, b.opts) if loadErr != nil && !errors.Is(loadErr, ErrConfigNotFound) { // Return on fatal load errors. ErrConfigNotFound is not fatal. return nil, loadErr } - // Run validators + // 3. Run non-typed validators for _, validator := range b.validators { if err := validator(b.cfg); err != nil { return nil, fmt.Errorf("configuration validation failed: %w", err) } } + // 4. Populate target and run typed validators + if b.cfg.structCache != nil && b.cfg.structCache.target != nil && len(b.typedValidators) > 0 { + // Populate the target struct first. This unifies all types (e.g., string "8888" -> int64 8888). + populatedTarget, err := b.cfg.AsStruct() + if err != nil { + return nil, fmt.Errorf("failed to populate target struct for validation: %w", err) + } + + // Run the typed validators against the populated, type-safe struct. + for _, validator := range b.typedValidators { + validatorFunc := reflect.ValueOf(validator) + validatorType := validatorFunc.Type() + + // Check if the validator's input type matches the target's type. + if validatorType.In(0) != reflect.TypeOf(populatedTarget) { + return nil, fmt.Errorf("typed validator signature %v does not match target type %T", validatorType, populatedTarget) + } + + // Call the validator. + results := validatorFunc.Call([]reflect.Value{reflect.ValueOf(populatedTarget)}) + if !results[0].IsNil() { + err := results[0].Interface().(error) + return nil, fmt.Errorf("typed configuration validation failed: %w", err) + } + } + } + // ErrConfigNotFound or nil return b.cfg, loadErr } @@ -188,13 +217,6 @@ func (b *Builder) WithTarget(target any) *Builder { } } - // NOTE: removed since it would cause issues when an empty struct is passed - // TODO: may cause issue in other scenarios, test extensively - // // Register struct fields automatically - // if b.defaults == nil { - // b.defaults = target - // } - return b } @@ -207,4 +229,23 @@ func (b *Builder) WithValidator(fn ValidatorFunc) *Builder { b.validators = append(b.validators, fn) } return b +} + +// WithTypedValidator adds a type-safe validation function that runs at the end of the build process, +// after the target struct has been populated. The provided function must accept a single argument +// that is a pointer to the same type as the one provided to WithTarget, and must return an error. +func (b *Builder) WithTypedValidator(fn any) *Builder { + if fn == nil { + return b + } + + // Basic reflection check to ensure it's a function that takes one argument and returns an error. + t := reflect.TypeOf(fn) + if t.Kind() != reflect.Func || t.NumIn() != 1 || t.NumOut() != 1 || t.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + b.err = fmt.Errorf("WithTypedValidator requires a function with signature func(*T) error") + return b + } + + b.typedValidators = append(b.typedValidators, fn) + return b } \ No newline at end of file diff --git a/builder_test.go b/builder_test.go index c78c9aa..d9564ea 100644 --- a/builder_test.go +++ b/builder_test.go @@ -300,4 +300,63 @@ func TestFileDiscovery(t *testing.T) { val, _ := cfg.Get("test") assert.Equal(t, "clifile", val) }) +} + +func TestBuilderWithTypedValidator(t *testing.T) { + type Cfg struct { + Port int `toml:"port"` + } + + // Case 1: Valid configuration + t.Run("ValidTyped", func(t *testing.T) { + target := &Cfg{Port: 8080} + validator := func(c *Cfg) error { + if c.Port < 1024 { + return fmt.Errorf("port too low") + } + return nil + } + + _, err := NewBuilder(). + WithTarget(target). + WithTypedValidator(validator). + Build() + + require.NoError(t, err) + }) + + // Case 2: Invalid configuration + t.Run("InvalidTyped", func(t *testing.T) { + target := &Cfg{Port: 80} + validator := func(c *Cfg) error { + if c.Port < 1024 { + return fmt.Errorf("port too low") + } + return nil + } + + _, err := NewBuilder(). + WithTarget(target). + WithTypedValidator(validator). + Build() + + require.Error(t, err) + assert.Contains(t, err.Error(), "typed configuration validation failed: port too low") + }) + + // Case 3: Mismatched validator signature + t.Run("MismatchedSignature", func(t *testing.T) { + target := &Cfg{} + validator := func(c *struct{ Name string }) error { // Different type + return nil + } + + _, err := NewBuilder(). + WithTarget(target). + WithTypedValidator(validator). + Build() + + require.Error(t, err) + assert.Contains(t, err.Error(), "typed validator signature") + }) } \ No newline at end of file diff --git a/convenience.go b/convenience.go index 1572809..1e6e488 100644 --- a/convenience.go +++ b/convenience.go @@ -4,6 +4,7 @@ package config import ( "flag" "fmt" + "github.com/mitchellh/mapstructure" "os" "reflect" "strings" @@ -234,4 +235,51 @@ func QuickTyped[T any](target *T, envPrefix, configFile string) (*Config, error) WithEnvPrefix(envPrefix). WithFile(configFile). Build() +} + +// GetTyped retrieves a configuration value and decodes it into the specified type T. +// It leverages the same decoding hooks as the Scan and AsStruct methods, +// providing type conversion from strings, numbers, etc. +func GetTyped[T any](c *Config, path string) (T, error) { + var zero T + + rawValue, exists := c.Get(path) + if !exists { + return zero, fmt.Errorf("path %q not found", path) + } + + // Prepare the input map and target struct for the decoder. + inputMap := map[string]any{"value": rawValue} + var target struct { + Value T `mapstructure:"value"` + } + + // Create a new decoder configured with the same hooks as the main config. + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: &target, + TagName: c.tagName, + WeaklyTypedInput: true, + DecodeHook: c.getDecodeHook(), + Metadata: nil, + }) + if err != nil { + return zero, fmt.Errorf("failed to create decoder for path %q: %w", path, err) + } + + // Decode the single value. + if err := decoder.Decode(inputMap); err != nil { + return zero, fmt.Errorf("failed to decode value for path %q into type %T: %w", path, zero, err) + } + + return target.Value, nil +} + +// ScanTyped is a generic wrapper around Scan. It allocates a new instance of type T, +// populates it with configuration data from the given base path, and returns a pointer to it. +func ScanTyped[T any](c *Config, basePath string) (*T, error) { + var target T + if err := c.Scan(basePath, &target); err != nil { + return nil, err + } + return &target, nil } \ No newline at end of file diff --git a/convenience_test.go b/convenience_test.go index 803331b..3370cb6 100644 --- a/convenience_test.go +++ b/convenience_test.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -284,4 +285,43 @@ func TestClone(t *testing.T) { // Verify source data is copied sources := clone.GetSources("shared.value") assert.Equal(t, "envvalue", sources[SourceEnv]) +} + +func TestGenericHelpers(t *testing.T) { + cfg := New() + cfg.Register("server.host", "localhost") + cfg.Register("server.port", "8080") // Note: string value + cfg.Register("features.dark_mode", true) + cfg.Register("timeouts.read", "5s") + + t.Run("GetTyped", func(t *testing.T) { + port, err := GetTyped[int](cfg, "server.port") + require.NoError(t, err) + assert.Equal(t, 8080, port) + + host, err := GetTyped[string](cfg, "server.host") + require.NoError(t, err) + assert.Equal(t, "localhost", host) + + // Test with custom decode hook type + readTimeout, err := GetTyped[time.Duration](cfg, "timeouts.read") + require.NoError(t, err) + assert.Equal(t, 5*time.Second, readTimeout) + + _, err = GetTyped[int](cfg, "nonexistent.path") + assert.Error(t, err) + }) + + t.Run("ScanTyped", func(t *testing.T) { + type ServerConfig struct { + Host string `toml:"host"` + Port int `toml:"port"` + } + + serverConf, err := ScanTyped[ServerConfig](cfg, "server") + require.NoError(t, err) + require.NotNil(t, serverConf) + assert.Equal(t, "localhost", serverConf.Host) + assert.Equal(t, 8080, serverConf.Port) + }) } \ No newline at end of file diff --git a/decode.go b/decode.go index af4d6f6..5e16cf0 100644 --- a/decode.go +++ b/decode.go @@ -44,12 +44,13 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error { // Navigate to basePath section sectionData := navigateToPath(nestedMap, basePath) - // Ensure we have a map to decode - sectionMap, ok := sectionData.(map[string]any) - if !ok { + // Ensure we have a map to decode, normalizing if necessary. + sectionMap, err := normalizeMap(sectionData) + if err != nil { if sectionData == nil { - sectionMap = make(map[string]any) // Empty section + sectionMap = make(map[string]any) // Empty section is valid. } else { + // Path points to a non-map value, which is an error for Scan. return fmt.Errorf("path %q refers to non-map value (type %T)", basePath, sectionData) } } @@ -74,6 +75,66 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error { return nil } +// // Ensure we have a map to decode +// sectionMap, ok := sectionData.(map[string]any) +// if !ok { +// if sectionData == nil { +// sectionMap = make(map[string]any) // Empty section +// } else { +// return fmt.Errorf("path %q refers to non-map value (type %T)", basePath, sectionData) +// } +// } +// +// // Create decoder with comprehensive hooks +// decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ +// Result: target, +// TagName: c.tagName, +// WeaklyTypedInput: true, +// DecodeHook: c.getDecodeHook(), +// ZeroFields: true, +// Metadata: nil, +// }) +// if err != nil { +// return fmt.Errorf("decoder creation failed: %w", err) +// } +// +// if err := decoder.Decode(sectionMap); err != nil { +// return fmt.Errorf("decode failed for path %q: %w", basePath, err) +// } +// +// return nil +// } + +// normalizeMap ensures that the input data is a map[string]any for the decoder. +func normalizeMap(data any) (map[string]any, error) { + if data == nil { + return make(map[string]any), nil + } + + // If it's already the correct type, return it. + if m, ok := data.(map[string]any); ok { + return m, nil + } + + // Use reflection to handle other map types (e.g., map[string]bool) + v := reflect.ValueOf(data) + if v.Kind() == reflect.Map { + if v.Type().Key().Kind() != reflect.String { + return nil, fmt.Errorf("map keys must be strings, but got %v", v.Type().Key()) + } + + // Create a new map[string]any and copy the values. + normalized := make(map[string]any, v.Len()) + iter := v.MapRange() + for iter.Next() { + normalized[iter.Key().String()] = iter.Value().Interface() + } + return normalized, nil + } + + return nil, fmt.Errorf("expected a map but got %T", data) +} + // getDecodeHook returns the composite decode hook for all type conversions func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc { return mapstructure.ComposeDecodeHookFunc( @@ -217,4 +278,4 @@ func navigateToPath(nested map[string]any, path string) any { } return current -} +} \ No newline at end of file diff --git a/doc/access.md b/doc/access.md index 4da25c8..1be3767 100644 --- a/doc/access.md +++ b/doc/access.md @@ -95,6 +95,37 @@ cfg, _ := config.NewBuilder(). fmt.Println(config.Server.Port) ``` +### GetTyped + +Retrieves a single configuration value and decodes it to the specified type. + +```go +import "time" + +// Returns an int, converting from string "9090" if necessary. +port, err := config.GetTyped[int](cfg, "server.port") + +// Returns a time.Duration, converting from string "5m30s". +timeout, err := config.GetTyped[time.Duration](cfg, "server.timeout") +``` + +### ScanTyped + +A generic wrapper around `Scan` that allocates, populates, and returns a pointer to a struct of the specified type. + +```go +// Instead of: +// var dbConf DBConfig +// if err := cfg.Scan("database", &dbConf); err != nil { ... } + +// You can write: +dbConf, err := config.ScanTyped[DBConfig](cfg, "database") +if err != nil { + // ... +} +// dbConf is a *DBConfig``` +``` + ### Type-Aware Mode ```go diff --git a/doc/builder.md b/doc/builder.md index 6abba10..fb91be8 100644 --- a/doc/builder.md +++ b/doc/builder.md @@ -152,9 +152,10 @@ cfg, _ := config.NewBuilder(). ### WithValidator -Add validation functions that run after loading: +Add validation functions that run *before* the target struct is populated. These validators operate on the raw `*config.Config` object and are suitable for checking required paths or formats before type conversion. ```go +// Validator runs on raw, pre-decoded values. cfg, _ := config.NewBuilder(). WithDefaults(defaults). WithValidator(func(c *config.Config) error { @@ -172,6 +173,34 @@ cfg, _ := config.NewBuilder(). Build() ``` +For type-safe validation, see `WithTypedValidator`. + +### WithTypedValidator + +Add a type-safe validation function that runs *after* the configuration has been fully loaded and decoded into the target struct (set by `WithTarget`). This is the recommended approach for most validation logic. + +The validation function must accept a single argument: a pointer to the same struct type that was passed to `WithTarget`. +```go +type AppConfig struct { + Server struct { + Port int64 `toml:"port"` + } `toml:"server"` +} + +var target AppConfig + +cfg, err := config.NewBuilder(). + WithTarget(&target). + WithFile("config.toml"). + WithTypedValidator(func(conf *AppConfig) error { + if conf.Server.Port < 1024 || conf.Server.Port > 65535 { + return fmt.Errorf("port %d is outside the valid range", conf.Server.Port) + } + return nil + }). + Build() +``` + ### WithFile Set configuration file path: diff --git a/example/main.go b/example/main.go index c8c0ee8..7d66f38 100644 --- a/example/main.go +++ b/example/main.go @@ -5,11 +5,10 @@ import ( "fmt" "log" "os" - "strconv" "sync" "time" - "config" + "github.com/lixenwraith/config" ) // AppConfig defines a richer configuration structure to showcase more features. @@ -68,42 +67,20 @@ func main() { // and keep it updated when using `AsStruct()`. target := &AppConfig{} - // Define a custom validator function. - validator := func(c *config.Config) error { - p, _ := c.Get("server.port") - // 'p' can be an int64 (from defaults/TOML) or a string (from environment variables). - - var port int64 - var err error - - switch v := p.(type) { - case string: - // If it's a string from an env var, parse it. - port, err = strconv.ParseInt(v, 10, 64) - if err != nil { - return fmt.Errorf("could not parse port from string '%s': %w", v, err) - } - case int64: - // If it's already an int64, just use it. - port = v - default: - // Handle any other unexpected types. - return fmt.Errorf("unexpected type for server.port: %T", p) - } - - if port < 1024 || port > 65535 { - return fmt.Errorf("port %d is outside the recommended range (1024-65535)", port) - } - return nil - } - // Use the builder to chain multiple configuration options. builder := config.NewBuilder(). - WithTarget(target). // Enables type-safe `AsStruct()` and auto-registration. - WithDefaults(initialData). // Explicitly set the source of defaults. - WithFile(configFilePath). // Specifies the config file to read. - WithEnvPrefix("APP_"). // Sets prefix for environment variables (e.g., APP_SERVER_PORT). - WithValidator(validator) // Adds a validation function to run at the end of the build. + WithTarget(target). // Enables type-safe `AsStruct()` and auto-registration. + WithDefaults(initialData). // Explicitly set the source of defaults. + WithFile(configFilePath). // Specifies the config file to read. + WithEnvPrefix("APP_"). // Sets prefix for environment variables (e.g., APP_SERVER_PORT). + WithTypedValidator(func(cfg *AppConfig) error { // <-- NEW METHOD + // No type assertion needed! `cfg.Server.Port` is guaranteed to be an int64 + // because the validator runs *after* the target struct is populated. + if cfg.Server.Port < 1024 || cfg.Server.Port > 65535 { + return fmt.Errorf("port %d is outside the recommended range (1024-65535)", cfg.Server.Port) + } + return nil + }) // Build the final config object. cfg, err := builder.Build() @@ -175,44 +152,37 @@ func createInitialConfigFile(data *AppConfig) error { return cfg.Save(configFilePath) } -// modifyFileOnDiskStructurally simulates an external program robustly changing the config file. +// modifyFileOnDiskStructurally simulates an external program that changes the config file. func modifyFileOnDiskStructurally(wg *sync.WaitGroup) { defer wg.Done() time.Sleep(1 * time.Second) log.Println(" (Modifier goroutine: now changing file on disk...)") + // Create a new, independent config instance to simulate an external process. modifierCfg := config.New() + // Register the struct shape so the loader knows what paths are valid. if err := modifierCfg.RegisterStruct("", &AppConfig{}); err != nil { log.Fatalf("❌ Modifier failed to register struct: %v", err) } + // Load the current state from disk. if err := modifierCfg.LoadFile(configFilePath); err != nil { log.Fatalf("❌ Modifier failed to load file: %v", err) } - // Change the log level and add a new feature flag. + // Change the log level. modifierCfg.Set("server.log_level", "debug") - rawFlags, _ := modifierCfg.Get("feature_flags") - newFlags := make(map[string]any) - - // Use a type switch to robustly handle the map, regardless of its source. - switch flags := rawFlags.(type) { - case map[string]bool: - for k, v := range flags { - newFlags[k] = v - } - case map[string]any: - for k, v := range flags { - newFlags[k] = v - } - default: - log.Fatalf("❌ Modifier encountered unexpected type for feature_flags: %T", rawFlags) + // Use the generic GetTyped function. This is safe because modifierCfg has loaded the file. + featureFlags, err := config.GetTyped[map[string]bool](modifierCfg, "feature_flags") + if err != nil { + log.Fatalf("❌ Modifier failed to get typed feature_flags: %v", err) } - // Now modify the generic map and set it back. - newFlags["enable_tracing"] = false - modifierCfg.Set("feature_flags", newFlags) + // Modify the typed map and set it back. + featureFlags["enable_metrics"] = false + modifierCfg.Set("feature_flags", featureFlags) + // Save the changes back to disk, which will trigger the watcher in the main goroutine. if err := modifierCfg.Save(configFilePath); err != nil { log.Fatalf("❌ Modifier failed to save file: %v", err) } @@ -229,4 +199,4 @@ func printCurrentState(cfg *AppConfig, title string) { fmt.Printf(" Server Log Level: %s\n", cfg.Server.LogLevel) fmt.Printf(" Feature Flags: %v\n", cfg.FeatureFlags) fmt.Println(" --------------------------------------------------") -} \ No newline at end of file +} diff --git a/go.mod b/go.mod index d541f78..83d5b1d 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module config +module github.com/lixenwraith/config go 1.24.5 diff --git a/loader.go b/loader.go index a2f56dc..1f2ee9b 100644 --- a/loader.go +++ b/loader.go @@ -143,6 +143,7 @@ func (c *Config) LoadFile(filePath string) error { // loadFile reads and parses a TOML configuration file func (c *Config) loadFile(path string) error { + // 1. Read and Parse (No Lock) fileData, err := os.ReadFile(path) if err != nil { if errors.Is(err, os.ErrNotExist) { @@ -156,36 +157,58 @@ func (c *Config) loadFile(path string) error { return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err) } - // Flatten and apply file data - flattenedFileConfig := flattenMap(fileConfig, "") + // 2. Prepare New State (Read-Lock Only) + newFileData := make(map[string]any) + // Briefly acquire a read-lock to safely get the list of registered paths. + c.mutex.RLock() + registeredPaths := make(map[string]bool, len(c.items)) + for p := range c.items { + registeredPaths[p] = true + } + c.mutex.RUnlock() + + // Define a recursive function to populate newFileData. This runs without any lock. + var apply func(prefix string, data map[string]any) + apply = func(prefix string, data map[string]any) { + for key, value := range data { + fullPath := key + if prefix != "" { + fullPath = prefix + "." + key + } + if registeredPaths[fullPath] { + newFileData[fullPath] = value + } else if subMap, isMap := value.(map[string]any); isMap { + apply(fullPath, subMap) + } + } + } + apply("", fileConfig) + + // -- 3. Atomically Update Config (Write-Lock) c.mutex.Lock() defer c.mutex.Unlock() - // Track the config file path for watching c.configFilePath = path + c.fileData = newFileData - defer c.invalidateCache() // Invalidate cache after changes - - // Store in cache - c.fileData = flattenedFileConfig - - // Apply to registered paths - for path, value := range flattenedFileConfig { - if item, exists := c.items[path]; exists { + // Apply the new state to the main config items. + for path, item := range c.items { + if value, exists := newFileData[path]; exists { if item.values == nil { item.values = make(map[Source]any) } - if str, ok := value.(string); ok && len(str) > MaxValueSize { - return ErrValueSize - } item.values[SourceFile] = value - item.currentValue = c.computeValue(path, item) - c.items[path] = item + } else { + // Key was not in the new file, so remove its old file-sourced value. + delete(item.values, SourceFile) } - // Ignore unregistered paths from file + // Recompute the current value based on new source precedence. + item.currentValue = c.computeValue(path, item) + c.items[path] = item } + c.invalidateCache() return nil } @@ -196,27 +219,47 @@ func (c *Config) loadEnv(opts LoadOptions) error { transform = defaultEnvTransform(opts.EnvPrefix) } - c.mutex.Lock() - defer c.mutex.Unlock() + // -- 1. Prepare data (Read-Lock to get paths) + c.mutex.RLock() + paths := make([]string, 0, len(c.items)) + for p := range c.items { + paths = append(paths, p) + } + c.mutex.RUnlock() - defer c.invalidateCache() // Invalidate cache after changes - - c.envData = make(map[string]any) - - for path, item := range c.items { + // -- 2. Process env vars (No Lock) + foundEnvVars := make(map[string]string) + for _, path := range paths { if opts.EnvWhitelist != nil && !opts.EnvWhitelist[path] { continue } envVar := transform(path) if value, exists := os.LookupEnv(envVar); exists { - // Store raw string value - mapstructure will handle conversion - if item.values == nil { - item.values = make(map[Source]any) - } if len(value) > MaxValueSize { return ErrValueSize } + foundEnvVars[path] = value + } + } + + // If no relevant env vars were found, we are done. + if len(foundEnvVars) == 0 { + return nil + } + + // -- 3. Atomically update config (Write-Lock) + c.mutex.Lock() + defer c.mutex.Unlock() + + c.envData = make(map[string]any, len(foundEnvVars)) + + for path, value := range foundEnvVars { + // Store raw string value - mapstructure will handle conversion later. + if item, exists := c.items[path]; exists { + if item.values == nil { + item.values = make(map[Source]any) + } item.values[SourceEnv] = value // Store as string item.currentValue = c.computeValue(path, item) c.items[path] = item @@ -224,18 +267,24 @@ func (c *Config) loadEnv(opts LoadOptions) error { } } + c.invalidateCache() return nil } // loadCLI loads configuration from command-line arguments func (c *Config) loadCLI(args []string) error { + // -- 1. Prepare data (No Lock) parsedCLI, err := parseArgs(args) if err != nil { return fmt.Errorf("%w: %w", ErrCLIParse, err) } flattenedCLI := flattenMap(parsedCLI, "") + if len(flattenedCLI) == 0 { + return nil // No CLI args to process. + } + // 2. Atomically update config (Write-Lock) c.mutex.Lock() defer c.mutex.Unlock() @@ -252,7 +301,7 @@ func (c *Config) loadCLI(args []string) error { } } - c.invalidateCache() // Invalidate cache after changes + c.invalidateCache() return nil } diff --git a/watch.go b/watch.go index d1a47ab..d155a20 100644 --- a/watch.go +++ b/watch.go @@ -9,8 +9,6 @@ import ( "sync" "sync/atomic" "time" - - "github.com/BurntSushi/toml" ) // WatchOptions configures file watching behavior @@ -132,6 +130,30 @@ func (c *Config) Watch() <-chan string { return c.WatchWithOptions(DefaultWatchOptions()) } +// WatchFile stops any existing file watcher, loads a new configuration file, +// and starts a new watcher on that file path. +func (c *Config) WatchFile(filePath string) error { + // Stop any currently running watcher to prevent orphaned goroutines. + c.StopAutoUpdate() + + // Load the new file and set `configFilePath` to the new path + if err := c.LoadFile(filePath); err != nil { + return fmt.Errorf("failed to load new file for watching: %w", err) + } + + // Start a new watcher on the new file + c.mutex.RLock() + opts := DefaultWatchOptions() + if c.watcher != nil { + opts = c.watcher.opts + } + c.mutex.RUnlock() + + c.AutoUpdateWithOptions(opts) + + return nil +} + // WatchWithOptions returns a channel with custom watch options func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string { // First ensure auto-update is running @@ -260,7 +282,7 @@ func (w *watcher) performReload(c *Config) { // Reload file in a goroutine with timeout done := make(chan error, 1) go func() { - done <- c.reloadFileAtomic(w.filePath) + done <- c.loadFile(w.filePath) }() select { @@ -372,59 +394,4 @@ func (c *Config) snapshot() map[string]any { snapshot[path] = item.currentValue } return snapshot -} - -// reloadFileAtomic atomically reloads the configuration file -func (c *Config) reloadFileAtomic(filePath string) error { - // Read file - fileData, err := os.ReadFile(filePath) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - - // SECURITY: Check file size to prevent DoS - if len(fileData) > MaxValueSize*10 { // 10MB max for config file - return fmt.Errorf("config file too large: %d bytes", len(fileData)) - } - - // Parse TOML - fileConfig := make(map[string]any) - if err := toml.Unmarshal(fileData, &fileConfig); err != nil { - return fmt.Errorf("failed to parse TOML: %w", err) - } - - // Flatten the configuration - flattenedFileConfig := flattenMap(fileConfig, "") - - // Apply atomically - c.mutex.Lock() - defer c.mutex.Unlock() - - // Clear old file data - c.fileData = make(map[string]any) - - // Apply new values - for path, value := range flattenedFileConfig { - if item, exists := c.items[path]; exists { - if item.values == nil { - item.values = make(map[Source]any) - } - item.values[SourceFile] = value - item.currentValue = c.computeValue(path, item) - c.items[path] = item - c.fileData[path] = value - } - } - - // Remove file values not in new config - for path, item := range c.items { - if _, exists := flattenedFileConfig[path]; !exists { - delete(item.values, SourceFile) - item.currentValue = c.computeValue(path, item) - c.items[path] = item - } - } - - c.invalidateCache() - return nil } \ No newline at end of file