e5.2.0 Decoder and loader refactored, bug fixes.

This commit is contained in:
2025-07-20 02:09:32 -04:00
parent 573eef8d78
commit 06cddbe00e
12 changed files with 474 additions and 178 deletions

3
.gitignore vendored
View File

@ -5,4 +5,5 @@ log
logs
script
*.log
bin
*.toml
bin

View File

@ -11,15 +11,16 @@ import (
// Builder provides a fluent API for constructing a Config instance. It allows for
// chaining configuration options before final build of the config object.
type Builder struct {
cfg *Config
opts LoadOptions
defaults any
tagName string
prefix string
file string
args []string
err error
validators []ValidatorFunc
cfg *Config
opts LoadOptions
defaults any
tagName string
prefix string
file string
args []string
err error
validators []ValidatorFunc
typedValidators []any
}
// ValidatorFunc defines the signature for a function that can validate a Config instance.
@ -29,10 +30,11 @@ type ValidatorFunc func(c *Config) error
// NewBuilder creates a new configuration builder
func NewBuilder() *Builder {
return &Builder{
cfg: New(),
opts: DefaultLoadOptions(),
args: os.Args[1:],
validators: make([]ValidatorFunc, 0),
cfg: New(),
opts: DefaultLoadOptions(),
args: os.Args[1:],
validators: make([]ValidatorFunc, 0),
typedValidators: make([]any, 0),
}
}
@ -48,9 +50,9 @@ func (b *Builder) Build() (*Config, error) {
tagName = "toml"
}
// The logic for registering defaults must be prioritized:
// 1. If WithDefaults() was called, it takes precedence.
// 2. If not, but WithTarget() was called, use the target struct for defaults.
// 1. Register defaults
// If WithDefaults() was called, it takes precedence.
// If not, but WithTarget() was called, use the target struct for defaults.
if b.defaults != nil {
// WithDefaults() was called explicitly.
if err := b.cfg.RegisterStructWithTags(b.prefix, b.defaults, tagName); err != nil {
@ -65,23 +67,50 @@ func (b *Builder) Build() (*Config, error) {
}
// Explicitly set the file path on the config object so the watcher can find it,
// even if the initial load fails with a non-fatal error (e.g., file not found).
// even if the initial load fails with a non-fatal error (file not found).
b.cfg.configFilePath = b.file
// Load configuration
// 2. Load configuration
loadErr := b.cfg.LoadWithOptions(b.file, b.args, b.opts)
if loadErr != nil && !errors.Is(loadErr, ErrConfigNotFound) {
// Return on fatal load errors. ErrConfigNotFound is not fatal.
return nil, loadErr
}
// Run validators
// 3. Run non-typed validators
for _, validator := range b.validators {
if err := validator(b.cfg); err != nil {
return nil, fmt.Errorf("configuration validation failed: %w", err)
}
}
// 4. Populate target and run typed validators
if b.cfg.structCache != nil && b.cfg.structCache.target != nil && len(b.typedValidators) > 0 {
// Populate the target struct first. This unifies all types (e.g., string "8888" -> int64 8888).
populatedTarget, err := b.cfg.AsStruct()
if err != nil {
return nil, fmt.Errorf("failed to populate target struct for validation: %w", err)
}
// Run the typed validators against the populated, type-safe struct.
for _, validator := range b.typedValidators {
validatorFunc := reflect.ValueOf(validator)
validatorType := validatorFunc.Type()
// Check if the validator's input type matches the target's type.
if validatorType.In(0) != reflect.TypeOf(populatedTarget) {
return nil, fmt.Errorf("typed validator signature %v does not match target type %T", validatorType, populatedTarget)
}
// Call the validator.
results := validatorFunc.Call([]reflect.Value{reflect.ValueOf(populatedTarget)})
if !results[0].IsNil() {
err := results[0].Interface().(error)
return nil, fmt.Errorf("typed configuration validation failed: %w", err)
}
}
}
// ErrConfigNotFound or nil
return b.cfg, loadErr
}
@ -188,13 +217,6 @@ func (b *Builder) WithTarget(target any) *Builder {
}
}
// NOTE: removed since it would cause issues when an empty struct is passed
// TODO: may cause issue in other scenarios, test extensively
// // Register struct fields automatically
// if b.defaults == nil {
// b.defaults = target
// }
return b
}
@ -207,4 +229,23 @@ func (b *Builder) WithValidator(fn ValidatorFunc) *Builder {
b.validators = append(b.validators, fn)
}
return b
}
// WithTypedValidator adds a type-safe validation function that runs at the end of the build process,
// after the target struct has been populated. The provided function must accept a single argument
// that is a pointer to the same type as the one provided to WithTarget, and must return an error.
func (b *Builder) WithTypedValidator(fn any) *Builder {
if fn == nil {
return b
}
// Basic reflection check to ensure it's a function that takes one argument and returns an error.
t := reflect.TypeOf(fn)
if t.Kind() != reflect.Func || t.NumIn() != 1 || t.NumOut() != 1 || t.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
b.err = fmt.Errorf("WithTypedValidator requires a function with signature func(*T) error")
return b
}
b.typedValidators = append(b.typedValidators, fn)
return b
}

