e5.2.0 Decoder and loader refactored, bug fixes.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,4 +5,5 @@ log
|
||||
logs
|
||||
script
|
||||
*.log
|
||||
*.toml
|
||||
bin
|
||||
67
builder.go
67
builder.go
@ -20,6 +20,7 @@ type Builder struct {
|
||||
args []string
|
||||
err error
|
||||
validators []ValidatorFunc
|
||||
typedValidators []any
|
||||
}
|
||||
|
||||
// ValidatorFunc defines the signature for a function that can validate a Config instance.
|
||||
@ -33,6 +34,7 @@ func NewBuilder() *Builder {
|
||||
opts: DefaultLoadOptions(),
|
||||
args: os.Args[1:],
|
||||
validators: make([]ValidatorFunc, 0),
|
||||
typedValidators: make([]any, 0),
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,9 +50,9 @@ func (b *Builder) Build() (*Config, error) {
|
||||
tagName = "toml"
|
||||
}
|
||||
|
||||
// The logic for registering defaults must be prioritized:
|
||||
// 1. If WithDefaults() was called, it takes precedence.
|
||||
// 2. If not, but WithTarget() was called, use the target struct for defaults.
|
||||
// 1. Register defaults
|
||||
// If WithDefaults() was called, it takes precedence.
|
||||
// If not, but WithTarget() was called, use the target struct for defaults.
|
||||
if b.defaults != nil {
|
||||
// WithDefaults() was called explicitly.
|
||||
if err := b.cfg.RegisterStructWithTags(b.prefix, b.defaults, tagName); err != nil {
|
||||
@ -65,23 +67,50 @@ func (b *Builder) Build() (*Config, error) {
|
||||
}
|
||||
|
||||
// Explicitly set the file path on the config object so the watcher can find it,
|
||||
// even if the initial load fails with a non-fatal error (e.g., file not found).
|
||||
// even if the initial load fails with a non-fatal error (file not found).
|
||||
b.cfg.configFilePath = b.file
|
||||
|
||||
// Load configuration
|
||||
// 2. Load configuration
|
||||
loadErr := b.cfg.LoadWithOptions(b.file, b.args, b.opts)
|
||||
if loadErr != nil && !errors.Is(loadErr, ErrConfigNotFound) {
|
||||
// Return on fatal load errors. ErrConfigNotFound is not fatal.
|
||||
return nil, loadErr
|
||||
}
|
||||
|
||||
// Run validators
|
||||
// 3. Run non-typed validators
|
||||
for _, validator := range b.validators {
|
||||
if err := validator(b.cfg); err != nil {
|
||||
return nil, fmt.Errorf("configuration validation failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Populate target and run typed validators
|
||||
if b.cfg.structCache != nil && b.cfg.structCache.target != nil && len(b.typedValidators) > 0 {
|
||||
// Populate the target struct first. This unifies all types (e.g., string "8888" -> int64 8888).
|
||||
populatedTarget, err := b.cfg.AsStruct()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to populate target struct for validation: %w", err)
|
||||
}
|
||||
|
||||
// Run the typed validators against the populated, type-safe struct.
|
||||
for _, validator := range b.typedValidators {
|
||||
validatorFunc := reflect.ValueOf(validator)
|
||||
validatorType := validatorFunc.Type()
|
||||
|
||||
// Check if the validator's input type matches the target's type.
|
||||
if validatorType.In(0) != reflect.TypeOf(populatedTarget) {
|
||||
return nil, fmt.Errorf("typed validator signature %v does not match target type %T", validatorType, populatedTarget)
|
||||
}
|
||||
|
||||
// Call the validator.
|
||||
results := validatorFunc.Call([]reflect.Value{reflect.ValueOf(populatedTarget)})
|
||||
if !results[0].IsNil() {
|
||||
err := results[0].Interface().(error)
|
||||
return nil, fmt.Errorf("typed configuration validation failed: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ErrConfigNotFound or nil
|
||||
return b.cfg, loadErr
|
||||
}
|
||||
@ -188,13 +217,6 @@ func (b *Builder) WithTarget(target any) *Builder {
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: removed since it would cause issues when an empty struct is passed
|
||||
// TODO: may cause issue in other scenarios, test extensively
|
||||
// // Register struct fields automatically
|
||||
// if b.defaults == nil {
|
||||
// b.defaults = target
|
||||
// }
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@ -208,3 +230,22 @@ func (b *Builder) WithValidator(fn ValidatorFunc) *Builder {
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithTypedValidator adds a type-safe validation function that runs at the end of the build process,
|
||||
// after the target struct has been populated. The provided function must accept a single argument
|
||||
// that is a pointer to the same type as the one provided to WithTarget, and must return an error.
|
||||
func (b *Builder) WithTypedValidator(fn any) *Builder {
|
||||
if fn == nil {
|
||||
return b
|
||||
}
|
||||
|
||||
// Basic reflection check to ensure it's a function that takes one argument and returns an error.
|
||||
t := reflect.TypeOf(fn)
|
||||
if t.Kind() != reflect.Func || t.NumIn() != 1 || t.NumOut() != 1 || t.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
|
||||
b.err = fmt.Errorf("WithTypedValidator requires a function with signature func(*T) error")
|
||||
return b
|
||||
}
|
||||
|
||||
b.typedValidators = append(b.typedValidators, fn)
|
||||
return b
|
||||
}
|
||||
@ -301,3 +301,62 @@ func TestFileDiscovery(t *testing.T) {
|
||||
assert.Equal(t, "clifile", val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuilderWithTypedValidator(t *testing.T) {
|
||||
type Cfg struct {
|
||||
Port int `toml:"port"`
|
||||
}
|
||||
|
||||
// Case 1: Valid configuration
|
||||
t.Run("ValidTyped", func(t *testing.T) {
|
||||
target := &Cfg{Port: 8080}
|
||||
validator := func(c *Cfg) error {
|
||||
if c.Port < 1024 {
|
||||
return fmt.Errorf("port too low")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := NewBuilder().
|
||||
WithTarget(target).
|
||||
WithTypedValidator(validator).
|
||||
Build()
|
||||
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
// Case 2: Invalid configuration
|
||||
t.Run("InvalidTyped", func(t *testing.T) {
|
||||
target := &Cfg{Port: 80}
|
||||
validator := func(c *Cfg) error {
|
||||
if c.Port < 1024 {
|
||||
return fmt.Errorf("port too low")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := NewBuilder().
|
||||
WithTarget(target).
|
||||
WithTypedValidator(validator).
|
||||
Build()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "typed configuration validation failed: port too low")
|
||||
})
|
||||
|
||||
// Case 3: Mismatched validator signature
|
||||
t.Run("MismatchedSignature", func(t *testing.T) {
|
||||
target := &Cfg{}
|
||||
validator := func(c *struct{ Name string }) error { // Different type
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := NewBuilder().
|
||||
WithTarget(target).
|
||||
WithTypedValidator(validator).
|
||||
Build()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "typed validator signature")
|
||||
})
|
||||
}
|
||||
@ -4,6 +4,7 @@ package config
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
@ -235,3 +236,50 @@ func QuickTyped[T any](target *T, envPrefix, configFile string) (*Config, error)
|
||||
WithFile(configFile).
|
||||
Build()
|
||||
}
|
||||
|
||||
// GetTyped retrieves a configuration value and decodes it into the specified type T.
|
||||
// It leverages the same decoding hooks as the Scan and AsStruct methods,
|
||||
// providing type conversion from strings, numbers, etc.
|
||||
func GetTyped[T any](c *Config, path string) (T, error) {
|
||||
var zero T
|
||||
|
||||
rawValue, exists := c.Get(path)
|
||||
if !exists {
|
||||
return zero, fmt.Errorf("path %q not found", path)
|
||||
}
|
||||
|
||||
// Prepare the input map and target struct for the decoder.
|
||||
inputMap := map[string]any{"value": rawValue}
|
||||
var target struct {
|
||||
Value T `mapstructure:"value"`
|
||||
}
|
||||
|
||||
// Create a new decoder configured with the same hooks as the main config.
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
Result: &target,
|
||||
TagName: c.tagName,
|
||||
WeaklyTypedInput: true,
|
||||
DecodeHook: c.getDecodeHook(),
|
||||
Metadata: nil,
|
||||
})
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("failed to create decoder for path %q: %w", path, err)
|
||||
}
|
||||
|
||||
// Decode the single value.
|
||||
if err := decoder.Decode(inputMap); err != nil {
|
||||
return zero, fmt.Errorf("failed to decode value for path %q into type %T: %w", path, zero, err)
|
||||
}
|
||||
|
||||
return target.Value, nil
|
||||
}
|
||||
|
||||
// ScanTyped is a generic wrapper around Scan. It allocates a new instance of type T,
|
||||
// populates it with configuration data from the given base path, and returns a pointer to it.
|
||||
func ScanTyped[T any](c *Config, basePath string) (*T, error) {
|
||||
var target T
|
||||
if err := c.Scan(basePath, &target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &target, nil
|
||||
}
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -285,3 +286,42 @@ func TestClone(t *testing.T) {
|
||||
sources := clone.GetSources("shared.value")
|
||||
assert.Equal(t, "envvalue", sources[SourceEnv])
|
||||
}
|
||||
|
||||
func TestGenericHelpers(t *testing.T) {
|
||||
cfg := New()
|
||||
cfg.Register("server.host", "localhost")
|
||||
cfg.Register("server.port", "8080") // Note: string value
|
||||
cfg.Register("features.dark_mode", true)
|
||||
cfg.Register("timeouts.read", "5s")
|
||||
|
||||
t.Run("GetTyped", func(t *testing.T) {
|
||||
port, err := GetTyped[int](cfg, "server.port")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 8080, port)
|
||||
|
||||
host, err := GetTyped[string](cfg, "server.host")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "localhost", host)
|
||||
|
||||
// Test with custom decode hook type
|
||||
readTimeout, err := GetTyped[time.Duration](cfg, "timeouts.read")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5*time.Second, readTimeout)
|
||||
|
||||
_, err = GetTyped[int](cfg, "nonexistent.path")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("ScanTyped", func(t *testing.T) {
|
||||
type ServerConfig struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
}
|
||||
|
||||
serverConf, err := ScanTyped[ServerConfig](cfg, "server")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, serverConf)
|
||||
assert.Equal(t, "localhost", serverConf.Host)
|
||||
assert.Equal(t, 8080, serverConf.Port)
|
||||
})
|
||||
}
|
||||
69
decode.go
69
decode.go
@ -44,12 +44,13 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error {
|
||||
// Navigate to basePath section
|
||||
sectionData := navigateToPath(nestedMap, basePath)
|
||||
|
||||
// Ensure we have a map to decode
|
||||
sectionMap, ok := sectionData.(map[string]any)
|
||||
if !ok {
|
||||
// Ensure we have a map to decode, normalizing if necessary.
|
||||
sectionMap, err := normalizeMap(sectionData)
|
||||
if err != nil {
|
||||
if sectionData == nil {
|
||||
sectionMap = make(map[string]any) // Empty section
|
||||
sectionMap = make(map[string]any) // Empty section is valid.
|
||||
} else {
|
||||
// Path points to a non-map value, which is an error for Scan.
|
||||
return fmt.Errorf("path %q refers to non-map value (type %T)", basePath, sectionData)
|
||||
}
|
||||
}
|
||||
@ -74,6 +75,66 @@ func (c *Config) unmarshal(basePath string, source Source, target any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// // Ensure we have a map to decode
|
||||
// sectionMap, ok := sectionData.(map[string]any)
|
||||
// if !ok {
|
||||
// if sectionData == nil {
|
||||
// sectionMap = make(map[string]any) // Empty section
|
||||
// } else {
|
||||
// return fmt.Errorf("path %q refers to non-map value (type %T)", basePath, sectionData)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Create decoder with comprehensive hooks
|
||||
// decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
// Result: target,
|
||||
// TagName: c.tagName,
|
||||
// WeaklyTypedInput: true,
|
||||
// DecodeHook: c.getDecodeHook(),
|
||||
// ZeroFields: true,
|
||||
// Metadata: nil,
|
||||
// })
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("decoder creation failed: %w", err)
|
||||
// }
|
||||
//
|
||||
// if err := decoder.Decode(sectionMap); err != nil {
|
||||
// return fmt.Errorf("decode failed for path %q: %w", basePath, err)
|
||||
// }
|
||||
//
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// normalizeMap ensures that the input data is a map[string]any for the decoder.
|
||||
func normalizeMap(data any) (map[string]any, error) {
|
||||
if data == nil {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
// If it's already the correct type, return it.
|
||||
if m, ok := data.(map[string]any); ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Use reflection to handle other map types (e.g., map[string]bool)
|
||||
v := reflect.ValueOf(data)
|
||||
if v.Kind() == reflect.Map {
|
||||
if v.Type().Key().Kind() != reflect.String {
|
||||
return nil, fmt.Errorf("map keys must be strings, but got %v", v.Type().Key())
|
||||
}
|
||||
|
||||
// Create a new map[string]any and copy the values.
|
||||
normalized := make(map[string]any, v.Len())
|
||||
iter := v.MapRange()
|
||||
for iter.Next() {
|
||||
normalized[iter.Key().String()] = iter.Value().Interface()
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected a map but got %T", data)
|
||||
}
|
||||
|
||||
// getDecodeHook returns the composite decode hook for all type conversions
|
||||
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
||||
return mapstructure.ComposeDecodeHookFunc(
|
||||
|
||||
@ -95,6 +95,37 @@ cfg, _ := config.NewBuilder().
|
||||
fmt.Println(config.Server.Port)
|
||||
```
|
||||
|
||||
### GetTyped
|
||||
|
||||
Retrieves a single configuration value and decodes it to the specified type.
|
||||
|
||||
```go
|
||||
import "time"
|
||||
|
||||
// Returns an int, converting from string "9090" if necessary.
|
||||
port, err := config.GetTyped[int](cfg, "server.port")
|
||||
|
||||
// Returns a time.Duration, converting from string "5m30s".
|
||||
timeout, err := config.GetTyped[time.Duration](cfg, "server.timeout")
|
||||
```
|
||||
|
||||
### ScanTyped
|
||||
|
||||
A generic wrapper around `Scan` that allocates, populates, and returns a pointer to a struct of the specified type.
|
||||
|
||||
```go
|
||||
// Instead of:
|
||||
// var dbConf DBConfig
|
||||
// if err := cfg.Scan("database", &dbConf); err != nil { ... }
|
||||
|
||||
// You can write:
|
||||
dbConf, err := config.ScanTyped[DBConfig](cfg, "database")
|
||||
if err != nil {
|
||||
// ...
|
||||
}
|
||||
// dbConf is a *DBConfig```
|
||||
```
|
||||
|
||||
### Type-Aware Mode
|
||||
|
||||
```go
|
||||
|
||||
@ -152,9 +152,10 @@ cfg, _ := config.NewBuilder().
|
||||
|
||||
### WithValidator
|
||||
|
||||
Add validation functions that run after loading:
|
||||
Add validation functions that run *before* the target struct is populated. These validators operate on the raw `*config.Config` object and are suitable for checking required paths or formats before type conversion.
|
||||
|
||||
```go
|
||||
// Validator runs on raw, pre-decoded values.
|
||||
cfg, _ := config.NewBuilder().
|
||||
WithDefaults(defaults).
|
||||
WithValidator(func(c *config.Config) error {
|
||||
@ -172,6 +173,34 @@ cfg, _ := config.NewBuilder().
|
||||
Build()
|
||||
```
|
||||
|
||||
For type-safe validation, see `WithTypedValidator`.
|
||||
|
||||
### WithTypedValidator
|
||||
|
||||
Add a type-safe validation function that runs *after* the configuration has been fully loaded and decoded into the target struct (set by `WithTarget`). This is the recommended approach for most validation logic.
|
||||
|
||||
The validation function must accept a single argument: a pointer to the same struct type that was passed to `WithTarget`.
|
||||
```go
|
||||
type AppConfig struct {
|
||||
Server struct {
|
||||
Port int64 `toml:"port"`
|
||||
} `toml:"server"`
|
||||
}
|
||||
|
||||
var target AppConfig
|
||||
|
||||
cfg, err := config.NewBuilder().
|
||||
WithTarget(&target).
|
||||
WithFile("config.toml").
|
||||
WithTypedValidator(func(conf *AppConfig) error {
|
||||
if conf.Server.Port < 1024 || conf.Server.Port > 65535 {
|
||||
return fmt.Errorf("port %d is outside the valid range", conf.Server.Port)
|
||||
}
|
||||
return nil
|
||||
}).
|
||||
Build()
|
||||
```
|
||||
|
||||
### WithFile
|
||||
|
||||
Set configuration file path:
|
||||
|
||||
@ -5,11 +5,10 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"config"
|
||||
"github.com/lixenwraith/config"
|
||||
)
|
||||
|
||||
// AppConfig defines a richer configuration structure to showcase more features.
|
||||
@ -68,42 +67,20 @@ func main() {
|
||||
// and keep it updated when using `AsStruct()`.
|
||||
target := &AppConfig{}
|
||||
|
||||
// Define a custom validator function.
|
||||
validator := func(c *config.Config) error {
|
||||
p, _ := c.Get("server.port")
|
||||
// 'p' can be an int64 (from defaults/TOML) or a string (from environment variables).
|
||||
|
||||
var port int64
|
||||
var err error
|
||||
|
||||
switch v := p.(type) {
|
||||
case string:
|
||||
// If it's a string from an env var, parse it.
|
||||
port, err = strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse port from string '%s': %w", v, err)
|
||||
}
|
||||
case int64:
|
||||
// If it's already an int64, just use it.
|
||||
port = v
|
||||
default:
|
||||
// Handle any other unexpected types.
|
||||
return fmt.Errorf("unexpected type for server.port: %T", p)
|
||||
}
|
||||
|
||||
if port < 1024 || port > 65535 {
|
||||
return fmt.Errorf("port %d is outside the recommended range (1024-65535)", port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use the builder to chain multiple configuration options.
|
||||
builder := config.NewBuilder().
|
||||
WithTarget(target). // Enables type-safe `AsStruct()` and auto-registration.
|
||||
WithDefaults(initialData). // Explicitly set the source of defaults.
|
||||
WithFile(configFilePath). // Specifies the config file to read.
|
||||
WithEnvPrefix("APP_"). // Sets prefix for environment variables (e.g., APP_SERVER_PORT).
|
||||
WithValidator(validator) // Adds a validation function to run at the end of the build.
|
||||
WithTypedValidator(func(cfg *AppConfig) error { // <-- NEW METHOD
|
||||
// No type assertion needed! `cfg.Server.Port` is guaranteed to be an int64
|
||||
// because the validator runs *after* the target struct is populated.
|
||||
if cfg.Server.Port < 1024 || cfg.Server.Port > 65535 {
|
||||
return fmt.Errorf("port %d is outside the recommended range (1024-65535)", cfg.Server.Port)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Build the final config object.
|
||||
cfg, err := builder.Build()
|
||||
@ -175,44 +152,37 @@ func createInitialConfigFile(data *AppConfig) error {
|
||||
return cfg.Save(configFilePath)
|
||||
}
|
||||
|
||||
// modifyFileOnDiskStructurally simulates an external program robustly changing the config file.
|
||||
// modifyFileOnDiskStructurally simulates an external program that changes the config file.
|
||||
func modifyFileOnDiskStructurally(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
time.Sleep(1 * time.Second)
|
||||
log.Println(" (Modifier goroutine: now changing file on disk...)")
|
||||
|
||||
// Create a new, independent config instance to simulate an external process.
|
||||
modifierCfg := config.New()
|
||||
// Register the struct shape so the loader knows what paths are valid.
|
||||
if err := modifierCfg.RegisterStruct("", &AppConfig{}); err != nil {
|
||||
log.Fatalf("❌ Modifier failed to register struct: %v", err)
|
||||
}
|
||||
// Load the current state from disk.
|
||||
if err := modifierCfg.LoadFile(configFilePath); err != nil {
|
||||
log.Fatalf("❌ Modifier failed to load file: %v", err)
|
||||
}
|
||||
|
||||
// Change the log level and add a new feature flag.
|
||||
// Change the log level.
|
||||
modifierCfg.Set("server.log_level", "debug")
|
||||
|
||||
rawFlags, _ := modifierCfg.Get("feature_flags")
|
||||
newFlags := make(map[string]any)
|
||||
|
||||
// Use a type switch to robustly handle the map, regardless of its source.
|
||||
switch flags := rawFlags.(type) {
|
||||
case map[string]bool:
|
||||
for k, v := range flags {
|
||||
newFlags[k] = v
|
||||
}
|
||||
case map[string]any:
|
||||
for k, v := range flags {
|
||||
newFlags[k] = v
|
||||
}
|
||||
default:
|
||||
log.Fatalf("❌ Modifier encountered unexpected type for feature_flags: %T", rawFlags)
|
||||
// Use the generic GetTyped function. This is safe because modifierCfg has loaded the file.
|
||||
featureFlags, err := config.GetTyped[map[string]bool](modifierCfg, "feature_flags")
|
||||
if err != nil {
|
||||
log.Fatalf("❌ Modifier failed to get typed feature_flags: %v", err)
|
||||
}
|
||||
|
||||
// Now modify the generic map and set it back.
|
||||
newFlags["enable_tracing"] = false
|
||||
modifierCfg.Set("feature_flags", newFlags)
|
||||
// Modify the typed map and set it back.
|
||||
featureFlags["enable_metrics"] = false
|
||||
modifierCfg.Set("feature_flags", featureFlags)
|
||||
|
||||
// Save the changes back to disk, which will trigger the watcher in the main goroutine.
|
||||
if err := modifierCfg.Save(configFilePath); err != nil {
|
||||
log.Fatalf("❌ Modifier failed to save file: %v", err)
|
||||
}
|
||||
|
||||
105
loader.go
105
loader.go
@ -143,6 +143,7 @@ func (c *Config) LoadFile(filePath string) error {
|
||||
|
||||
// loadFile reads and parses a TOML configuration file
|
||||
func (c *Config) loadFile(path string) error {
|
||||
// 1. Read and Parse (No Lock)
|
||||
fileData, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
@ -156,36 +157,58 @@ func (c *Config) loadFile(path string) error {
|
||||
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
|
||||
}
|
||||
|
||||
// Flatten and apply file data
|
||||
flattenedFileConfig := flattenMap(fileConfig, "")
|
||||
// 2. Prepare New State (Read-Lock Only)
|
||||
newFileData := make(map[string]any)
|
||||
|
||||
// Briefly acquire a read-lock to safely get the list of registered paths.
|
||||
c.mutex.RLock()
|
||||
registeredPaths := make(map[string]bool, len(c.items))
|
||||
for p := range c.items {
|
||||
registeredPaths[p] = true
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
// Define a recursive function to populate newFileData. This runs without any lock.
|
||||
var apply func(prefix string, data map[string]any)
|
||||
apply = func(prefix string, data map[string]any) {
|
||||
for key, value := range data {
|
||||
fullPath := key
|
||||
if prefix != "" {
|
||||
fullPath = prefix + "." + key
|
||||
}
|
||||
if registeredPaths[fullPath] {
|
||||
newFileData[fullPath] = value
|
||||
} else if subMap, isMap := value.(map[string]any); isMap {
|
||||
apply(fullPath, subMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
apply("", fileConfig)
|
||||
|
||||
// -- 3. Atomically Update Config (Write-Lock)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Track the config file path for watching
|
||||
c.configFilePath = path
|
||||
c.fileData = newFileData
|
||||
|
||||
defer c.invalidateCache() // Invalidate cache after changes
|
||||
|
||||
// Store in cache
|
||||
c.fileData = flattenedFileConfig
|
||||
|
||||
// Apply to registered paths
|
||||
for path, value := range flattenedFileConfig {
|
||||
if item, exists := c.items[path]; exists {
|
||||
// Apply the new state to the main config items.
|
||||
for path, item := range c.items {
|
||||
if value, exists := newFileData[path]; exists {
|
||||
if item.values == nil {
|
||||
item.values = make(map[Source]any)
|
||||
}
|
||||
if str, ok := value.(string); ok && len(str) > MaxValueSize {
|
||||
return ErrValueSize
|
||||
}
|
||||
item.values[SourceFile] = value
|
||||
} else {
|
||||
// Key was not in the new file, so remove its old file-sourced value.
|
||||
delete(item.values, SourceFile)
|
||||
}
|
||||
// Recompute the current value based on new source precedence.
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
}
|
||||
// Ignore unregistered paths from file
|
||||
}
|
||||
|
||||
c.invalidateCache()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -196,27 +219,47 @@ func (c *Config) loadEnv(opts LoadOptions) error {
|
||||
transform = defaultEnvTransform(opts.EnvPrefix)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// -- 1. Prepare data (Read-Lock to get paths)
|
||||
c.mutex.RLock()
|
||||
paths := make([]string, 0, len(c.items))
|
||||
for p := range c.items {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
defer c.invalidateCache() // Invalidate cache after changes
|
||||
|
||||
c.envData = make(map[string]any)
|
||||
|
||||
for path, item := range c.items {
|
||||
// -- 2. Process env vars (No Lock)
|
||||
foundEnvVars := make(map[string]string)
|
||||
for _, path := range paths {
|
||||
if opts.EnvWhitelist != nil && !opts.EnvWhitelist[path] {
|
||||
continue
|
||||
}
|
||||
|
||||
envVar := transform(path)
|
||||
if value, exists := os.LookupEnv(envVar); exists {
|
||||
// Store raw string value - mapstructure will handle conversion
|
||||
if item.values == nil {
|
||||
item.values = make(map[Source]any)
|
||||
}
|
||||
if len(value) > MaxValueSize {
|
||||
return ErrValueSize
|
||||
}
|
||||
foundEnvVars[path] = value
|
||||
}
|
||||
}
|
||||
|
||||
// If no relevant env vars were found, we are done.
|
||||
if len(foundEnvVars) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// -- 3. Atomically update config (Write-Lock)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.envData = make(map[string]any, len(foundEnvVars))
|
||||
|
||||
for path, value := range foundEnvVars {
|
||||
// Store raw string value - mapstructure will handle conversion later.
|
||||
if item, exists := c.items[path]; exists {
|
||||
if item.values == nil {
|
||||
item.values = make(map[Source]any)
|
||||
}
|
||||
item.values[SourceEnv] = value // Store as string
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
@ -224,18 +267,24 @@ func (c *Config) loadEnv(opts LoadOptions) error {
|
||||
}
|
||||
}
|
||||
|
||||
c.invalidateCache()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCLI loads configuration from command-line arguments
|
||||
func (c *Config) loadCLI(args []string) error {
|
||||
// -- 1. Prepare data (No Lock)
|
||||
parsedCLI, err := parseArgs(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrCLIParse, err)
|
||||
}
|
||||
|
||||
flattenedCLI := flattenMap(parsedCLI, "")
|
||||
if len(flattenedCLI) == 0 {
|
||||
return nil // No CLI args to process.
|
||||
}
|
||||
|
||||
// 2. Atomically update config (Write-Lock)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@ -252,7 +301,7 @@ func (c *Config) loadCLI(args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
c.invalidateCache() // Invalidate cache after changes
|
||||
c.invalidateCache()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
83
watch.go
83
watch.go
@ -9,8 +9,6 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
// WatchOptions configures file watching behavior
|
||||
@ -132,6 +130,30 @@ func (c *Config) Watch() <-chan string {
|
||||
return c.WatchWithOptions(DefaultWatchOptions())
|
||||
}
|
||||
|
||||
// WatchFile stops any existing file watcher, loads a new configuration file,
|
||||
// and starts a new watcher on that file path.
|
||||
func (c *Config) WatchFile(filePath string) error {
|
||||
// Stop any currently running watcher to prevent orphaned goroutines.
|
||||
c.StopAutoUpdate()
|
||||
|
||||
// Load the new file and set `configFilePath` to the new path
|
||||
if err := c.LoadFile(filePath); err != nil {
|
||||
return fmt.Errorf("failed to load new file for watching: %w", err)
|
||||
}
|
||||
|
||||
// Start a new watcher on the new file
|
||||
c.mutex.RLock()
|
||||
opts := DefaultWatchOptions()
|
||||
if c.watcher != nil {
|
||||
opts = c.watcher.opts
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
c.AutoUpdateWithOptions(opts)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WatchWithOptions returns a channel with custom watch options
|
||||
func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string {
|
||||
// First ensure auto-update is running
|
||||
@ -260,7 +282,7 @@ func (w *watcher) performReload(c *Config) {
|
||||
// Reload file in a goroutine with timeout
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.reloadFileAtomic(w.filePath)
|
||||
done <- c.loadFile(w.filePath)
|
||||
}()
|
||||
|
||||
select {
|
||||
@ -373,58 +395,3 @@ func (c *Config) snapshot() map[string]any {
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
// reloadFileAtomic atomically reloads the configuration file
|
||||
func (c *Config) reloadFileAtomic(filePath string) error {
|
||||
// Read file
|
||||
fileData, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
// SECURITY: Check file size to prevent DoS
|
||||
if len(fileData) > MaxValueSize*10 { // 10MB max for config file
|
||||
return fmt.Errorf("config file too large: %d bytes", len(fileData))
|
||||
}
|
||||
|
||||
// Parse TOML
|
||||
fileConfig := make(map[string]any)
|
||||
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse TOML: %w", err)
|
||||
}
|
||||
|
||||
// Flatten the configuration
|
||||
flattenedFileConfig := flattenMap(fileConfig, "")
|
||||
|
||||
// Apply atomically
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Clear old file data
|
||||
c.fileData = make(map[string]any)
|
||||
|
||||
// Apply new values
|
||||
for path, value := range flattenedFileConfig {
|
||||
if item, exists := c.items[path]; exists {
|
||||
if item.values == nil {
|
||||
item.values = make(map[Source]any)
|
||||
}
|
||||
item.values[SourceFile] = value
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
c.fileData[path] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Remove file values not in new config
|
||||
for path, item := range c.items {
|
||||
if _, exists := flattenedFileConfig[path]; !exists {
|
||||
delete(item.values, SourceFile)
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
}
|
||||
}
|
||||
|
||||
c.invalidateCache()
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user