View File

@ -300,4 +300,63 @@ func TestFileDiscovery(t *testing.T) {
val, _ := cfg.Get("test")
assert.Equal(t, "clifile", val)
})
}
func TestBuilderWithTypedValidator(t *testing.T) {
type Cfg struct {
Port int `toml:"port"`
}
// Case 1: Valid configuration
t.Run("ValidTyped", func(t *testing.T) {
target := &Cfg{Port: 8080}
validator := func(c *Cfg) error {
if c.Port < 1024 {
return fmt.Errorf("port too low")
}
return nil
}
_, err := NewBuilder().
WithTarget(target).
WithTypedValidator(validator).
Build()
require.NoError(t, err)
})
// Case 2: Invalid configuration
t.Run("InvalidTyped", func(t *testing.T) {
target := &Cfg{Port: 80}
validator := func(c *Cfg) error {
if c.Port < 1024 {
return fmt.Errorf("port too low")
}
return nil
}
_, err := NewBuilder().
WithTarget(target).
WithTypedValidator(validator).
Build()
require.Error(t, err)
assert.Contains(t, err.Error(), "typed configuration validation failed: port too low")
})
// Case 3: Mismatched validator signature
t.Run("MismatchedSignature", func(t *testing.T) {
target := &Cfg{}
validator := func(c *struct{ Name string }) error { // Different type
return nil
}
_, err := NewBuilder().
WithTarget(target).
WithTypedValidator(validator).
Build()
require.Error(t, err)
assert.Contains(t, err.Error(), "typed validator signature")
})
}

View File

@ -4,6 +4,7 @@ package config
import (
"flag"
"fmt"
"github.com/mitchellh/mapstructure"
"os"
"reflect"
"strings"
@ -234,4 +235,51 @@ func QuickTyped[T any](target *T, envPrefix, configFile string) (*Config, error)
WithEnvPrefix(envPrefix).
WithFile(configFile).
Build()
}
// GetTyped retrieves a configuration value and decodes it into the specified type T.
// It leverages the same decoding hooks as the Scan and AsStruct methods,
// providing type conversion from strings, numbers, etc.
func GetTyped[T any](c *Config, path string) (T, error) {
var zero T
rawValue, exists := c.Get(path)
if !exists {
return zero, fmt.Errorf("path %q not found", path)
}
// Prepare the input map and target struct for the decoder.
inputMap := map[string]any{"value": rawValue}
var target struct {
Value T `mapstructure:"value"`
}
// Create a new decoder configured with the same hooks as the main config.
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &target,
TagName: c.tagName,
WeaklyTypedInput: true,
DecodeHook: c.getDecodeHook(),
Metadata: nil,
})
if err != nil {
return zero, fmt.Errorf("failed to create decoder for path %q: %w", path, err)
}
// Decode the single value.
if err := decoder.Decode(inputMap); err != nil {
return zero, fmt.Errorf("failed to decode value for path %q into type %T: %w", path, zero, err)
}
return target.Value, nil
}
// ScanTyped is a generic wrapper around Scan. It allocates a new instance of type T,
// populates it with configuration data from the given base path, and returns a pointer to it.
func ScanTyped[T any](c *Config, basePath string) (*T, error) {
var target T
if err := c.Scan(basePath, &target); err != nil {
return nil, err
}
return &target, nil
}

View File

@ -6,6 +6,7 @@ import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -284,4 +285,43 @@ func TestClone(t *testing.T) {
// Verify source data is copied
sources := clone.GetSources("shared.value")
assert.Equal(t, "envvalue", sources[SourceEnv])
}
func TestGenericHelpers(t *testing.T) {
cfg := New()
cfg.Register("server.host", "localhost")
cfg.Register("server.port", "8080") // Note: string value
cfg.Register("features.dark_mode", true)
cfg.Register("timeouts.read", "5s")
t.Run("GetTyped", func(t *testing.T) {
port, err := GetTyped[int](cfg, "server.port")
require.NoError(t, err)
assert.Equal(t, 8080, port)
host, err := GetTyped[string](cfg, "server.host")
require.NoError(t, err)
assert.Equal(t, "localhost", host)
// Test with custom decode hook type
readTimeout, err := GetTyped[time.Duration](cfg, "timeouts.read")
require.NoError(t, err)
assert.Equal(t, 5*time.Second, readTimeout)
_, err = GetTyped[int](cfg, "nonexistent.path")
assert.Error(t, err)
})
t.Run("ScanTyped", func(t *testing.T) {
type ServerConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
}
serverConf, err := ScanTyped[ServerConfig](cfg, "server")
require.NoError(t, err)
require.NotNil(t, serverConf)
assert.Equal(t, "localhost", serverConf.Host)
assert.Equal(t, 8080, serverConf.Port)
})
}

View File

@ -44,12 +44,13 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error {
// Navigate to basePath section
sectionData := navigateToPath(nestedMap, basePath)
// Ensure we have a map to decode
sectionMap, ok := sectionData.(map[string]any)
if !ok {
// Ensure we have a map to decode, normalizing if necessary.
sectionMap, err := normalizeMap(sectionData)
if err != nil {
if sectionData == nil {
sectionMap = make(map[string]any) // Empty section
sectionMap = make(map[string]any) // Empty section is valid.
} else {
// Path points to a non-map value, which is an error for Scan.
return fmt.Errorf("path %q refers to non-map value (type %T)", basePath, sectionData)
}
}
@ -74,6 +75,66 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error {
return nil
}
// // Ensure we have a map to decode
// sectionMap, ok := sectionData.(map[string]any)
// if !ok {
// if sectionData == nil {
// sectionMap = make(map[string]any) // Empty section
// } else {
// return fmt.Errorf("path %q refers to non-map value (type %T)", basePath, sectionData)
// }
// }
//
// // Create decoder with comprehensive hooks
// decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
// Result: target,
// TagName: c.tagName,
// WeaklyTypedInput: true,
// DecodeHook: c.getDecodeHook(),
// ZeroFields: true,
// Metadata: nil,
// })
// if err != nil {
// return fmt.Errorf("decoder creation failed: %w", err)
// }
//
// if err := decoder.Decode(sectionMap); err != nil {
// return fmt.Errorf("decode failed for path %q: %w", basePath, err)
// }
//
// return nil
// }
// normalizeMap ensures that the input data is a map[string]any for the decoder.
func normalizeMap(data any) (map[string]any, error) {
if data == nil {
return make(map[string]any), nil
}
// If it's already the correct type, return it.
if m, ok := data.(map[string]any); ok {
return m, nil
}
// Use reflection to handle other map types (e.g., map[string]bool)
v := reflect.ValueOf(data)
if v.Kind() == reflect.Map {
if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("map keys must be strings, but got %v", v.Type().Key())
}
// Create a new map[string]any and copy the values.
normalized := make(map[string]any, v.Len())
iter := v.MapRange()
for iter.Next() {
normalized[iter.Key().String()] = iter.Value().Interface()
}
return normalized, nil
}
return nil, fmt.Errorf("expected a map but got %T", data)
}
// getDecodeHook returns the composite decode hook for all type conversions
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
return mapstructure.ComposeDecodeHookFunc(
@ -217,4 +278,4 @@ func navigateToPath(nested map[string]any, path string) any {
}
return current
}
}

View File

@ -95,6 +95,37 @@ cfg, _ := config.NewBuilder().
fmt.Println(config.Server.Port)
```
### GetTyped
Retrieves a single configuration value and decodes it to the specified type.
```go
import "time"
// Returns an int, converting from string "9090" if necessary.
port, err := config.GetTyped[int](cfg, "server.port")
// Returns a time.Duration, converting from string "5m30s".
timeout, err := config.GetTyped[time.Duration](cfg, "server.timeout")
```
### ScanTyped
A generic wrapper around `Scan` that allocates, populates, and returns a pointer to a struct of the specified type.
```go
// Instead of:
// var dbConf DBConfig
// if err := cfg.Scan("database", &dbConf); err != nil { ... }
// You can write:
dbConf, err := config.ScanTyped[DBConfig](cfg, "database")
if err != nil {
// ...
}
// dbConf is a *DBConfig```
```
### Type-Aware Mode
```go

View File

@ -152,9 +152,10 @@ cfg, _ := config.NewBuilder().
### WithValidator
Add validation functions that run after loading:
Add validation functions that run *before* the target struct is populated. These validators operate on the raw `*config.Config` object and are suitable for checking required paths or formats before type conversion.
```go
// Validator runs on raw, pre-decoded values.
cfg, _ := config.NewBuilder().
WithDefaults(defaults).
WithValidator(func(c *config.Config) error {
@ -172,6 +173,34 @@ cfg, _ := config.NewBuilder().
Build()
```
For type-safe validation, see `WithTypedValidator`.
### WithTypedValidator
Add a type-safe validation function that runs *after* the configuration has been fully loaded and decoded into the target struct (set by `WithTarget`). This is the recommended approach for most validation logic.
The validation function must accept a single argument: a pointer to the same struct type that was passed to `WithTarget`.
```go
type AppConfig struct {
Server struct {
Port int64 `toml:"port"`
} `toml:"server"`
}
var target AppConfig
cfg, err := config.NewBuilder().
WithTarget(&target).
WithFile("config.toml").
WithTypedValidator(func(conf *AppConfig) error {
if conf.Server.Port < 1024 || conf.Server.Port > 65535 {
return fmt.Errorf("port %d is outside the valid range", conf.Server.Port)
}
return nil
}).
Build()
```
### WithFile
Set configuration file path:

View File

@ -5,11 +5,10 @@ import (
"fmt"
"log"
"os"
"strconv"
"sync"
"time"
"config"
"github.com/lixenwraith/config"
)
// AppConfig defines a richer configuration structure to showcase more features.
@ -68,42 +67,20 @@ func main() {
// and keep it updated when using `AsStruct()`.
target := &AppConfig{}
// Define a custom validator function.
validator := func(c *config.Config) error {
p, _ := c.Get("server.port")
// 'p' can be an int64 (from defaults/TOML) or a string (from environment variables).
var port int64
var err error
switch v := p.(type) {
case string:
// If it's a string from an env var, parse it.
port, err = strconv.ParseInt(v, 10, 64)
if err != nil {
return fmt.Errorf("could not parse port from string '%s': %w", v, err)
}
case int64:
// If it's already an int64, just use it.
port = v
default:
// Handle any other unexpected types.
return fmt.Errorf("unexpected type for server.port: %T", p)
}
if port < 1024 || port > 65535 {
return fmt.Errorf("port %d is outside the recommended range (1024-65535)", port)
}
return nil
}
// Use the builder to chain multiple configuration options.
builder := config.NewBuilder().
WithTarget(target). // Enables type-safe `AsStruct()` and auto-registration.
WithDefaults(initialData). // Explicitly set the source of defaults.
WithFile(configFilePath). // Specifies the config file to read.
WithEnvPrefix("APP_"). // Sets prefix for environment variables (e.g., APP_SERVER_PORT).
WithValidator(validator) // Adds a validation function to run at the end of the build.
WithTarget(target). // Enables type-safe `AsStruct()` and auto-registration.
WithDefaults(initialData). // Explicitly set the source of defaults.
WithFile(configFilePath). // Specifies the config file to read.
WithEnvPrefix("APP_"). // Sets prefix for environment variables (e.g., APP_SERVER_PORT).
WithTypedValidator(func(cfg *AppConfig) error { // <-- NEW METHOD
// No type assertion needed! `cfg.Server.Port` is guaranteed to be an int64
// because the validator runs *after* the target struct is populated.
if cfg.Server.Port < 1024 || cfg.Server.Port > 65535 {
return fmt.Errorf("port %d is outside the recommended range (1024-65535)", cfg.Server.Port)
}
return nil
})
// Build the final config object.
cfg, err := builder.Build()
@ -175,44 +152,37 @@ func createInitialConfigFile(data *AppConfig) error {
return cfg.Save(configFilePath)
}
// modifyFileOnDiskStructurally simulates an external program robustly changing the config file.
// modifyFileOnDiskStructurally simulates an external program that changes the config file.
func modifyFileOnDiskStructurally(wg *sync.WaitGroup) {
defer wg.Done()
time.Sleep(1 * time.Second)
log.Println(" (Modifier goroutine: now changing file on disk...)")
// Create a new, independent config instance to simulate an external process.
modifierCfg := config.New()
// Register the struct shape so the loader knows what paths are valid.
if err := modifierCfg.RegisterStruct("", &AppConfig{}); err != nil {
log.Fatalf("❌ Modifier failed to register struct: %v", err)
}
// Load the current state from disk.
if err := modifierCfg.LoadFile(configFilePath); err != nil {
log.Fatalf("❌ Modifier failed to load file: %v", err)
}
// Change the log level and add a new feature flag.
// Change the log level.
modifierCfg.Set("server.log_level", "debug")
rawFlags, _ := modifierCfg.Get("feature_flags")
newFlags := make(map[string]any)
// Use a type switch to robustly handle the map, regardless of its source.
switch flags := rawFlags.(type) {
case map[string]bool:
for k, v := range flags {
newFlags[k] = v
}
case map[string]any:
for k, v := range flags {
newFlags[k] = v
}
default:
log.Fatalf("❌ Modifier encountered unexpected type for feature_flags: %T", rawFlags)
// Use the generic GetTyped function. This is safe because modifierCfg has loaded the file.
featureFlags, err := config.GetTyped[map[string]bool](modifierCfg, "feature_flags")
if err != nil {
log.Fatalf("❌ Modifier failed to get typed feature_flags: %v", err)
}
// Now modify the generic map and set it back.
newFlags["enable_tracing"] = false
modifierCfg.Set("feature_flags", newFlags)
// Modify the typed map and set it back.
featureFlags["enable_metrics"] = false
modifierCfg.Set("feature_flags", featureFlags)
// Save the changes back to disk, which will trigger the watcher in the main goroutine.
if err := modifierCfg.Save(configFilePath); err != nil {
log.Fatalf("❌ Modifier failed to save file: %v", err)
}
@ -229,4 +199,4 @@ func printCurrentState(cfg *AppConfig, title string) {
fmt.Printf(" Server Log Level: %s\n", cfg.Server.LogLevel)
fmt.Printf(" Feature Flags: %v\n", cfg.FeatureFlags)
fmt.Println(" --------------------------------------------------")
}
}

2
go.mod
View File

@ -1,4 +1,4 @@
module config
module github.com/lixenwraith/config
go 1.24.5

107
loader.go
View File

@ -143,6 +143,7 @@ func (c *Config) LoadFile(filePath string) error {
// loadFile reads and parses a TOML configuration file
func (c *Config) loadFile(path string) error {
// 1. Read and Parse (No Lock)
fileData, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
@ -156,36 +157,58 @@ func (c *Config) loadFile(path string) error {
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
}
// Flatten and apply file data
flattenedFileConfig := flattenMap(fileConfig, "")
// 2. Prepare New State (Read-Lock Only)
newFileData := make(map[string]any)
// Briefly acquire a read-lock to safely get the list of registered paths.
c.mutex.RLock()
registeredPaths := make(map[string]bool, len(c.items))
for p := range c.items {
registeredPaths[p] = true
}
c.mutex.RUnlock()
// Define a recursive function to populate newFileData. This runs without any lock.
var apply func(prefix string, data map[string]any)
apply = func(prefix string, data map[string]any) {
for key, value := range data {
fullPath := key
if prefix != "" {
fullPath = prefix + "." + key
}
if registeredPaths[fullPath] {
newFileData[fullPath] = value
} else if subMap, isMap := value.(map[string]any); isMap {
apply(fullPath, subMap)
}
}
}
apply("", fileConfig)
// -- 3. Atomically Update Config (Write-Lock)
c.mutex.Lock()
defer c.mutex.Unlock()
// Track the config file path for watching
c.configFilePath = path
c.fileData = newFileData
defer c.invalidateCache() // Invalidate cache after changes
// Store in cache
c.fileData = flattenedFileConfig
// Apply to registered paths
for path, value := range flattenedFileConfig {
if item, exists := c.items[path]; exists {
// Apply the new state to the main config items.
for path, item := range c.items {
if value, exists := newFileData[path]; exists {
if item.values == nil {
item.values = make(map[Source]any)
}
if str, ok := value.(string); ok && len(str) > MaxValueSize {
return ErrValueSize
}
item.values[SourceFile] = value
item.currentValue = c.computeValue(path, item)
c.items[path] = item
} else {
// Key was not in the new file, so remove its old file-sourced value.
delete(item.values, SourceFile)
}
// Ignore unregistered paths from file
// Recompute the current value based on new source precedence.
item.currentValue = c.computeValue(path, item)
c.items[path] = item
}
c.invalidateCache()
return nil
}
@ -196,27 +219,47 @@ func (c *Config) loadEnv(opts LoadOptions) error {
transform = defaultEnvTransform(opts.EnvPrefix)
}
c.mutex.Lock()
defer c.mutex.Unlock()
// -- 1. Prepare data (Read-Lock to get paths)
c.mutex.RLock()
paths := make([]string, 0, len(c.items))
for p := range c.items {
paths = append(paths, p)
}
c.mutex.RUnlock()
defer c.invalidateCache() // Invalidate cache after changes
c.envData = make(map[string]any)
for path, item := range c.items {
// -- 2. Process env vars (No Lock)
foundEnvVars := make(map[string]string)
for _, path := range paths {
if opts.EnvWhitelist != nil && !opts.EnvWhitelist[path] {
continue
}
envVar := transform(path)
if value, exists := os.LookupEnv(envVar); exists {
// Store raw string value - mapstructure will handle conversion
if item.values == nil {
item.values = make(map[Source]any)
}
if len(value) > MaxValueSize {
return ErrValueSize
}
foundEnvVars[path] = value
}
}
// If no relevant env vars were found, we are done.
if len(foundEnvVars) == 0 {
return nil
}
// -- 3. Atomically update config (Write-Lock)
c.mutex.Lock()
defer c.mutex.Unlock()
c.envData = make(map[string]any, len(foundEnvVars))
for path, value := range foundEnvVars {
// Store raw string value - mapstructure will handle conversion later.
if item, exists := c.items[path]; exists {
if item.values == nil {
item.values = make(map[Source]any)
}
item.values[SourceEnv] = value // Store as string
item.currentValue = c.computeValue(path, item)
c.items[path] = item
@ -224,18 +267,24 @@ func (c *Config) loadEnv(opts LoadOptions) error {
}
}
c.invalidateCache()
return nil
}
// loadCLI loads configuration from command-line arguments
func (c *Config) loadCLI(args []string) error {
// -- 1. Prepare data (No Lock)
parsedCLI, err := parseArgs(args)
if err != nil {
return fmt.Errorf("%w: %w", ErrCLIParse, err)
}
flattenedCLI := flattenMap(parsedCLI, "")
if len(flattenedCLI) == 0 {
return nil // No CLI args to process.
}
// 2. Atomically update config (Write-Lock)
c.mutex.Lock()
defer c.mutex.Unlock()
@ -252,7 +301,7 @@ func (c *Config) loadCLI(args []string) error {
}
}
c.invalidateCache() // Invalidate cache after changes
c.invalidateCache()
return nil
}

View File

@ -9,8 +9,6 @@ import (
"sync"
"sync/atomic"
"time"
"github.com/BurntSushi/toml"
)
// WatchOptions configures file watching behavior
@ -132,6 +130,30 @@ func (c *Config) Watch() <-chan string {
return c.WatchWithOptions(DefaultWatchOptions())
}
// WatchFile stops any existing file watcher, loads a new configuration file,
// and starts a new watcher on that file path.
func (c *Config) WatchFile(filePath string) error {
// Stop any currently running watcher to prevent orphaned goroutines.
c.StopAutoUpdate()
// Load the new file and set `configFilePath` to the new path
if err := c.LoadFile(filePath); err != nil {
return fmt.Errorf("failed to load new file for watching: %w", err)
}
// Start a new watcher on the new file
c.mutex.RLock()
opts := DefaultWatchOptions()
if c.watcher != nil {
opts = c.watcher.opts
}
c.mutex.RUnlock()
c.AutoUpdateWithOptions(opts)
return nil
}
// WatchWithOptions returns a channel with custom watch options
func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string {
// First ensure auto-update is running
@ -260,7 +282,7 @@ func (w *watcher) performReload(c *Config) {
// Reload file in a goroutine with timeout
done := make(chan error, 1)
go func() {
done <- c.reloadFileAtomic(w.filePath)
done <- c.loadFile(w.filePath)
}()
select {
@ -372,59 +394,4 @@ func (c *Config) snapshot() map[string]any {
snapshot[path] = item.currentValue
}
return snapshot
}
// reloadFileAtomic atomically reloads the configuration file
func (c *Config) reloadFileAtomic(filePath string) error {
// Read file
fileData, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
// SECURITY: Check file size to prevent DoS
if len(fileData) > MaxValueSize*10 { // 10MB max for config file
return fmt.Errorf("config file too large: %d bytes", len(fileData))
}
// Parse TOML
fileConfig := make(map[string]any)
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
return fmt.Errorf("failed to parse TOML: %w", err)
}
// Flatten the configuration
flattenedFileConfig := flattenMap(fileConfig, "")
// Apply atomically
c.mutex.Lock()
defer c.mutex.Unlock()
// Clear old file data
c.fileData = make(map[string]any)
// Apply new values
for path, value := range flattenedFileConfig {
if item, exists := c.items[path]; exists {
if item.values == nil {
item.values = make(map[Source]any)
}
item.values[SourceFile] = value
item.currentValue = c.computeValue(path, item)
c.items[path] = item
c.fileData[path] = value
}
}
// Remove file values not in new config
for path, item := range c.items {
if _, exists := flattenedFileConfig[path]; !exists {
delete(item.values, SourceFile)
item.currentValue = c.computeValue(path, item)
c.items[path] = item
}
}
c.invalidateCache()
return nil
}