Compare commits

..

10 Commits

37 changed files with 7606 additions and 1826 deletions

1
.gitignore vendored
View File

@ -5,4 +5,5 @@ log
logs logs
script script
*.log *.log
*.toml
bin bin

161
README.md
View File

@ -1,11 +1,21 @@
# Config # Config
Thread-safe configuration management for Go with support for TOML files, environment variables, command-line arguments, and defaults with configurable precedence. Thread-safe configuration management for Go applications with support for multiple sources (files, environment variables, command-line arguments, defaults) and configurable precedence.
## Features
- **Multiple Sources**: Load configuration from defaults, files, environment variables, and CLI arguments
- **Configurable Precedence**: Control which sources override others
- **Type Safety**: Struct-based configuration with automatic validation
- **Thread-Safe**: Concurrent access with read-write locking
- **File Watching**: Automatic reloading on configuration changes
- **Source Tracking**: Know exactly where each value came from
- **Tag Support**: Use `toml`, `json`, or `yaml` struct tags
## Installation ## Installation
```bash ```bash
go get github.com/LixenWraith/config go get github.com/lixenwraith/config
``` ```
## Quick Start ## Quick Start
@ -15,7 +25,6 @@ package main
import ( import (
"log" "log"
"github.com/lixenwraith/config" "github.com/lixenwraith/config"
) )
@ -24,157 +33,33 @@ type AppConfig struct {
Host string `toml:"host"` Host string `toml:"host"`
Port int `toml:"port"` Port int `toml:"port"`
} `toml:"server"` } `toml:"server"`
Database struct {
URL string `toml:"url"`
MaxConns int `toml:"max_conns"`
} `toml:"database"`
Debug bool `toml:"debug"` Debug bool `toml:"debug"`
} }
func main() { func main() {
// Define defaults defaults := &AppConfig{}
defaults := AppConfig{}
defaults.Server.Host = "localhost" defaults.Server.Host = "localhost"
defaults.Server.Port = 8080 defaults.Server.Port = 8080
defaults.Database.URL = "postgres://localhost/myapp"
defaults.Database.MaxConns = 10
// Initialize with environment prefix and config file
cfg, err := config.Quick(defaults, "MYAPP_", "config.toml") cfg, err := config.Quick(defaults, "MYAPP_", "config.toml")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// Access values port, _ := cfg.Get("server.port")
host, _ := cfg.String("server.host") log.Printf("Server port: %d", port.(int64))
port, _ := cfg.Int64("server.port")
dbURL, _ := cfg.String("database.url")
debug, _ := cfg.Bool("debug")
log.Printf("Server: %s:%d, DB: %s, Debug: %v", host, port, dbURL, debug)
} }
``` ```
**config.toml:** ## Documentation
```toml
[server]
host = "production.example.com"
port = 9090
[database] - [Quick Start Guide](doc/quick-start.md) - Get up and running quickly
url = "postgres://prod-db/myapp" - [Builder Pattern](doc/builder.md) - Advanced configuration with the builder
max_conns = 50 - [Command Line](doc/cli.md) - CLI argument handling
- [Environment Variables](doc/env.md) - Environment variable configuration
debug = false - [Configuration Files](doc/file.md) - File loading and formats
``` - [Access Patterns](doc/access.md) - Getting and setting values
- [Live Reconfiguration](doc/reconfiguration.md) - File watching and updates
**Usage:**
```bash
# Override with environment variables
export MYAPP_SERVER_PORT=8443
export MYAPP_DEBUG=true
# Override with CLI arguments
./myapp --server.port=9999 --debug
```
## Key Features
- **Multiple Sources**: Defaults → File → Environment → CLI (configurable order)
- **Type Safety**: Automatic conversion with detailed error messages
- **Thread-Safe**: Concurrent reads with protected writes
- **Builder Pattern**: Fluent interface for advanced configuration
- **Source Tracking**: See which source provided each value
- **Zero Dependencies**: Only stdlib + minimal parsers
## Common Patterns
### Custom Precedence
```go
cfg, _ := config.NewBuilder().
WithDefaults(defaults).
WithSources(
config.SourceEnv, // Env vars highest priority
config.SourceFile,
config.SourceCLI,
config.SourceDefault,
).
Build()
```
### Environment Variable Mapping
```go
// Custom env var names
opts := config.LoadOptions{
EnvTransform: func(path string) string {
switch path {
case "server.port": return "PORT"
case "database.url": return "DATABASE_URL"
default: return ""
}
},
}
cfg.LoadWithOptions("config.toml", os.Args[1:], opts)
```
### Validation
```go
// Register and validate required fields
cfg.RegisterRequired("api.key", "")
cfg.RegisterRequired("database.url", "")
if err := cfg.Validate("api.key", "database.url"); err != nil {
log.Fatal("Missing required config: ", err)
}
```
### Source Inspection
```go
// See all sources for a value
sources := cfg.GetSources("server.port")
for source, value := range sources {
fmt.Printf("%s: %v\n", source, value)
}
// Get value from specific source
envPort, exists := cfg.GetSource("server.port", config.SourceEnv)
```
### Struct Scanning
```go
var serverConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
}
cfg.Scan("server", &serverConfig)
```
### Environment Whitelist
```go
// Only load specific env vars
cfg, _ := config.NewBuilder().
WithDefaults(defaults).
WithEnvPrefix("MYAPP_").
WithEnvWhitelist("api.key", "database.password").
Build()
```
## API Reference
### Core Methods
- `Quick(defaults, envPrefix, configFile)` - Quick initialization
- `Register(path, defaultValue)` - Register configuration path
- `Get/String/Int64/Bool/Float64(path)` - Type-safe accessors
- `Set(path, value)` - Update configuration
- `Validate(paths...)` - Ensure required values are set
### Advanced Methods
- `NewBuilder()` - Create custom configuration
- `GetSource(path, source)` - Get value from specific source
- `GetSources(path)` - Get all source values
- `Scan(basePath, target)` - Unmarshal into struct
- `Clone()` - Deep copy configuration
- `Debug()` - Show all values and sources
## License ## License

View File

@ -1,28 +1,34 @@
// File: lixenwraith/config/builder.go // FILE: lixenwraith/config/builder.go
package config package config
import ( import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"reflect"
) )
// ValidatorFunc defines the signature for a function that can validate a Config instance. // Builder provides a fluent API for constructing a Config instance. It allows for
// It receives the fully loaded *Config object and should return an error if validation fails. // chaining configuration options before final build of the config object.
type ValidatorFunc func(c *Config) error
// Builder provides a fluent interface for building configurations
type Builder struct { type Builder struct {
cfg *Config cfg *Config
opts LoadOptions opts LoadOptions
defaults any defaults any
tagName string
fileFormat string
securityOpts *SecurityOptions
prefix string prefix string
file string file string
args []string args []string
err error err error
validators []ValidatorFunc validators []ValidatorFunc
typedValidators []any
} }
// ValidatorFunc defines the signature for a function that can validate a Config instance.
// It receives the fully loaded *Config object and should return an error if validation fails.
type ValidatorFunc func(c *Config) error
// NewBuilder creates a new configuration builder // NewBuilder creates a new configuration builder
func NewBuilder() *Builder { func NewBuilder() *Builder {
return &Builder{ return &Builder{
@ -30,15 +36,145 @@ func NewBuilder() *Builder {
opts: DefaultLoadOptions(), opts: DefaultLoadOptions(),
args: os.Args[1:], args: os.Args[1:],
validators: make([]ValidatorFunc, 0), validators: make([]ValidatorFunc, 0),
typedValidators: make([]any, 0),
} }
} }
// Build creates the Config instance with all specified options
func (b *Builder) Build() (*Config, error) {
if b.err != nil {
return nil, b.err
}
// Use tagName if set, default to "toml"
tagName := b.tagName
if tagName == "" {
tagName = "toml"
}
// Set format and security settings
if b.fileFormat != "" {
b.cfg.fileFormat = b.fileFormat
}
if b.securityOpts != nil {
b.cfg.securityOpts = b.securityOpts
}
// 1. Register defaults
// If WithDefaults() was called, it takes precedence.
// If not, but WithTarget() was called, use the target struct for defaults.
if b.defaults != nil {
// WithDefaults() was called explicitly.
if err := b.cfg.RegisterStructWithTags(b.prefix, b.defaults, tagName); err != nil {
return nil, fmt.Errorf("failed to register defaults: %w", err)
}
} else if b.cfg.structCache != nil && b.cfg.structCache.target != nil {
// No explicit defaults, so use the target struct as the source of defaults.
// This is the behavior the tests rely on.
if err := b.cfg.RegisterStructWithTags(b.prefix, b.cfg.structCache.target, tagName); err != nil {
return nil, fmt.Errorf("failed to register target struct as defaults: %w", err)
}
}
// Explicitly set the file path on the config object so the watcher can find it,
// even if the initial load fails with a non-fatal error (file not found).
b.cfg.configFilePath = b.file
// 2. Load configuration
loadErr := b.cfg.LoadWithOptions(b.file, b.args, b.opts)
if loadErr != nil && !errors.Is(loadErr, ErrConfigNotFound) {
// Return on fatal load errors. ErrConfigNotFound is not fatal.
return nil, loadErr
}
// 3. Run non-typed validators
for _, validator := range b.validators {
if err := validator(b.cfg); err != nil {
return nil, fmt.Errorf("configuration validation failed: %w", err)
}
}
// 4. Populate target and run typed validators
if b.cfg.structCache != nil && b.cfg.structCache.target != nil && len(b.typedValidators) > 0 {
// Populate the target struct first. This unifies all types (e.g., string "8888" -> int64 8888).
populatedTarget, err := b.cfg.AsStruct()
if err != nil {
return nil, fmt.Errorf("failed to populate target struct for validation: %w", err)
}
// Run the typed validators against the populated, type-safe struct.
for _, validator := range b.typedValidators {
validatorFunc := reflect.ValueOf(validator)
validatorType := validatorFunc.Type()
// Check if the validator's input type matches the target's type.
if validatorType.In(0) != reflect.TypeOf(populatedTarget) {
return nil, fmt.Errorf("typed validator signature %v does not match target type %T", validatorType, populatedTarget)
}
// Call the validator.
results := validatorFunc.Call([]reflect.Value{reflect.ValueOf(populatedTarget)})
if !results[0].IsNil() {
err := results[0].Interface().(error)
return nil, fmt.Errorf("typed configuration validation failed: %w", err)
}
}
}
// ErrConfigNotFound or nil
return b.cfg, loadErr
}
// MustBuild is like Build but panics on error
func (b *Builder) MustBuild() *Config {
cfg, err := b.Build()
if err != nil {
// Ignore ErrConfigNotFound for app to proceed with defaults/env vars
if !errors.Is(err, ErrConfigNotFound) {
panic(fmt.Sprintf("config build failed: %v", err))
}
}
return cfg
}
// WithDefaults sets the struct containing default values // WithDefaults sets the struct containing default values
func (b *Builder) WithDefaults(defaults any) *Builder { func (b *Builder) WithDefaults(defaults any) *Builder {
b.defaults = defaults b.defaults = defaults
return b return b
} }
// WithTagName sets the struct tag name to use for field mapping
// Supported values: "toml" (default), "json", "yaml"
func (b *Builder) WithTagName(tagName string) *Builder {
switch tagName {
case "toml", "json", "yaml":
b.tagName = tagName
if b.cfg != nil { // Ensure cfg exists
b.cfg.tagName = tagName
}
default:
b.err = fmt.Errorf("unsupported tag name %q, must be one of: toml, json, yaml", tagName)
}
return b
}
// WithFileFormat sets the expected file format
func (b *Builder) WithFileFormat(format string) *Builder {
switch format {
case "toml", "json", "yaml", "auto":
b.fileFormat = format
default:
b.err = fmt.Errorf("unsupported file format %q", format)
}
return b
}
// WithSecurityOptions sets security options for file loading
func (b *Builder) WithSecurityOptions(opts SecurityOptions) *Builder {
b.securityOpts = &opts
return b
}
// WithPrefix sets the prefix for struct registration // WithPrefix sets the prefix for struct registration
func (b *Builder) WithPrefix(prefix string) *Builder { func (b *Builder) WithPrefix(prefix string) *Builder {
b.prefix = prefix b.prefix = prefix
@ -86,8 +222,35 @@ func (b *Builder) WithEnvWhitelist(paths ...string) *Builder {
return b return b
} }
// WithTarget enables type-aware mode for the builder
func (b *Builder) WithTarget(target any) *Builder {
rv := reflect.ValueOf(target)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
b.err = fmt.Errorf("WithTarget requires non-nil pointer to struct, got %T", target)
return b
}
elem := rv.Elem()
if elem.Kind() != reflect.Struct {
b.err = fmt.Errorf("WithTarget requires pointer to struct, got pointer to %v", elem.Kind())
return b
}
// Initialize struct cache
if b.cfg.structCache == nil {
b.cfg.structCache = &structCache{
target: target,
targetType: elem.Type(),
}
}
return b
}
// WithValidator adds a validation function that runs at the end of the build process // WithValidator adds a validation function that runs at the end of the build process
// Multiple validators can be added and are executed in the order they are added // Multiple validators can be added and are executed in the order they are added
// Validation runs after all sources are loaded
// If any validator returns error, build fails without running subsequent validators
func (b *Builder) WithValidator(fn ValidatorFunc) *Builder { func (b *Builder) WithValidator(fn ValidatorFunc) *Builder {
if fn != nil { if fn != nil {
b.validators = append(b.validators, fn) b.validators = append(b.validators, fn)
@ -95,63 +258,21 @@ func (b *Builder) WithValidator(fn ValidatorFunc) *Builder {
return b return b
} }
// Build creates the Config instance with all specified options // WithTypedValidator adds a type-safe validation function that runs at the end of the build process,
func (b *Builder) Build() (*Config, error) { // after the target struct has been populated. The provided function must accept a single argument
if b.err != nil { // that is a pointer to the same type as the one provided to WithTarget, and must return an error.
return nil, b.err func (b *Builder) WithTypedValidator(fn any) *Builder {
if fn == nil {
return b
} }
// Register defaults if provided // Basic reflection check to ensure it's a function that takes one argument and returns an error.
if b.defaults != nil { t := reflect.TypeOf(fn)
if err := b.cfg.RegisterStruct(b.prefix, b.defaults); err != nil { if t.Kind() != reflect.Func || t.NumIn() != 1 || t.NumOut() != 1 || t.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
return nil, fmt.Errorf("failed to register defaults: %w", err) b.err = fmt.Errorf("WithTypedValidator requires a function with signature func(*T) error")
} return b
} }
// Load configuration b.typedValidators = append(b.typedValidators, fn)
loadErr := b.cfg.LoadWithOptions(b.file, b.args, b.opts) return b
if loadErr != nil && !errors.Is(loadErr, ErrConfigNotFound) {
// Return on fatal load errors. ErrConfigNotFound is not fatal.
return nil, loadErr
}
// Run validators
for _, validator := range b.validators {
if err := validator(b.cfg); err != nil {
return nil, fmt.Errorf("configuration validation failed: %w", err)
}
}
// ErrConfigNotFound or nil
return b.cfg, loadErr
}
// MustBuild is like Build but panics on error
func (b *Builder) MustBuild() *Config {
cfg, err := b.Build()
if err != nil {
// Ignore ErrConfigNotFound as it is not a fatal error for MustBuild.
// The application can proceed with defaults/env vars.
if !errors.Is(err, ErrConfigNotFound) {
panic(fmt.Sprintf("config build failed: %v", err))
}
}
return cfg
}
// BuildAndScan builds and unmarshals the final configuration into the provided target struct pointer
func (b *Builder) BuildAndScan(target any) error {
cfg, err := b.Build()
if err != nil && !errors.Is(err, ErrConfigNotFound) {
return err
}
// Use Scan to populate the target struct.
// The prefix used during registration is the base path for scanning.
if err := cfg.Scan(b.prefix, target); err != nil {
return fmt.Errorf("failed to scan final config into target: %w", err)
}
// ErrConfigNotFound or nil
return err
} }

View File

@ -1,339 +1,364 @@
// File: lixenwraith/config/builder_test.go // FILE: lixenwraith/config/builder_test.go
package config_test package config
import ( import (
"errors" "fmt"
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/lixenwraith/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// TestBuilder tests the builder pattern
func TestBuilder(t *testing.T) { func TestBuilder(t *testing.T) {
t.Run("BasicBuilder", func(t *testing.T) { t.Run("BasicBuilder", func(t *testing.T) {
type AppConfig struct {
Name string `toml:"name"`
Version string `toml:"version"`
Debug bool `toml:"debug"`
}
defaults := AppConfig{
Name: "testapp",
Version: "1.0.0",
Debug: false,
}
cfg, err := config.NewBuilder().
WithDefaults(defaults).
WithPrefix("app.").
Build()
require.NoError(t, err)
// Check registered paths
paths := cfg.GetRegisteredPaths("app.")
assert.Len(t, paths, 3)
// Check values
name, err := cfg.String("app.name")
require.NoError(t, err)
assert.Equal(t, "testapp", name)
})
t.Run("Builder with All Options", func(t *testing.T) {
os.Setenv("BUILDER_SERVER_PORT", "5555")
defer os.Unsetenv("BUILDER_SERVER_PORT")
type Config struct { type Config struct {
Server struct {
Host string `toml:"host"` Host string `toml:"host"`
Port int `toml:"port"` Port int `toml:"port"`
} `toml:"server"`
API struct {
Key string `toml:"key"`
Timeout int `toml:"timeout"`
} `toml:"api"`
} }
defaults := Config{} defaults := &Config{
defaults.Server.Host = "localhost" Host: "localhost",
defaults.Server.Port = 8080 Port: 8080,
defaults.API.Timeout = 30 }
cfg, err := config.NewBuilder(). cfg, err := NewBuilder().
WithDefaults(defaults). WithDefaults(defaults).
WithEnvPrefix("BUILDER_"). WithEnvPrefix("TEST_").
WithArgs([]string{"--api.key=test-key"}).
WithSources(
config.SourceCLI,
config.SourceEnv,
config.SourceDefault,
).
WithEnvWhitelist("server.port", "api.key").
Build() Build()
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, cfg)
// CLI should provide api.key val, exists := cfg.Get("host")
apiKey, err := cfg.String("api.key") assert.True(t, exists)
require.NoError(t, err) assert.Equal(t, "localhost", val)
assert.Equal(t, "test-key", apiKey)
// Env should provide server.port (whitelisted)
port, err := cfg.Int64("server.port")
require.NoError(t, err)
assert.Equal(t, int64(5555), port)
// Non-whitelisted env should not load
os.Setenv("BUILDER_API_TIMEOUT", "99")
defer os.Unsetenv("BUILDER_API_TIMEOUT")
cfg2, err := config.NewBuilder().
WithDefaults(defaults).
WithEnvPrefix("BUILDER_").
WithEnvWhitelist("server.port"). // api.timeout NOT whitelisted
Build()
require.NoError(t, err)
timeout, err := cfg2.Int64("api.timeout")
require.NoError(t, err)
assert.Equal(t, int64(30), timeout, "non-whitelisted env should not load")
}) })
t.Run("Builder Custom Transform", func(t *testing.T) { t.Run("BuilderWithAllOptions", func(t *testing.T) {
os.Setenv("PORT", "3333") tmpDir := t.TempDir()
os.Setenv("DB_URL", "postgres://custom") configFile := filepath.Join(tmpDir, "test.toml")
defer func() { os.WriteFile(configFile, []byte(`host = "filehost"`), 0644)
os.Unsetenv("PORT")
os.Unsetenv("DB_URL")
}()
type Config struct { type Config struct {
Server struct { Host string `json:"hostname"`
Port int `toml:"port"` Port int `json:"port"`
} `toml:"server"`
Database struct {
URL string `toml:"url"`
} `toml:"database"`
} }
cfg, err := config.NewBuilder(). defaults := &Config{
WithDefaults(Config{}). Host: "defaulthost",
WithEnvTransform(func(path string) string { Port: 3000,
switch path {
case "server.port":
return "PORT"
case "database.url":
return "DB_URL"
default:
return ""
} }
}).
// Custom env transform
envTransform := func(path string) string {
return "CUSTOM_" + path
}
cfg, err := NewBuilder().
WithDefaults(defaults).
WithTagName("json").
WithPrefix("server").
WithEnvPrefix("APP_").
WithFile(configFile).
WithArgs([]string{"--server.hostname=clihost"}).
WithSources(SourceCLI, SourceFile, SourceEnv, SourceDefault).
WithEnvTransform(envTransform).
WithEnvWhitelist("server.hostname").
Build() Build()
require.NoError(t, err) require.NoError(t, err)
port, err := cfg.Int64("server.port") // CLI should take precedence
require.NoError(t, err) val, _ := cfg.Get("server.hostname")
assert.Equal(t, int64(3333), port) assert.Equal(t, "clihost", val)
})
t.Run("BuilderWithTarget", func(t *testing.T) {
type Config struct {
Database struct {
Host string `toml:"host"`
Port int `toml:"port"`
} `toml:"db"`
Cache struct {
TTL int `toml:"ttl"`
} `toml:"cache"`
}
target := &Config{}
target.Database.Host = "localhost"
target.Database.Port = 5432
target.Cache.TTL = 300
cfg, err := NewBuilder().
WithTarget(target).
Build()
dbURL, err := cfg.String("database.url")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "postgres://custom", dbURL)
// Verify paths were registered
paths := cfg.GetRegisteredPaths()
assert.True(t, paths["db.host"])
assert.True(t, paths["db.port"])
assert.True(t, paths["cache.ttl"])
// Test AsStruct
result, err := cfg.AsStruct()
require.NoError(t, err)
assert.Equal(t, target, result)
})
t.Run("BuilderWithValidator", func(t *testing.T) {
type UserConfig struct {
Port int `toml:"port"`
}
validatorCalled := false
validator := func(cfg *Config) error {
validatorCalled = true
val, exists := cfg.Get("port")
if !exists {
return fmt.Errorf("port not found")
}
// Convert to int - could be int64 from storage
var port int
switch v := val.(type) {
case int:
port = v
case int64:
port = int(v)
default:
return fmt.Errorf("port has unexpected type %T", v)
}
if port < 1024 {
return fmt.Errorf("port %d is below 1024", port)
}
return nil
}
// Valid case
cfg, err := NewBuilder().
WithDefaults(&UserConfig{Port: 8080}).
WithValidator(validator).
Build()
require.NoError(t, err)
assert.NotNil(t, cfg)
assert.True(t, validatorCalled)
// Invalid case
validatorCalled = false
cfg2, err := NewBuilder().
WithDefaults(&UserConfig{Port: 80}).
WithValidator(validator).
Build()
assert.Nil(t, cfg2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "configuration validation failed")
assert.True(t, validatorCalled)
})
t.Run("BuilderErrorAccumulation", func(t *testing.T) {
// Unsupported tag name
_, err := NewBuilder().
WithTagName("xml").
WithDefaults(struct{}{}).
Build()
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported tag name")
// Invalid target
_, err = NewBuilder().
WithTarget("not-a-pointer").
Build()
assert.Error(t, err)
assert.Contains(t, err.Error(), "requires non-nil pointer to struct")
}) })
t.Run("MustBuildPanic", func(t *testing.T) { t.Run("MustBuildPanic", func(t *testing.T) {
// Should not panic with valid config
assert.NotPanics(t, func() {
cfg := NewBuilder().
WithDefaults(struct{ Port int }{Port: 8080}).
MustBuild()
assert.NotNil(t, cfg)
})
// Should panic with error
assert.Panics(t, func() { assert.Panics(t, func() {
config.NewBuilder(). NewBuilder().
WithDefaults("not a struct"). WithTagName("invalid").
MustBuild() MustBuild()
}) })
}) })
}
t.Run("Builder with Validator", func(t *testing.T) { // TestFileDiscovery tests automatic config file discovery
type Config struct { func TestFileDiscovery(t *testing.T) {
Server struct { t.Run("DiscoveryWithCLIFlag", func(t *testing.T) {
Host string `toml:"host"` tmpDir := t.TempDir()
// Use .toml extension for TOML content
configFile := filepath.Join(tmpDir, "custom.toml")
os.WriteFile(configFile, []byte(`test = "value"`), 0644)
opts := DefaultDiscoveryOptions("myapp")
cfg, err := NewBuilder().
WithDefaults(struct {
Test string `toml:"test"`
}{Test: "default"}).
WithArgs([]string{"--config", configFile}).
WithFileDiscovery(opts).
Build()
require.NoError(t, err)
// Verify file was loaded
val, _ := cfg.Get("test")
assert.Equal(t, "value", val)
})
// Rest of test cases remain the same...
t.Run("DiscoveryWithEnvVar", func(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "env.toml")
os.WriteFile(configFile, []byte(`test = "envvalue"`), 0644)
os.Setenv("MYAPP_CONFIG", configFile)
defer os.Unsetenv("MYAPP_CONFIG")
opts := DefaultDiscoveryOptions("myapp")
cfg, err := NewBuilder().
WithDefaults(struct {
Test string `toml:"test"`
}{Test: "default"}).
WithFileDiscovery(opts).
Build()
require.NoError(t, err)
val, _ := cfg.Get("test")
assert.Equal(t, "envvalue", val)
})
t.Run("DiscoveryInCurrentDir", func(t *testing.T) {
// Create config in current directory
cwd, _ := os.Getwd()
configFile := filepath.Join(cwd, "myapp.toml")
os.WriteFile(configFile, []byte(`test = "cwdvalue"`), 0644)
defer os.Remove(configFile)
opts := FileDiscoveryOptions{
Name: "myapp",
Extensions: []string{".toml"},
UseCurrentDir: true,
}
cfg, err := NewBuilder().
WithDefaults(struct {
Test string `toml:"test"`
}{Test: "default"}).
WithFileDiscovery(opts).
Build()
require.NoError(t, err)
val, _ := cfg.Get("test")
assert.Equal(t, "cwdvalue", val)
})
t.Run("DiscoveryPrecedence", func(t *testing.T) {
tmpDir := t.TempDir()
// Create multiple config files
cliFile := filepath.Join(tmpDir, "cli.toml")
envFile := filepath.Join(tmpDir, "env.toml")
os.WriteFile(cliFile, []byte(`test = "clifile"`), 0644)
os.WriteFile(envFile, []byte(`test = "envfile"`), 0644)
// CLI should take precedence over env
os.Setenv("MYAPP_CONFIG", envFile)
defer os.Unsetenv("MYAPP_CONFIG")
opts := DefaultDiscoveryOptions("myapp")
cfg, err := NewBuilder().
WithDefaults(struct {
Test string `toml:"test"`
}{Test: "default"}).
WithArgs([]string{"--config", cliFile}).
WithFileDiscovery(opts).
Build()
require.NoError(t, err)
val, _ := cfg.Get("test")
assert.Equal(t, "clifile", val)
})
}
func TestBuilderWithTypedValidator(t *testing.T) {
type Cfg struct {
Port int `toml:"port"` Port int `toml:"port"`
} `toml:"server"`
MaxConns int `toml:"max_conns"`
} }
defaults := Config{} // Case 1: Valid configuration
defaults.Server.Host = "localhost" t.Run("ValidTyped", func(t *testing.T) {
defaults.Server.Port = 8080 target := &Cfg{Port: 8080}
defaults.MaxConns = 100 validator := func(c *Cfg) error {
if c.Port < 1024 {
// Validator that fails return fmt.Errorf("port too low")
failingValidator := func(c *config.Config) error {
port, err := c.Int64("server.port")
if err != nil {
return err
}
if port == 8080 {
return errors.New("port 8080 is not allowed")
} }
return nil return nil
} }
// Validator that succeeds _, err := NewBuilder().
passingValidator := func(c *config.Config) error { WithTarget(target).
host, err := c.String("server.host") WithTypedValidator(validator).
if err != nil {
return err
}
if host == "" {
return errors.New("host cannot be empty")
}
return nil
}
// Test case 1: Validator fails
_, err := config.NewBuilder().
WithDefaults(defaults).
WithValidator(failingValidator).
Build() Build()
require.NoError(t, err)
})
// Case 2: Invalid configuration
t.Run("InvalidTyped", func(t *testing.T) {
target := &Cfg{Port: 80}
validator := func(c *Cfg) error {
if c.Port < 1024 {
return fmt.Errorf("port too low")
}
return nil
}
_, err := NewBuilder().
WithTarget(target).
WithTypedValidator(validator).
Build()
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "port 8080 is not allowed") assert.Contains(t, err.Error(), "typed configuration validation failed: port too low")
})
// Test case 2: Validator passes // Case 3: Mismatched validator signature
cfg, err := config.NewBuilder(). t.Run("MismatchedSignature", func(t *testing.T) {
WithDefaults(defaults). target := &Cfg{}
WithArgs([]string{"--server.port=9000"}). // Change the port so it passes validator := func(c *struct{ Name string }) error { // Different type
WithValidator(failingValidator). return nil
WithValidator(passingValidator). }
_, err := NewBuilder().
WithTarget(target).
WithTypedValidator(validator).
Build() Build()
require.NoError(t, err)
assert.NotNil(t, cfg)
port, _ := cfg.Int64("server.port")
assert.Equal(t, int64(9000), port)
// Test case 3: MustBuild panics on validation failure require.Error(t, err)
assert.PanicsWithError(t, "configuration validation failed: port 8080 is not allowed", func() { assert.Contains(t, err.Error(), "typed validator signature")
config.NewBuilder().
WithDefaults(defaults).
WithValidator(failingValidator).
MustBuild()
})
})
}
func TestQuickFunctions(t *testing.T) {
t.Run("Quick Success", func(t *testing.T) {
type Config struct {
App struct {
Name string `toml:"name"`
} `toml:"app"`
}
defaults := Config{}
defaults.App.Name = "quicktest"
cfg, err := config.Quick(defaults, "QUICK_", "")
require.NoError(t, err)
name, err := cfg.String("app.name")
require.NoError(t, err)
assert.Equal(t, "quicktest", name)
})
t.Run("QuickCustom", func(t *testing.T) {
opts := config.LoadOptions{
Sources: []config.Source{
config.SourceDefault,
config.SourceEnv,
},
EnvPrefix: "CUSTOM_",
}
cfg, err := config.QuickCustom(nil, opts, "")
require.NoError(t, err)
assert.NotNil(t, cfg)
})
t.Run("MustQuick Panic", func(t *testing.T) {
assert.Panics(t, func() {
config.MustQuick("invalid", "TEST_", "")
})
})
}
func TestConvenienceFunctions(t *testing.T) {
t.Run("Validate", func(t *testing.T) {
cfg := config.New()
cfg.Register("required1", "")
cfg.Register("required2", 0)
cfg.Register("optional", "has-default")
// Initial validation should fail
err := cfg.Validate("required1", "required2")
assert.Error(t, err, "expected validation to fail for empty values")
// Set required values
cfg.Set("required1", "value1")
cfg.Set("required2", 42)
// Now should pass
err = cfg.Validate("required1", "required2")
assert.NoError(t, err)
// Validate unregistered path
err = cfg.Validate("unregistered")
assert.Error(t, err, "expected error for unregistered path")
})
t.Run("Debug Output", func(t *testing.T) {
cfg := config.New()
cfg.Register("test.value", "default")
cfg.SetSource("test.value", config.SourceFile, "from-file")
cfg.SetSource("test.value", config.SourceEnv, "from-env")
debug := cfg.Debug()
// Should contain key information
assert.NotEmpty(t, debug)
// Should show sources - checking for actual source string values
assert.Contains(t, debug, "file")
assert.Contains(t, debug, "from-file")
assert.Contains(t, debug, "env")
assert.Contains(t, debug, "from-env")
})
t.Run("Clone", func(t *testing.T) {
cfg := config.New()
cfg.Register("original", "value")
cfg.Set("original", "modified")
// Clone configuration
clone := cfg.Clone()
// Clone should have same values
val, err := clone.String("original")
require.NoError(t, err)
assert.Equal(t, "modified", val)
// Modifying clone should not affect original
clone.Set("original", "clone-modified")
origVal, err := cfg.String("original")
require.NoError(t, err)
assert.Equal(t, "modified", origVal, "original should not be affected by clone modification")
})
t.Run("GetRegisteredPathsWithDefaults", func(t *testing.T) {
cfg := config.New()
cfg.Register("app.name", "myapp")
cfg.Register("app.version", "1.0.0")
cfg.Register("server.port", 8080)
// Get paths with defaults
paths := cfg.GetRegisteredPathsWithDefaults("app.")
assert.Len(t, paths, 2)
assert.Equal(t, "myapp", paths["app.name"])
assert.Equal(t, "1.0.0", paths["app.version"])
}) })
} }

View File

@ -1,301 +0,0 @@
// File: lixenwraith/cmd/test/main.go
// Test program for the config package
package main
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/lixenwraith/config"
)
// AppConfig represents a simple application configuration
type AppConfig struct {
Server struct {
Host string `toml:"host"`
Port int `toml:"port"`
} `toml:"server"`
Database struct {
URL string `toml:"url"`
MaxConns int `toml:"max_conns"`
} `toml:"database"`
API struct {
Key string `toml:"key" env:"CUSTOM_API_KEY"` // Custom env mapping
Timeout int `toml:"timeout"`
} `toml:"api"`
Debug bool `toml:"debug"`
LogFile string `toml:"log_file"`
}
func main() {
fmt.Println("=== Config Package Feature Test ===\n")
// Test directories
tempDir := os.TempDir()
configPath := filepath.Join(tempDir, "test_config.toml")
defer os.Remove(configPath)
// Set up test environment variables
setupEnvironment()
defer cleanupEnvironment()
// Run feature tests
testQuickStart()
testBuilder()
testSourceTracking()
testEnvironmentFeatures()
testValidation()
testUtilities()
fmt.Println("\n=== All Tests Complete ===")
}
func testQuickStart() {
fmt.Println("=== Test 1: Quick Start ===")
// Define defaults
defaults := AppConfig{}
defaults.Server.Host = "localhost"
defaults.Server.Port = 8080
defaults.Database.URL = "postgres://localhost/testdb"
defaults.Database.MaxConns = 10
defaults.Debug = false
// Quick initialization
cfg, err := config.Quick(defaults, "TEST_", "")
if err != nil {
log.Fatalf("Quick init failed: %v", err)
}
// Access values
host, _ := cfg.String("server.host")
port, _ := cfg.Int64("server.port")
fmt.Printf("Quick config - Host: %s, Port: %d\n", host, port)
// Verify env override (TEST_DEBUG=true was set)
debug, _ := cfg.Bool("debug")
fmt.Printf("Debug from env: %v (should be true)\n", debug)
}
func testBuilder() {
fmt.Println("\n=== Test 2: Builder Pattern ===")
defaults := AppConfig{}
defaults.Server.Port = 8080
defaults.API.Timeout = 30
// Custom precedence: Env > File > CLI > Default
cfg, err := config.NewBuilder().
WithDefaults(defaults).
WithEnvPrefix("APP_").
WithSources(
config.SourceEnv,
config.SourceFile,
config.SourceCLI,
config.SourceDefault,
).
WithArgs([]string{"--server.port=9999"}).
Build()
if err != nil {
log.Fatalf("Builder failed: %v", err)
}
// ENV should win over CLI due to custom precedence
port, _ := cfg.Int64("server.port")
fmt.Printf("Port with Env > CLI precedence: %d (should be 7070 from env)\n", port)
}
func testSourceTracking() {
fmt.Println("\n=== Test 3: Source Tracking ===")
cfg := config.New()
cfg.Register("test.value", "default")
// Set from multiple sources
cfg.SetSource("test.value", config.SourceFile, "from-file")
cfg.SetSource("test.value", config.SourceEnv, "from-env")
cfg.SetSource("test.value", config.SourceCLI, "from-cli")
// Show all sources
sources := cfg.GetSources("test.value")
fmt.Println("All sources for test.value:")
for source, value := range sources {
fmt.Printf(" %s: %v\n", source, value)
}
// Get from specific source
envVal, exists := cfg.GetSource("test.value", config.SourceEnv)
fmt.Printf("Value from env source: %v (exists: %v)\n", envVal, exists)
// Current value (default precedence)
current, _ := cfg.String("test.value")
fmt.Printf("Current value: %s (should be from-cli)\n", current)
}
func testEnvironmentFeatures() {
fmt.Println("\n=== Test 4: Environment Features ===")
cfg := config.New()
cfg.Register("api.key", "")
cfg.Register("api.secret", "")
cfg.Register("database.host", "localhost")
// Test 4a: Custom env transform
fmt.Println("\n4a. Custom Environment Transform:")
opts := config.LoadOptions{
Sources: []config.Source{config.SourceEnv, config.SourceDefault},
EnvTransform: func(path string) string {
switch path {
case "api.key":
return "CUSTOM_API_KEY"
case "database.host":
return "DB_HOST"
default:
return ""
}
},
}
cfg.LoadWithOptions("", nil, opts)
apiKey, _ := cfg.String("api.key")
fmt.Printf("API Key from CUSTOM_API_KEY: %s\n", apiKey)
// Test 4b: Discover environment variables
fmt.Println("\n4b. Environment Discovery:")
cfg2 := config.New()
cfg2.Register("server.port", 8080)
cfg2.Register("debug", false)
cfg2.Register("api.timeout", 30)
discovered := cfg2.DiscoverEnv("TEST_")
fmt.Println("Discovered env vars with TEST_ prefix:")
for path, envVar := range discovered {
fmt.Printf(" %s -> %s\n", path, envVar)
}
// Test 4c: Export configuration as env vars
fmt.Println("\n4c. Export as Environment:")
cfg2.Set("server.port", 3000)
cfg2.Set("debug", true)
exports := cfg2.ExportEnv("EXPORT_")
fmt.Println("Non-default values exported:")
for env, value := range exports {
fmt.Printf(" export %s=%s\n", env, value)
}
// Test 4d: RegisterWithEnv
fmt.Println("\n4d. RegisterWithEnv:")
cfg3 := config.New()
err := cfg3.RegisterWithEnv("special.value", "default", "SPECIAL_ENV_VAR")
if err != nil {
fmt.Printf("RegisterWithEnv error: %v\n", err)
}
special, _ := cfg3.String("special.value")
fmt.Printf("Value from SPECIAL_ENV_VAR: %s\n", special)
}
func testValidation() {
fmt.Println("\n=== Test 5: Validation ===")
cfg := config.New()
cfg.RegisterRequired("api.key", "")
cfg.RegisterRequired("database.url", "")
cfg.Register("optional.setting", "default")
// Should fail validation
err := cfg.Validate("api.key", "database.url")
if err != nil {
fmt.Printf("Validation failed as expected: %v\n", err)
}
// Set required values
cfg.Set("api.key", "secret-key")
cfg.Set("database.url", "postgres://localhost/db")
// Should pass validation
err = cfg.Validate("api.key", "database.url")
if err == nil {
fmt.Println("Validation passed after setting required values")
}
}
func testUtilities() {
fmt.Println("\n=== Test 6: Utility Features ===")
// Create config with some data
cfg := config.New()
cfg.Register("app.name", "testapp")
cfg.Register("app.version", "1.0.0")
cfg.Register("server.port", 8080)
cfg.SetSource("app.version", config.SourceFile, "1.1.0")
cfg.SetSource("server.port", config.SourceEnv, 9090)
// Test 6a: Debug output
fmt.Println("\n6a. Debug Output:")
debug := cfg.Debug()
fmt.Printf("Debug info (first 200 chars): %.200s...\n", debug)
// Test 6b: Clone
fmt.Println("\n6b. Clone Configuration:")
clone := cfg.Clone()
clone.Set("app.name", "cloned-app")
original, _ := cfg.String("app.name")
cloned, _ := clone.String("app.name")
fmt.Printf("Original app.name: %s, Cloned: %s\n", original, cloned)
// Test 6c: Reset source
fmt.Println("\n6c. Reset Sources:")
sources := cfg.GetSources("server.port")
fmt.Printf("Sources before reset: %v\n", sources)
cfg.ResetSource(config.SourceEnv)
sources = cfg.GetSources("server.port")
fmt.Printf("Sources after env reset: %v\n", sources)
// Test 6d: Save and load specific source
fmt.Println("\n6d. Save/Load Specific Source:")
tempFile := filepath.Join(os.TempDir(), "source_test.toml")
defer os.Remove(tempFile)
err := cfg.SaveSource(tempFile, config.SourceFile)
if err != nil {
fmt.Printf("SaveSource error: %v\n", err)
} else {
fmt.Println("Saved SourceFile values to temp file")
}
// Test 6e: GetRegisteredPaths
fmt.Println("\n6e. Registered Paths:")
paths := cfg.GetRegisteredPaths("app.")
fmt.Printf("Paths with 'app.' prefix: %v\n", paths)
pathsWithDefaults := cfg.GetRegisteredPathsWithDefaults("app.")
for path, def := range pathsWithDefaults {
fmt.Printf(" %s: %v\n", path, def)
}
}
func setupEnvironment() {
// Set test environment variables
os.Setenv("TEST_DEBUG", "true")
os.Setenv("TEST_SERVER_PORT", "6666")
os.Setenv("APP_SERVER_PORT", "7070")
os.Setenv("CUSTOM_API_KEY", "env-api-key")
os.Setenv("DB_HOST", "env-db-host")
os.Setenv("SPECIAL_ENV_VAR", "special-value")
}
func cleanupEnvironment() {
os.Unsetenv("TEST_DEBUG")
os.Unsetenv("TEST_SERVER_PORT")
os.Unsetenv("APP_SERVER_PORT")
os.Unsetenv("CUSTOM_API_KEY")
os.Unsetenv("DB_HOST")
os.Unsetenv("SPECIAL_ENV_VAR")
}

277
config.go
View File

@ -1,4 +1,4 @@
// File: lixenwraith/config/config.go // FILE: lixenwraith/config/config.go
// Package config provides thread-safe configuration management for Go applications // Package config provides thread-safe configuration management for Go applications
// with support for multiple sources: TOML files, environment variables, command-line // with support for multiple sources: TOML files, environment variables, command-line
// arguments, and default values with configurable precedence. // arguments, and default values with configurable precedence.
@ -7,9 +7,14 @@ package config
import ( import (
"errors" "errors"
"fmt" "fmt"
"reflect"
"sync" "sync"
"sync/atomic"
) )
// Max config item value size to prevent misuse
const MaxValueSize = 1024 * 1024 // 1MB
// Errors // Errors
var ( var (
// ErrConfigNotFound indicates the specified configuration file was not found. // ErrConfigNotFound indicates the specified configuration file was not found.
@ -19,65 +24,13 @@ var (
ErrCLIParse = errors.New("failed to parse command-line arguments") ErrCLIParse = errors.New("failed to parse command-line arguments")
// ErrEnvParse indicates that parsing environment variables failed. // ErrEnvParse indicates that parsing environment variables failed.
// TODO: use in loader:loadEnv or remove
ErrEnvParse = errors.New("failed to parse environment variables") ErrEnvParse = errors.New("failed to parse environment variables")
// ErrValueSize indicates a value larger than MaxValueSize
ErrValueSize = fmt.Errorf("value size exceeds maximum %d bytes", MaxValueSize)
) )
// Source represents a configuration source
type Source string
const (
SourceDefault Source = "default"
SourceFile Source = "file"
SourceEnv Source = "env"
SourceCLI Source = "cli"
)
// LoadMode defines how configuration sources are processed
type LoadMode int
const (
// LoadModeReplace completely replaces values (default behavior)
LoadModeReplace LoadMode = iota
// LoadModeMerge merges maps/structs instead of replacing
LoadModeMerge
)
// EnvTransformFunc converts a configuration path to an environment variable name
type EnvTransformFunc func(path string) string
// LoadOptions configures how configuration is loaded from multiple sources
type LoadOptions struct {
// Sources defines the precedence order (first = highest priority)
// Default: [SourceCLI, SourceEnv, SourceFile, SourceDefault]
Sources []Source
// EnvPrefix is prepended to environment variable names
// Example: "MYAPP_" transforms "server.port" to "MYAPP_SERVER_PORT"
EnvPrefix string
// EnvTransform customizes how paths map to environment variables
// If nil, uses default transformation (dots to underscores, uppercase)
EnvTransform EnvTransformFunc
// LoadMode determines how values are merged
LoadMode LoadMode
// EnvWhitelist limits which paths are checked for env vars (nil = all)
EnvWhitelist map[string]bool
// SkipValidation skips path validation during load
SkipValidation bool
}
// DefaultLoadOptions returns the standard load options
func DefaultLoadOptions() LoadOptions {
return LoadOptions{
Sources: []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault},
LoadMode: LoadModeReplace,
}
}
// configItem holds configuration values from different sources // configItem holds configuration values from different sources
type configItem struct { type configItem struct {
defaultValue any defaultValue any
@ -85,20 +38,54 @@ type configItem struct {
currentValue any // Computed value based on precedence currentValue any // Computed value based on precedence
} }
// Config manages application configuration loaded from multiple sources. // structCache manages the typed representation of configuration
type structCache struct {
target any // User-provided struct pointer
targetType reflect.Type // Cached type for validation
version int64 // Version for invalidation
populated bool // Whether cache is valid
mu sync.RWMutex
}
// SecurityOptions for enhanced file loading security
type SecurityOptions struct {
PreventPathTraversal bool // Prevent ../ in paths
EnforceFileOwnership bool // Unix only: ensure file owned by current user
MaxFileSize int64 // Maximum config file size (0 = no limit)
}
// Config manages application configuration. It can be used in two primary ways:
// 1. As a dynamic key-value store, accessed via methods like Get(), String(), and Int64()
// 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct()
type Config struct { type Config struct {
items map[string]configItem items map[string]configItem
tagName string
fileFormat string // Separate from tagName: "toml", "json", "yaml", or "auto"
securityOpts *SecurityOptions
mutex sync.RWMutex mutex sync.RWMutex
options LoadOptions // Current load options options LoadOptions // Current load options
fileData map[string]any // Cached file data fileData map[string]any // Cached file data
envData map[string]any // Cached env data envData map[string]any // Cached env data
cliData map[string]any // Cached CLI data cliData map[string]any // Cached CLI data
version atomic.Int64
structCache *structCache
// File watching support
watcher *watcher
configFilePath string // Track loaded file path
} }
// New creates and initializes a new Config instance. // New creates and initializes a new Config instance.
func New() *Config { func New() *Config {
return &Config{ return &Config{
items: make(map[string]configItem), items: make(map[string]configItem),
tagName: "toml",
fileFormat: "auto",
// securityOpts: &SecurityOptions{
// PreventPathTraversal: false,
// EnforceFileOwnership: false,
// MaxFileSize: 0,
// },
options: DefaultLoadOptions(), options: DefaultLoadOptions(),
fileData: make(map[string]any), fileData: make(map[string]any),
envData: make(map[string]any), envData: make(map[string]any),
@ -122,15 +109,86 @@ func (c *Config) SetLoadOptions(opts LoadOptions) error {
// Recompute all current values based on new precedence // Recompute all current values based on new precedence
for path, item := range c.items { for path, item := range c.items {
item.currentValue = c.computeValue(path, item) item.currentValue = c.computeValue(item)
c.items[path] = item c.items[path] = item
} }
return nil return nil
} }
// SetPrecedence updates source precedence with validation
func (c *Config) SetPrecedence(sources ...Source) error {
// Validate all required sources present
required := map[Source]bool{
SourceDefault: false,
SourceFile: false,
SourceEnv: false,
SourceCLI: false,
}
for _, s := range sources {
if _, valid := required[s]; !valid {
return fmt.Errorf("invalid source: %s", s)
}
required[s] = true
}
// Ensure SourceDefault is included
if !required[SourceDefault] {
sources = append(sources, SourceDefault)
}
c.mutex.Lock()
defer c.mutex.Unlock()
// FIXED: Check if precedence actually changed
oldPrecedence := c.options.Sources
if reflect.DeepEqual(oldPrecedence, sources) {
return nil // No change needed
}
// Track value changes before updating precedence
oldValues := make(map[string]any)
for path, item := range c.items {
oldValues[path] = item.currentValue
}
// Update precedence
c.options.Sources = sources
// Recompute values and track changes
changedPaths := make([]string, 0)
for path, item := range c.items {
item.currentValue = c.computeValue(item)
if !reflect.DeepEqual(oldValues[path], item.currentValue) {
changedPaths = append(changedPaths, path)
}
c.items[path] = item
}
// Notify watchers of precedence change
if c.watcher != nil && len(changedPaths) > 0 {
for _, path := range changedPaths {
c.watcher.notifyWatchers("precedence:" + path)
}
}
c.invalidateCache()
return nil
}
// GetPrecedence returns current source precedence
func (c *Config) GetPrecedence() []Source {
c.mutex.RLock()
defer c.mutex.RUnlock()
result := make([]Source, len(c.options.Sources))
copy(result, c.options.Sources)
return result
}
// computeValue determines the current value based on precedence // computeValue determines the current value based on precedence
func (c *Config) computeValue(path string, item configItem) any { func (c *Config) computeValue(item configItem) any {
// Check sources in precedence order // Check sources in precedence order
for _, source := range c.options.Sources { for _, source := range c.options.Sources {
if val, exists := item.values[source]; exists && val != nil { if val, exists := item.values[source]; exists && val != nil {
@ -142,9 +200,31 @@ func (c *Config) computeValue(path string, item configItem) any {
return item.defaultValue return item.defaultValue
} }
// Get retrieves a configuration value using the path. // SetFileFormat sets the expected format for configuration files.
// It returns the current value based on configured precedence. // Use "auto" to detect based on file extension.
// The second return value indicates if the path was registered. func (c *Config) SetFileFormat(format string) error {
switch format {
case "toml", "json", "yaml", "auto":
// Valid formats
default:
return fmt.Errorf("unsupported file format %q, must be one of: toml, json, yaml, auto", format)
}
c.mutex.Lock()
defer c.mutex.Unlock()
c.fileFormat = format
return nil
}
// SetSecurityOptions configures security checks for file loading
func (c *Config) SetSecurityOptions(opts SecurityOptions) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.securityOpts = &opts
}
// Get retrieves a configuration value using the path and indicator if the path was registered
func (c *Config) Get(path string) (any, bool) { func (c *Config) Get(path string) (any, bool) {
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
@ -172,14 +252,15 @@ func (c *Config) GetSource(path string, source Source) (any, bool) {
} }
// Set updates a configuration value for the given path. // Set updates a configuration value for the given path.
// It sets the value in the highest priority source (typically CLI). // It sets the value in the highest priority source from the configured Sources.
// Returns an error if the path is not registered. // By default, this is SourceCLI. Returns an error if the path is not registered.
// To set a value in a specific source, use SetSource instead.
func (c *Config) Set(path string, value any) error { func (c *Config) Set(path string, value any) error {
return c.SetSource(path, c.options.Sources[0], value) return c.SetSource(c.options.Sources[0], path, value)
} }
// SetSource sets a value for a specific source // SetSource sets a value for a specific source
func (c *Config) SetSource(path string, source Source, value any) error { func (c *Config) SetSource(source Source, path string, value any) error {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
@ -188,12 +269,16 @@ func (c *Config) SetSource(path string, source Source, value any) error {
return fmt.Errorf("path %s is not registered", path) return fmt.Errorf("path %s is not registered", path)
} }
if str, ok := value.(string); ok && len(str) > MaxValueSize {
return ErrValueSize
}
if item.values == nil { if item.values == nil {
item.values = make(map[Source]any) item.values = make(map[Source]any)
} }
item.values[source] = value item.values[source] = value
item.currentValue = c.computeValue(path, item) item.currentValue = c.computeValue(item)
c.items[path] = item c.items[path] = item
// Update source cache // Update source cache
@ -206,6 +291,7 @@ func (c *Config) SetSource(path string, source Source, value any) error {
c.cliData[path] = value c.cliData[path] = value
} }
c.invalidateCache() // Invalidate cache after changes
return nil return nil
} }
@ -242,6 +328,8 @@ func (c *Config) Reset() {
item.currentValue = item.defaultValue item.currentValue = item.defaultValue
c.items[path] = item c.items[path] = item
} }
c.invalidateCache() // Invalidate cache after changes
} }
// ResetSource clears all values from a specific source // ResetSource clears all values from a specific source
@ -262,7 +350,58 @@ func (c *Config) ResetSource(source Source) {
// Remove source values from all items // Remove source values from all items
for path, item := range c.items { for path, item := range c.items {
delete(item.values, source) delete(item.values, source)
item.currentValue = c.computeValue(path, item) item.currentValue = c.computeValue(item)
c.items[path] = item c.items[path] = item
} }
c.invalidateCache() // Invalidate cache after changes
}
// Override Set methods to invalidate cache
func (c *Config) invalidateCache() {
c.version.Add(1)
}
// AsStruct returns the populated struct if in type-aware mode
func (c *Config) AsStruct() (any, error) {
if c.structCache == nil || c.structCache.target == nil {
return nil, fmt.Errorf("no target struct configured")
}
c.structCache.mu.RLock()
currentVersion := c.version.Load()
needsUpdate := !c.structCache.populated || c.structCache.version != currentVersion
c.structCache.mu.RUnlock()
if needsUpdate {
if err := c.populateStruct(); err != nil {
return nil, err
}
}
return c.structCache.target, nil
}
// Target populates the provided struct with current configuration
func (c *Config) Target(out any) error {
return c.Scan(out)
}
// populateStruct updates the cached struct representation using unified unmarshal
func (c *Config) populateStruct() error {
c.structCache.mu.Lock()
defer c.structCache.mu.Unlock()
currentVersion := c.version.Load()
if c.structCache.populated && c.structCache.version == currentVersion {
return nil
}
if err := c.unmarshal("", c.structCache.target); err != nil {
return fmt.Errorf("failed to populate struct cache: %w", err)
}
c.structCache.version = currentVersion
c.structCache.populated = true
return nil
} }

717
config_test.go Normal file
View File

@ -0,0 +1,717 @@
// FILE: lixenwraith/config/config_test.go
package config
import (
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConfigCreation tests various config creation patterns
func TestConfigCreation(t *testing.T) {
t.Run("NewWithDefaultOptions", func(t *testing.T) {
cfg := New()
require.NotNil(t, cfg)
assert.NotNil(t, cfg.items)
assert.Equal(t, []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault}, cfg.options.Sources)
})
t.Run("NewWithCustomOptions", func(t *testing.T) {
opts := LoadOptions{
Sources: []Source{SourceEnv, SourceFile, SourceDefault},
EnvPrefix: "MYAPP_",
LoadMode: LoadModeReplace,
}
cfg := NewWithOptions(opts)
require.NotNil(t, cfg)
assert.Equal(t, opts.Sources, cfg.options.Sources)
assert.Equal(t, "MYAPP_", cfg.options.EnvPrefix)
})
}
// TestPathRegistration tests path registration edge cases
func TestPathRegistration(t *testing.T) {
tests := []struct {
name string
path string
defaultVal any
expectError bool
errorMsg string
}{
{"ValidSimplePath", "port", 8080, false, ""},
{"ValidNestedPath", "server.host.name", "localhost", false, ""},
{"EmptyPath", "", nil, true, "registration path cannot be empty"},
{"InvalidCharacter", "server.port!", 8080, true, "invalid path segment"},
{"InvalidDot", "server..port", 8080, true, "invalid path segment"},
{"LeadingDot", ".server.port", 8080, true, "invalid path segment"},
{"TrailingDot", "server.port.", 8080, true, "invalid path segment"},
{"ValidUnderscore", "server_config.max_connections", 100, false, ""},
{"ValidDash", "feature-flags.enable-debug", false, false, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := New()
err := cfg.Register(tt.path, tt.defaultVal)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
val, exists := cfg.Get(tt.path)
assert.True(t, exists)
assert.Equal(t, tt.defaultVal, val)
}
})
}
}
// TestComplexStructRegistration tests struct registration with various tag types
func TestComplexStructRegistration(t *testing.T) {
type DatabaseConfig struct {
Host string `toml:"host" json:"db_host" yaml:"dbHost"`
Port int `toml:"port" json:"db_port" yaml:"dbPort"`
MaxConns int `toml:"max_connections"`
Timeout time.Duration `toml:"timeout"`
EnableDebug bool `toml:"debug" env:"DB_DEBUG"`
}
type ServerConfig struct {
Name string `toml:"name" json:"name"`
Database DatabaseConfig `toml:"db" json:"db"`
Tags []string `toml:"tags" json:"tags"`
Metadata map[string]any `toml:"metadata" json:"metadata"`
}
defaultConfig := &ServerConfig{
Name: "test-server",
Database: DatabaseConfig{
Host: "localhost",
Port: 5432,
MaxConns: 100,
Timeout: 30 * time.Second,
EnableDebug: false,
},
Tags: []string{"test", "development"},
Metadata: map[string]any{"version": "1.0"},
}
t.Run("TOMLTags", func(t *testing.T) {
cfg := New()
err := cfg.RegisterStruct("", defaultConfig)
require.NoError(t, err)
// Verify paths registered with TOML tags
paths := cfg.GetRegisteredPaths("")
assert.True(t, paths["name"])
assert.True(t, paths["db.host"])
assert.True(t, paths["db.port"])
assert.True(t, paths["db.max_connections"])
assert.True(t, paths["db.timeout"])
assert.True(t, paths["db.debug"])
assert.True(t, paths["tags"])
assert.True(t, paths["metadata"])
// Verify default values
val, _ := cfg.Get("db.timeout")
assert.Equal(t, 30*time.Second, val)
})
t.Run("JSONTags", func(t *testing.T) {
cfg := New()
err := cfg.RegisterStructWithTags("", defaultConfig, "json")
require.NoError(t, err)
// JSON tags should create different paths
paths := cfg.GetRegisteredPaths("")
assert.True(t, paths["db.db_host"])
assert.True(t, paths["db.db_port"])
})
t.Run("UnsupportedTag", func(t *testing.T) {
cfg := New()
err := cfg.RegisterStructWithTags("", defaultConfig, "xml")
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported tag name")
})
t.Run("WithPrefix", func(t *testing.T) {
cfg := New()
err := cfg.RegisterStruct("server", defaultConfig)
require.NoError(t, err)
paths := cfg.GetRegisteredPaths("server.")
assert.True(t, paths["server.name"])
assert.True(t, paths["server.db.host"])
})
}
// TestSourcePrecedence tests configuration source precedence
func TestSourcePrecedence(t *testing.T) {
cfg := New()
cfg.Register("test.value", "default")
// Set values in different sources
cfg.SetSource(SourceFile, "test.value", "from-file")
cfg.SetSource(SourceEnv, "test.value", "from-env")
cfg.SetSource(SourceCLI, "test.value", "from-cli")
// Default precedence: CLI > Env > File > Default
val, _ := cfg.Get("test.value")
assert.Equal(t, "from-cli", val)
// Remove CLI value
cfg.ResetSource(SourceCLI)
val, _ = cfg.Get("test.value")
assert.Equal(t, "from-env", val)
// Change precedence
err := cfg.SetLoadOptions(LoadOptions{
Sources: []Source{SourceFile, SourceEnv, SourceCLI, SourceDefault},
})
require.NoError(t, err)
val, _ = cfg.Get("test.value")
assert.Equal(t, "from-file", val)
// Test GetSources
sources := cfg.GetSources("test.value")
assert.Equal(t, "from-file", sources[SourceFile])
assert.Equal(t, "from-env", sources[SourceEnv])
}
// TestSetPrecedence tests runtime precedence switching
func TestSetPrecedence(t *testing.T) {
t.Run("BasicPrecedenceSwitch", func(t *testing.T) {
cfg := New()
cfg.Register("test.value", "default")
// Set different values in each source
cfg.SetSource(SourceFile, "test.value", "from-file")
cfg.SetSource(SourceEnv, "test.value", "from-env")
cfg.SetSource(SourceCLI, "test.value", "from-cli")
// Default precedence: CLI > Env > File > Default
val, _ := cfg.Get("test.value")
assert.Equal(t, "from-cli", val)
// Switch to File > CLI > Env > Default
err := cfg.SetPrecedence(SourceFile, SourceCLI, SourceEnv, SourceDefault)
require.NoError(t, err)
val, _ = cfg.Get("test.value")
assert.Equal(t, "from-file", val)
// Verify precedence was updated
precedence := cfg.GetPrecedence()
assert.Equal(t, []Source{SourceFile, SourceCLI, SourceEnv, SourceDefault}, precedence)
})
t.Run("NoPrecedenceChangeOptimization", func(t *testing.T) {
cfg := New()
cfg.Register("test.value", "default")
cfg.SetSource(SourceFile, "test.value", "from-file")
// Set same precedence
initialPrecedence := cfg.GetPrecedence()
err := cfg.SetPrecedence(initialPrecedence...)
require.NoError(t, err)
// Should be no-op, verify by checking version
version1 := cfg.version.Load()
err = cfg.SetPrecedence(initialPrecedence...)
require.NoError(t, err)
version2 := cfg.version.Load()
assert.Equal(t, version1, version2, "Version should not change on no-op")
})
t.Run("AutoAddDefaultSource", func(t *testing.T) {
cfg := New()
// Set precedence without SourceDefault
err := cfg.SetPrecedence(SourceCLI, SourceFile, SourceEnv)
require.NoError(t, err)
// SourceDefault should be auto-appended
precedence := cfg.GetPrecedence()
assert.Equal(t, []Source{SourceCLI, SourceFile, SourceEnv, SourceDefault}, precedence)
})
t.Run("InvalidSourceError", func(t *testing.T) {
cfg := New()
// Try to set invalid source
err := cfg.SetPrecedence(Source("invalid"), SourceFile)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid source")
// Precedence should remain unchanged
precedence := cfg.GetPrecedence()
assert.Equal(t, []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault}, precedence)
})
t.Run("PrecedenceChangeNotifications", func(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "test.toml")
os.WriteFile(configFile, []byte(`value = "from-file"`), 0644)
cfg := New()
cfg.Register("value", "default")
cfg.LoadFile(configFile)
cfg.SetSource(SourceCLI, "value", "from-cli")
// Enable watching
opts := WatchOptions{
PollInterval: 100 * time.Millisecond,
Debounce: 50 * time.Millisecond,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
// Start watching for changes
changes := cfg.Watch()
// Change precedence - should trigger notification
go func() {
time.Sleep(50 * time.Millisecond) // Let watcher start
cfg.SetPrecedence(SourceFile, SourceCLI, SourceEnv, SourceDefault)
}()
// Wait for precedence change notification
select {
case change := <-changes:
assert.Equal(t, "precedence:value", change)
// Verify value changed
val, _ := cfg.Get("value")
assert.Equal(t, "from-file", val)
case <-time.After(500 * time.Millisecond):
t.Error("Timeout waiting for precedence change notification")
}
})
t.Run("MultipleValuesAffected", func(t *testing.T) {
cfg := New()
paths := []string{"app.name", "app.version", "app.debug"}
for _, path := range paths {
cfg.Register(path, "default-"+path)
cfg.SetSource(SourceFile, path, "file-"+path)
cfg.SetSource(SourceEnv, path, "env-"+path)
}
// Initial state: Env wins
cfg.SetPrecedence(SourceEnv, SourceFile, SourceDefault)
for _, path := range paths {
val, _ := cfg.Get(path)
assert.Equal(t, "env-"+path, val)
}
// Switch: File wins
err := cfg.SetPrecedence(SourceFile, SourceEnv, SourceDefault)
require.NoError(t, err)
for _, path := range paths {
val, _ := cfg.Get(path)
assert.Equal(t, "file-"+path, val)
}
})
t.Run("ConcurrentPrecedenceChanges", func(t *testing.T) {
cfg := New()
cfg.Register("test", "default")
cfg.SetSource(SourceFile, "test", "file")
cfg.SetSource(SourceCLI, "test", "cli")
var wg sync.WaitGroup
errors := make(chan error, 20)
// Multiple goroutines changing precedence
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
var sources []Source
if id%2 == 0 {
sources = []Source{SourceFile, SourceCLI, SourceDefault}
} else {
sources = []Source{SourceCLI, SourceFile, SourceDefault}
}
if err := cfg.SetPrecedence(sources...); err != nil {
errors <- err
}
}(i)
}
// Concurrent reads during precedence changes
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
val, exists := cfg.Get("test")
if !exists {
errors <- fmt.Errorf("value not found during concurrent access")
}
// Value should be either "file" or "cli"
if val != "file" && val != "cli" {
errors <- fmt.Errorf("unexpected value: %v", val)
}
}()
}
wg.Wait()
close(errors)
// Check for errors
var errs []error
for err := range errors {
errs = append(errs, err)
}
assert.Empty(t, errs, "Concurrent precedence changes should not produce errors")
})
}
// TestPrecedenceWithAutoUpdate verifies no conflicts between precedence and auto-update
func TestPrecedenceWithAutoUpdate(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "test.toml")
// Initial file content
os.WriteFile(configFile, []byte(`
server = "file-server-1"
port = 8080
`), 0644)
cfg := New()
cfg.Register("server", "default-server")
cfg.Register("port", 0)
// Load with CLI override
cfg.LoadFile(configFile)
cfg.SetSource(SourceCLI, "server", "cli-server")
// CLI wins initially
val, _ := cfg.Get("server")
assert.Equal(t, "cli-server", val)
// Enable auto-update
opts := WatchOptions{
PollInterval: 100 * time.Millisecond,
Debounce: 50 * time.Millisecond,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
// Switch precedence to File > CLI
err := cfg.SetPrecedence(SourceFile, SourceCLI, SourceEnv, SourceDefault)
require.NoError(t, err)
// File should now win
val, _ = cfg.Get("server")
assert.Equal(t, "file-server-1", val)
// Update file
os.WriteFile(configFile, []byte(`
server = "file-server-2"
port = 9090
`), 0644)
// Wait for auto-update
time.Sleep(300 * time.Millisecond)
// File still wins with new value
val, _ = cfg.Get("server")
assert.Equal(t, "file-server-2", val)
// CLI value is preserved but not active
cliVal, exists := cfg.GetSource("server", SourceCLI)
assert.True(t, exists)
assert.Equal(t, "cli-server", cliVal)
// Switch back to CLI > File
err = cfg.SetPrecedence(SourceCLI, SourceFile, SourceEnv, SourceDefault)
require.NoError(t, err)
// CLI wins again
val, _ = cfg.Get("server")
assert.Equal(t, "cli-server", val)
}
// TestTypeConversion tests automatic type conversion through mapstructure
func TestTypeConversion(t *testing.T) {
type TestConfig struct {
IntValue int64 `toml:"int"`
FloatValue float64 `toml:"float"`
BoolValue bool `toml:"bool"`
Duration time.Duration `toml:"duration"`
Time time.Time `toml:"time"`
IP net.IP `toml:"ip"`
IPNet *net.IPNet `toml:"ipnet"`
URL *url.URL `toml:"url"`
StringSlice []string `toml:"strings"`
IntSlice []int `toml:"ints"`
}
cfg := New()
defaults := &TestConfig{
IntValue: 42,
FloatValue: 3.14,
BoolValue: true,
Duration: 5 * time.Second,
Time: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
IP: net.ParseIP("127.0.0.1"),
StringSlice: []string{"a", "b"},
IntSlice: []int{1, 2, 3},
}
err := cfg.RegisterStruct("", defaults)
require.NoError(t, err)
// Test string conversions from environment
cfg.SetSource(SourceEnv, "int", "100")
cfg.SetSource(SourceEnv, "float", "2.718")
cfg.SetSource(SourceEnv, "bool", "false")
cfg.SetSource(SourceEnv, "duration", "1m30s")
cfg.SetSource(SourceEnv, "time", "2024-12-25T10:00:00Z")
cfg.SetSource(SourceEnv, "ip", "192.168.1.1")
cfg.SetSource(SourceEnv, "ipnet", "10.0.0.0/8")
cfg.SetSource(SourceEnv, "url", "https://example.com:8080/path")
cfg.SetSource(SourceEnv, "strings", "x,y,z")
// cfg.SetSource("ints", SourceEnv, "7,8,9") // failure due to mapstructure limitation
// Scan into struct
var result TestConfig
err = cfg.Scan(&result)
require.NoError(t, err)
assert.Equal(t, int64(100), result.IntValue)
assert.Equal(t, 2.718, result.FloatValue)
assert.Equal(t, false, result.BoolValue)
assert.Equal(t, 90*time.Second, result.Duration)
assert.Equal(t, "2024-12-25T10:00:00Z", result.Time.Format(time.RFC3339))
assert.Equal(t, "192.168.1.1", result.IP.String())
assert.Equal(t, "10.0.0.0/8", result.IPNet.String())
assert.Equal(t, "https://example.com:8080/path", result.URL.String())
assert.Equal(t, []string{"x", "y", "z"}, result.StringSlice)
// Note: String to int slice conversion through env requires handling in the test
}
// TestConcurrentAccess tests thread safety
func TestConcurrentAccess(t *testing.T) {
cfg := New()
// Register paths
for i := 0; i < 100; i++ {
cfg.Register(fmt.Sprintf("path%d", i), i)
}
var wg sync.WaitGroup
errors := make(chan error, 1000)
// Concurrent readers
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 100; j++ {
path := fmt.Sprintf("path%d", j)
if _, exists := cfg.Get(path); !exists {
errors <- fmt.Errorf("reader %d: path %s not found", id, path)
}
}
}(i)
}
// Concurrent writers
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 100; j++ {
path := fmt.Sprintf("path%d", j)
value := fmt.Sprintf("writer%d-value%d", id, j)
if err := cfg.Set(path, value); err != nil {
errors <- fmt.Errorf("writer %d: %v", id, err)
}
}
}(i)
}
// Concurrent source changes
for i := 0; i < 3; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
sources := []Source{SourceFile, SourceEnv, SourceCLI}
for j := 0; j < 50; j++ {
path := fmt.Sprintf("path%d", j)
source := sources[j%len(sources)]
value := fmt.Sprintf("source%d-value%d", id, j)
if err := cfg.SetSource(source, path, value); err != nil {
errors <- fmt.Errorf("source writer %d: %v", id, err)
}
}
}(i)
}
wg.Wait()
close(errors)
// Check for errors
var errs []error
for err := range errors {
errs = append(errs, err)
}
assert.Empty(t, errs, "Concurrent access should not produce errors")
}
// TestUnregister tests path unregistration
func TestUnregister(t *testing.T) {
cfg := New()
// Register nested paths
cfg.Register("server.host", "localhost")
cfg.Register("server.port", 8080)
cfg.Register("server.tls.enabled", true)
cfg.Register("server.tls.cert", "/path/to/cert")
cfg.Register("database.host", "dbhost")
t.Run("UnregisterSinglePath", func(t *testing.T) {
err := cfg.Unregister("server.port")
assert.NoError(t, err)
_, exists := cfg.Get("server.port")
assert.False(t, exists)
// Other paths should remain
_, exists = cfg.Get("server.host")
assert.True(t, exists)
})
t.Run("UnregisterParentPath", func(t *testing.T) {
err := cfg.Unregister("server.tls")
assert.NoError(t, err)
// All child paths should be removed
_, exists := cfg.Get("server.tls.enabled")
assert.False(t, exists)
_, exists = cfg.Get("server.tls.cert")
assert.False(t, exists)
// Sibling paths should remain
_, exists = cfg.Get("server.host")
assert.True(t, exists)
})
t.Run("UnregisterNonExistentPath", func(t *testing.T) {
err := cfg.Unregister("nonexistent.path")
assert.Error(t, err)
assert.Contains(t, err.Error(), "path not registered")
})
}
// TestResetFunctionality tests reset operations
func TestResetFunctionality(t *testing.T) {
cfg := New()
cfg.Register("test1", "default1")
cfg.Register("test2", "default2")
// Set values in different sources
cfg.SetSource(SourceFile, "test1", "file1")
cfg.SetSource(SourceEnv, "test1", "env1")
cfg.SetSource(SourceCLI, "test2", "cli2")
t.Run("ResetSingleSource", func(t *testing.T) {
cfg.ResetSource(SourceEnv)
// Env value should be gone
_, exists := cfg.GetSource("test1", SourceEnv)
assert.False(t, exists)
// Other sources should remain
val, exists := cfg.GetSource("test1", SourceFile)
assert.True(t, exists)
assert.Equal(t, "file1", val)
})
t.Run("ResetAll", func(t *testing.T) {
cfg.Reset()
// All values should revert to defaults
val1, _ := cfg.Get("test1")
val2, _ := cfg.Get("test2")
assert.Equal(t, "default1", val1)
assert.Equal(t, "default2", val2)
// Source values should be cleared
sources := cfg.GetSources("test1")
assert.Empty(t, sources)
})
}
// TestValueSizeLimit tests the MaxValueSize constraint
func TestValueSizeLimit(t *testing.T) {
cfg := New()
cfg.Register("test", "")
// Create a value larger than MaxValueSize
largeValue := make([]byte, MaxValueSize+1)
for i := range largeValue {
largeValue[i] = 'x'
}
err := cfg.Set("test", string(largeValue))
assert.Error(t, err)
assert.Equal(t, ErrValueSize, err)
}
// TestGetRegisteredPaths tests path listing functionality
func TestGetRegisteredPaths(t *testing.T) {
cfg := New()
paths := []string{
"server.host",
"server.port",
"server.tls.enabled",
"database.host",
"database.port",
"cache.ttl",
}
for _, path := range paths {
cfg.Register(path, "")
}
t.Run("GetAllPaths", func(t *testing.T) {
all := cfg.GetRegisteredPaths("")
assert.Len(t, all, len(paths))
for _, path := range paths {
assert.True(t, all[path])
}
})
t.Run("GetPathsWithPrefix", func(t *testing.T) {
serverPaths := cfg.GetRegisteredPaths("server.")
assert.Len(t, serverPaths, 3)
assert.True(t, serverPaths["server.host"])
assert.True(t, serverPaths["server.port"])
assert.True(t, serverPaths["server.tls.enabled"])
})
t.Run("GetPathsWithDefaults", func(t *testing.T) {
defaults := cfg.GetRegisteredPathsWithDefaults("database.")
assert.Len(t, defaults, 2)
assert.Contains(t, defaults, "database.host")
assert.Contains(t, defaults, "database.port")
})
}

View File

@ -1,13 +1,15 @@
// File: lixenwraith/config/convenience.go // FILE: lixenwraith/config/convenience.go
package config package config
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/BurntSushi/toml" "github.com/mitchellh/mapstructure"
"os" "os"
"reflect" "reflect"
"strings" "strings"
"github.com/BurntSushi/toml"
) )
// Quick creates a fully configured Config instance with a single call // Quick creates a fully configured Config instance with a single call
@ -86,16 +88,22 @@ func (c *Config) GenerateFlags() *flag.FlagSet {
// BindFlags updates configuration from parsed flag.FlagSet // BindFlags updates configuration from parsed flag.FlagSet
func (c *Config) BindFlags(fs *flag.FlagSet) error { func (c *Config) BindFlags(fs *flag.FlagSet) error {
var errors []error var errors []error
needsInvalidation := false
fs.Visit(func(f *flag.Flag) { fs.Visit(func(f *flag.Flag) {
value := f.Value.String() value := f.Value.String()
parsed := parseValue(value) // Let mapstructure handle type conversion
if err := c.SetSource(SourceCLI, f.Name, value); err != nil {
if err := c.SetSource(f.Name, SourceCLI, parsed); err != nil {
errors = append(errors, fmt.Errorf("flag %s: %w", f.Name, err)) errors = append(errors, fmt.Errorf("flag %s: %w", f.Name, err))
} else {
needsInvalidation = true
} }
}) })
if needsInvalidation {
c.invalidateCache() // Batch invalidation after all flags
}
if len(errors) > 0 { if len(errors) > 0 {
return fmt.Errorf("failed to bind %d flags: %w", len(errors), errors[0]) return fmt.Errorf("failed to bind %d flags: %w", len(errors), errors[0])
} }
@ -141,16 +149,6 @@ func (c *Config) Validate(required ...string) error {
return nil return nil
} }
// Watch returns a channel that receives updates when configuration values change
// This is useful for hot-reloading configurations
// Note: This is a placeholder for future enhancement
func (c *Config) Watch() <-chan string {
// TODO: Implement file watching and config reload
ch := make(chan string)
close(ch) // Close immediately for now
return ch
}
// Debug returns a formatted string showing all configuration values and their sources // Debug returns a formatted string showing all configuration values and their sources
func (c *Config) Debug() string { func (c *Config) Debug() string {
c.mutex.RLock() c.mutex.RLock()
@ -229,3 +227,59 @@ func (c *Config) Clone() *Config {
return clone return clone
} }
// QuickTyped creates a fully configured Config with a typed target
func QuickTyped[T any](target *T, envPrefix, configFile string) (*Config, error) {
return NewBuilder().
WithTarget(target).
WithEnvPrefix(envPrefix).
WithFile(configFile).
Build()
}
// GetTyped retrieves a configuration value and decodes it into the specified type T.
// It leverages the same decoding hooks as the Scan and AsStruct methods,
// providing type conversion from strings, numbers, etc.
func GetTyped[T any](c *Config, path string) (T, error) {
var zero T
rawValue, exists := c.Get(path)
if !exists {
return zero, fmt.Errorf("path %q not found", path)
}
// Prepare the input map and target struct for the decoder.
inputMap := map[string]any{"value": rawValue}
var target struct {
Value T `mapstructure:"value"`
}
// Create a new decoder configured with the same hooks as the main config.
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &target,
TagName: c.tagName,
WeaklyTypedInput: true,
DecodeHook: c.getDecodeHook(),
Metadata: nil,
})
if err != nil {
return zero, fmt.Errorf("failed to create decoder for path %q: %w", path, err)
}
// Decode the single value.
if err := decoder.Decode(inputMap); err != nil {
return zero, fmt.Errorf("failed to decode value for path %q into type %T: %w", path, zero, err)
}
return target.Value, nil
}
// ScanTyped is a generic wrapper around Scan. It allocates a new instance of type T,
// populates it with configuration data from the given base path, and returns a pointer to it.
func ScanTyped[T any](c *Config, basePath ...string) (*T, error) {
var target T
if err := c.Scan(&target, basePath...); err != nil {
return nil, err
}
return &target, nil
}

327
convenience_test.go Normal file
View File

@ -0,0 +1,327 @@
// FILE: lixenwraith/config/convenience_test.go
package config
import (
"flag"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestQuickFunctions tests the convenience Quick* functions
func TestQuickFunctions(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "quick.toml")
os.WriteFile(configFile, []byte(`
host = "quickhost"
port = 7777
`), 0644)
type QuickConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
SSL bool `toml:"ssl"`
}
defaults := &QuickConfig{
Host: "localhost",
Port: 8080,
SSL: false,
}
t.Run("Quick", func(t *testing.T) {
// Mock os.Args
oldArgs := os.Args
os.Args = []string{"cmd", "--port=9999"}
defer func() { os.Args = oldArgs }()
cfg, err := Quick(defaults, "QUICK_", configFile)
require.NoError(t, err)
// CLI should override
port, _ := cfg.Get("port")
assert.Equal(t, "9999", port)
// File value
host, _ := cfg.Get("host")
assert.Equal(t, "quickhost", host)
})
t.Run("QuickCustom", func(t *testing.T) {
opts := LoadOptions{
Sources: []Source{SourceFile, SourceDefault}, // Only file and defaults
EnvPrefix: "CUSTOM_",
}
cfg, err := QuickCustom(defaults, opts, configFile)
require.NoError(t, err)
// Should use file value
port, _ := cfg.Get("port")
assert.Equal(t, int64(7777), port)
})
t.Run("MustQuickPanic", func(t *testing.T) {
// Valid case - should not panic
assert.NotPanics(t, func() {
cfg := MustQuick(defaults, "TEST_", configFile)
assert.NotNil(t, cfg)
})
// Invalid struct - should panic
assert.Panics(t, func() {
MustQuick("not-a-struct", "TEST_", configFile)
})
})
t.Run("QuickTyped", func(t *testing.T) {
target := &QuickConfig{
Host: "typedhost",
Port: 6666,
SSL: true,
}
cfg, err := QuickTyped(target, "TYPED_", configFile)
require.NoError(t, err)
// Should populate from file
updated, err := cfg.AsStruct()
require.NoError(t, err)
typedCfg := updated.(*QuickConfig)
assert.Equal(t, "quickhost", typedCfg.Host)
assert.Equal(t, 7777, typedCfg.Port)
})
}
// TestFlagGeneration tests flag generation and binding
func TestFlagGeneration(t *testing.T) {
cfg := New()
cfg.Register("server.host", "localhost")
cfg.Register("server.port", 8080)
cfg.Register("debug.enabled", false)
cfg.Register("timeout", 30.5)
cfg.Register("name", "app")
cfg.Register("complex", map[string]any{"key": "value"})
t.Run("GenerateFlags", func(t *testing.T) {
fs := cfg.GenerateFlags()
require.NotNil(t, fs)
// Verify flags exist
hostFlag := fs.Lookup("server.host")
require.NotNil(t, hostFlag)
assert.Equal(t, "localhost", hostFlag.DefValue)
portFlag := fs.Lookup("server.port")
require.NotNil(t, portFlag)
assert.Equal(t, "8080", portFlag.DefValue)
debugFlag := fs.Lookup("debug.enabled")
require.NotNil(t, debugFlag)
assert.Equal(t, "false", debugFlag.DefValue)
timeoutFlag := fs.Lookup("timeout")
require.NotNil(t, timeoutFlag)
assert.Equal(t, "30.5", timeoutFlag.DefValue)
})
t.Run("BindFlags", func(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.String("server.host", "default", "")
fs.Int("server.port", 8080, "")
fs.Bool("debug.enabled", false, "")
// Parse with test values
err := fs.Parse([]string{"-server.host=flaghost", "-server.port=5555", "-debug.enabled"})
require.NoError(t, err)
// Bind to config
err = cfg.BindFlags(fs)
require.NoError(t, err)
// Verify values were set
host, _ := cfg.Get("server.host")
assert.Equal(t, "flaghost", host)
port, _ := cfg.Get("server.port")
assert.Equal(t, "5555", port)
debug, _ := cfg.Get("debug.enabled")
assert.Equal(t, "true", debug)
})
t.Run("BindFlagsError", func(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.String("unregistered.path", "value", "")
fs.Parse([]string{"-unregistered.path=test"})
err := cfg.BindFlags(fs)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to bind 1 flags")
})
}
// TestValidation tests configuration validation
func TestValidation(t *testing.T) {
cfg := New()
cfg.Register("required.host", "")
cfg.Register("required.port", 0)
cfg.Register("optional.timeout", 30)
t.Run("ValidationFails", func(t *testing.T) {
err := cfg.Validate("required.host", "required.port")
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing required configuration")
assert.Contains(t, err.Error(), "required.host")
assert.Contains(t, err.Error(), "required.port")
})
t.Run("ValidationPasses", func(t *testing.T) {
cfg.Set("required.host", "localhost")
cfg.Set("required.port", 8080)
err := cfg.Validate("required.host", "required.port")
assert.NoError(t, err)
})
t.Run("ValidationUnregisteredPath", func(t *testing.T) {
err := cfg.Validate("nonexistent.path")
assert.Error(t, err)
assert.Contains(t, err.Error(), "nonexistent.path (not registered)")
})
t.Run("ValidationWithSourceValue", func(t *testing.T) {
cfg2 := New()
cfg2.Register("test", "default")
// Value equals default but from different source
cfg2.SetSource(SourceEnv, "test", "default")
err := cfg2.Validate("test")
assert.NoError(t, err) // Should pass because env provided value
})
}
// TestDebugAndDump tests debug output functions
func TestDebugAndDump(t *testing.T) {
cfg := New()
cfg.Register("server.host", "localhost")
cfg.Register("server.port", 8080)
cfg.SetSource(SourceFile, "server.host", "filehost")
cfg.SetSource(SourceEnv, "server.host", "envhost")
cfg.SetSource(SourceCLI, "server.port", "9999")
t.Run("Debug", func(t *testing.T) {
debug := cfg.Debug()
assert.Contains(t, debug, "Configuration Debug Info")
assert.Contains(t, debug, "Precedence:")
assert.Contains(t, debug, "server.host:")
assert.Contains(t, debug, "Current: envhost")
assert.Contains(t, debug, "Default: localhost")
assert.Contains(t, debug, "file: filehost")
assert.Contains(t, debug, "env: envhost")
})
t.Run("Dump", func(t *testing.T) {
// Capture stdout
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
err := cfg.Dump()
assert.NoError(t, err)
w.Close()
os.Stdout = oldStdout
// Read output
output := make([]byte, 1024)
n, _ := r.Read(output)
outputStr := string(output[:n])
assert.Contains(t, outputStr, "[server]")
assert.Contains(t, outputStr, "host = ")
assert.Contains(t, outputStr, "port = ")
})
}
// TestClone tests configuration cloning
func TestClone(t *testing.T) {
cfg := New()
cfg.Register("original.value", "default")
cfg.Register("shared.value", "shared")
cfg.SetSource(SourceFile, "original.value", "filevalue")
cfg.SetSource(SourceEnv, "shared.value", "envvalue")
clone := cfg.Clone()
require.NotNil(t, clone)
// Verify values are copied
val, exists := clone.Get("original.value")
assert.True(t, exists)
assert.Equal(t, "filevalue", val)
val, exists = clone.Get("shared.value")
assert.True(t, exists)
assert.Equal(t, "envvalue", val)
// Modify clone should not affect original
clone.Set("original.value", "clonevalue")
originalVal, _ := cfg.Get("original.value")
cloneVal, _ := clone.Get("original.value")
assert.Equal(t, "filevalue", originalVal)
assert.Equal(t, "clonevalue", cloneVal)
// Verify source data is copied
sources := clone.GetSources("shared.value")
assert.Equal(t, "envvalue", sources[SourceEnv])
}
func TestGenericHelpers(t *testing.T) {
cfg := New()
cfg.Register("server.host", "localhost")
cfg.Register("server.port", "8080") // Note: string value
cfg.Register("features.dark_mode", true)
cfg.Register("timeouts.read", "5s")
t.Run("GetTyped", func(t *testing.T) {
port, err := GetTyped[int](cfg, "server.port")
require.NoError(t, err)
assert.Equal(t, 8080, port)
host, err := GetTyped[string](cfg, "server.host")
require.NoError(t, err)
assert.Equal(t, "localhost", host)
// Test with custom decode hook type
readTimeout, err := GetTyped[time.Duration](cfg, "timeouts.read")
require.NoError(t, err)
assert.Equal(t, 5*time.Second, readTimeout)
_, err = GetTyped[int](cfg, "nonexistent.path")
assert.Error(t, err)
})
t.Run("ScanTyped", func(t *testing.T) {
type ServerConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
}
serverConf, err := ScanTyped[ServerConfig](cfg, "server")
require.NoError(t, err)
require.NotNil(t, serverConf)
assert.Equal(t, "localhost", serverConf.Host)
assert.Equal(t, 8080, serverConf.Port)
})
}

301
decode.go Normal file
View File

@ -0,0 +1,301 @@
// FILE: lixenwraith/config/decode.go
package config
import (
"encoding/json"
"fmt"
"net"
"net/url"
"reflect"
"strings"
"time"
"github.com/mitchellh/mapstructure"
)
// unmarshal is the single authoritative function for decoding configuration
// into target structures. All public decoding methods delegate to this.
func (c *Config) unmarshal(source Source, target any, basePath ...string) error {
// Parse variadic basePath
path := ""
switch len(basePath) {
case 0:
// Use default empty path
case 1:
path = basePath[0]
default:
return fmt.Errorf("too many basePath arguments: expected 0 or 1, got %d", len(basePath))
}
// Validate target
rv := reflect.ValueOf(target)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return fmt.Errorf("unmarshal target must be non-nil pointer, got %T", target)
}
c.mutex.RLock()
defer c.mutex.RUnlock()
// Build nested map based on source selection
nestedMap := make(map[string]any)
if source == "" {
// Use current merged state
for path, item := range c.items {
setNestedValue(nestedMap, path, item.currentValue)
}
} else {
// Use specific source
for path, item := range c.items {
if val, exists := item.values[source]; exists {
setNestedValue(nestedMap, path, val)
}
}
}
// Navigate to basePath section
sectionData := navigateToPath(nestedMap, path)
// Ensure we have a map to decode, normalizing if necessary.
sectionMap, err := normalizeMap(sectionData)
if err != nil {
if sectionData == nil {
sectionMap = make(map[string]any) // Empty section is valid.
} else {
// Path points to a non-map value, which is an error for Scan.
return fmt.Errorf("path %q refers to non-map value (type %T)", path, sectionData)
}
}
// Create decoder with comprehensive hooks
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: target,
TagName: c.tagName,
WeaklyTypedInput: true,
DecodeHook: c.getDecodeHook(),
ZeroFields: true,
Metadata: nil,
})
if err != nil {
return fmt.Errorf("decoder creation failed: %w", err)
}
if err := decoder.Decode(sectionMap); err != nil {
return fmt.Errorf("decode failed for path %q: %w", path, err)
}
return nil
}
// normalizeMap ensures that the input data is a map[string]any for the decoder.
func normalizeMap(data any) (map[string]any, error) {
if data == nil {
return make(map[string]any), nil
}
// If it's already the correct type, return it.
if m, ok := data.(map[string]any); ok {
return m, nil
}
// Use reflection to handle other map types (e.g., map[string]bool)
v := reflect.ValueOf(data)
if v.Kind() == reflect.Map {
if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("map keys must be strings, but got %v", v.Type().Key())
}
// Create a new map[string]any and copy the values.
normalized := make(map[string]any, v.Len())
iter := v.MapRange()
for iter.Next() {
normalized[iter.Key().String()] = iter.Value().Interface()
}
return normalized, nil
}
return nil, fmt.Errorf("expected a map but got %T", data)
}
// getDecodeHook returns the composite decode hook for all type conversions
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
return mapstructure.ComposeDecodeHookFunc(
// JSON Number handling
jsonNumberHookFunc(),
// Network types
stringToNetIPHookFunc(),
stringToNetIPNetHookFunc(),
stringToURLHookFunc(),
// Standard hooks
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToTimeHookFunc(time.RFC3339),
mapstructure.StringToSliceHookFunc(","),
// Custom application hooks
c.customDecodeHook(),
)
}
// jsonNumberHookFunc handles json.Number conversion to appropriate numeric types
func jsonNumberHookFunc() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
// Check if source is json.Number
if f != reflect.TypeOf(json.Number("")) {
return data, nil
}
num := data.(json.Number)
// Convert based on target type
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return num.Int64()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as int64 first, then convert
i, err := num.Int64()
if err != nil {
return nil, err
}
if i < 0 {
return nil, fmt.Errorf("cannot convert negative number to unsigned type")
}
return uint64(i), nil
case reflect.Float32, reflect.Float64:
return num.Float64()
case reflect.String:
return num.String(), nil
default:
// Return as-is for other types
return data, nil
}
}
}
// stringToNetIPHookFunc handles net.IP conversion
func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
if f.Kind() != reflect.String {
return data, nil
}
if t != reflect.TypeOf(net.IP{}) {
return data, nil
}
// SECURITY: Validate IP string format to prevent injection
str := data.(string)
if len(str) > 45 { // Max IPv6 length
return nil, fmt.Errorf("invalid IP length: %d", len(str))
}
ip := net.ParseIP(str)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", str)
}
return ip, nil
}
}
// stringToNetIPNetHookFunc handles net.IPNet conversion
func stringToNetIPNetHookFunc() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
if f.Kind() != reflect.String {
return data, nil
}
isPtr := t.Kind() == reflect.Ptr
targetType := t
if isPtr {
targetType = t.Elem()
}
if targetType != reflect.TypeOf(net.IPNet{}) {
return data, nil
}
str := data.(string)
if len(str) > 49 { // Max IPv6 CIDR length
return nil, fmt.Errorf("invalid CIDR length: %d", len(str))
}
_, ipnet, err := net.ParseCIDR(str)
if err != nil {
return nil, fmt.Errorf("invalid CIDR: %w", err)
}
if isPtr {
return ipnet, nil
}
return *ipnet, nil
}
}
// stringToURLHookFunc handles url.URL conversion
func stringToURLHookFunc() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
if f.Kind() != reflect.String {
return data, nil
}
isPtr := t.Kind() == reflect.Ptr
targetType := t
if isPtr {
targetType = t.Elem()
}
if targetType != reflect.TypeOf(url.URL{}) {
return data, nil
}
str := data.(string)
if len(str) > 2048 {
return nil, fmt.Errorf("URL too long: %d bytes", len(str))
}
u, err := url.Parse(str)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if isPtr {
return u, nil
}
return *u, nil
}
}
// customDecodeHook allows for application-specific type conversions
func (c *Config) customDecodeHook() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
// SECURITY: Add custom validation for application types here
// Example: Rate limit parsing, permission validation, etc.
// Pass through by default
return data, nil
}
}
// navigateToPath traverses nested map to reach the specified path
func navigateToPath(nested map[string]any, path string) any {
if path == "" {
return nested
}
path = strings.TrimSuffix(path, ".")
if path == "" {
return nested
}
segments := strings.Split(path, ".")
current := any(nested)
for _, segment := range segments {
currentMap, ok := current.(map[string]any)
if !ok {
return nil
}
value, exists := currentMap[segment]
if !exists {
return nil
}
current = value
}
return current
}

328
decode_test.go Normal file
View File

@ -0,0 +1,328 @@
// FILE: lixenwraith/config/decode_test.go
package config
import (
"net"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestScanWithComplexTypes tests scanning with various complex types
func TestScanWithComplexTypes(t *testing.T) {
type NetworkConfig struct {
IP net.IP `toml:"ip"`
IPNet *net.IPNet `toml:"subnet"`
URL *url.URL `toml:"endpoint"`
Timeout time.Duration `toml:"timeout"`
Retry struct {
Count int `toml:"count"`
Interval time.Duration `toml:"interval"`
} `toml:"retry"`
}
type AppConfig struct {
Network NetworkConfig `toml:"network"`
Tags []string `toml:"tags"`
Ports []int `toml:"ports"`
Labels map[string]string `toml:"labels"`
}
cfg := New()
// Register with defaults
defaults := &AppConfig{
Network: NetworkConfig{
IP: net.ParseIP("127.0.0.1"),
Timeout: 30 * time.Second,
},
Tags: []string{"default"},
Ports: []int{8080},
Labels: map[string]string{
"env": "dev",
},
}
err := cfg.RegisterStruct("", defaults)
require.NoError(t, err)
// Set values from different sources
cfg.SetSource(SourceEnv, "network.ip", "192.168.1.100")
cfg.SetSource(SourceEnv, "network.subnet", "192.168.1.0/24")
cfg.SetSource(SourceEnv, "network.endpoint", "https://api.example.com:8443/v1")
cfg.SetSource(SourceFile, "network.timeout", "2m30s")
cfg.SetSource(SourceFile, "network.retry.count", int64(5))
cfg.SetSource(SourceFile, "network.retry.interval", "10s")
cfg.SetSource(SourceCLI, "tags", "prod,staging,test")
cfg.SetSource(SourceFile, "ports", []any{int64(80), int64(443), int64(8080)})
cfg.SetSource(SourceFile, "labels", map[string]any{
"env": "production",
"version": "1.2.3",
})
// Scan into struct
var result AppConfig
err = cfg.Scan(&result)
require.NoError(t, err)
// Verify conversions
assert.Equal(t, "192.168.1.100", result.Network.IP.String())
assert.Equal(t, "192.168.1.0/24", result.Network.IPNet.String())
assert.Equal(t, "https://api.example.com:8443/v1", result.Network.URL.String())
assert.Equal(t, 150*time.Second, result.Network.Timeout)
assert.Equal(t, 5, result.Network.Retry.Count)
assert.Equal(t, 10*time.Second, result.Network.Retry.Interval)
assert.Equal(t, []string{"prod", "staging", "test"}, result.Tags)
assert.Equal(t, []int{80, 443, 8080}, result.Ports)
assert.Equal(t, "production", result.Labels["env"])
assert.Equal(t, "1.2.3", result.Labels["version"])
}
// TestScanWithBasePath tests scanning from nested paths
func TestScanWithBasePath(t *testing.T) {
type ServerConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
Enabled bool `toml:"enabled"`
}
cfg := New()
cfg.Register("app.server.host", "localhost")
cfg.Register("app.server.port", 8080)
cfg.Register("app.server.enabled", true)
cfg.Register("app.database.host", "dbhost")
cfg.Set("app.server.host", "appserver")
cfg.Set("app.server.port", 9000)
// Scan only the server section
var server ServerConfig
err := cfg.Scan(&server, "app.server")
require.NoError(t, err)
assert.Equal(t, "appserver", server.Host)
assert.Equal(t, 9000, server.Port)
assert.Equal(t, true, server.Enabled)
// Test non-existent base path
var empty ServerConfig
err = cfg.Scan(&empty, "app.nonexistent")
assert.NoError(t, err) // Should not error, just empty
assert.Equal(t, "", empty.Host)
assert.Equal(t, 0, empty.Port)
}
// TestScanFromSource tests scanning from specific sources
func TestScanFromSource(t *testing.T) {
type Config struct {
Value string `toml:"value"`
}
cfg := New()
cfg.Register("value", "default")
cfg.SetSource(SourceFile, "value", "fromfile")
cfg.SetSource(SourceEnv, "value", "fromenv")
cfg.SetSource(SourceCLI, "value", "fromcli")
tests := []struct {
source Source
expected string
}{
{SourceFile, "fromfile"},
{SourceEnv, "fromenv"},
{SourceCLI, "fromcli"},
{SourceDefault, ""}, // No value in default source
}
for _, tt := range tests {
t.Run(string(tt.source), func(t *testing.T) {
var result Config
err := cfg.ScanSource(tt.source, &result)
require.NoError(t, err)
assert.Equal(t, tt.expected, result.Value)
})
}
}
// TestInvalidScanTargets tests error cases for scanning
func TestInvalidScanTargets(t *testing.T) {
cfg := New()
cfg.Register("test", "value")
tests := []struct {
name string
target any
expectErr string
}{
{"NilPointer", nil, "must be non-nil pointer"},
{"NonPointer", "not-a-pointer", "must be non-nil pointer"},
{"NilStructPointer", (*struct{})(nil), "must be non-nil pointer"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := cfg.Scan(tt.target)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.expectErr)
})
}
}
// TestCustomTypeConversion tests edge cases in type conversion
func TestCustomTypeConversion(t *testing.T) {
cfg := New()
t.Run("InvalidIPAddress", func(t *testing.T) {
type Config struct {
IP net.IP `toml:"ip"`
}
cfg.Register("ip", net.IP{})
cfg.Set("ip", "not-an-ip")
var result Config
err := cfg.Scan(&result)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid IP address")
})
t.Run("InvalidCIDR", func(t *testing.T) {
type Config struct {
Network *net.IPNet `toml:"network"`
}
cfg.Register("network", (*net.IPNet)(nil))
cfg.Set("network", "invalid-cidr")
var result Config
err := cfg.Scan(&result)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid CIDR")
})
t.Run("InvalidURL", func(t *testing.T) {
type Config struct {
Endpoint *url.URL `toml:"endpoint"`
}
cfg.Register("endpoint", (*url.URL)(nil))
cfg.Set("endpoint", "://invalid-url")
var result Config
err := cfg.Scan(&result)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid URL")
})
t.Run("LongIPString", func(t *testing.T) {
type Config struct {
IP net.IP `toml:"ip"`
}
cfg.Register("ip", net.IP{})
// String longer than max IPv6 length
longIP := make([]byte, 50)
for i := range longIP {
longIP[i] = 'x'
}
cfg.Set("ip", string(longIP))
var result Config
err := cfg.Scan(&result)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid IP length")
})
t.Run("LongURL", func(t *testing.T) {
type Config struct {
URL *url.URL `toml:"url"`
}
cfg.Register("url", (*url.URL)(nil))
// URL longer than 2048 bytes
longURL := "https://example.com/"
for i := 0; i < 2048; i++ {
longURL += "x"
}
cfg.Set("url", longURL)
var result Config
err := cfg.Scan(&result)
assert.Error(t, err)
assert.Contains(t, err.Error(), "URL too long")
})
}
// TestZeroFields tests that ZeroFields option works correctly
func TestZeroFields(t *testing.T) {
type Config struct {
KeepValue string `toml:"keep"`
ResetValue string `toml:"reset"`
NestedValue struct {
Field string `toml:"field"`
} `toml:"nested"`
}
cfg := New()
// Register only some fields
cfg.Register("keep", "keepdefault")
cfg.Register("reset", "resetdefault")
// Don't register nested.field
cfg.Set("keep", "newvalue")
// Don't set reset, so it uses default
// Start with non-zero struct
result := Config{
KeepValue: "initial",
ResetValue: "initial",
NestedValue: struct {
Field string `toml:"field"`
}{Field: "initial"},
}
err := cfg.Scan(&result)
require.NoError(t, err)
// ZeroFields should reset all fields before decoding
assert.Equal(t, "newvalue", result.KeepValue)
assert.Equal(t, "resetdefault", result.ResetValue)
assert.Equal(t, "initial", result.NestedValue.Field) // Unregistered, so Scan should not touch it
}
// TestWeaklyTypedInput tests weak type conversion
func TestWeaklyTypedInput(t *testing.T) {
type Config struct {
IntFromString int `toml:"int_from_string"`
FloatFromString float64 `toml:"float_from_string"`
BoolFromString bool `toml:"bool_from_string"`
StringFromInt string `toml:"string_from_int"`
StringFromBool string `toml:"string_from_bool"`
}
cfg := New()
defaults := &Config{}
cfg.RegisterStruct("", defaults)
// Set string values that should convert
cfg.Set("int_from_string", "42")
cfg.Set("float_from_string", "3.14159")
cfg.Set("bool_from_string", "true")
cfg.Set("string_from_int", 12345)
cfg.Set("string_from_bool", true)
var result Config
err := cfg.Scan(&result)
require.NoError(t, err)
assert.Equal(t, 42, result.IntFromString)
assert.Equal(t, 3.14159, result.FloatFromString)
assert.Equal(t, true, result.BoolFromString)
assert.Equal(t, "12345", result.StringFromInt)
assert.Equal(t, "1", result.StringFromBool) // mapstructure converts bool(true) to "1" in weak conversion
}

128
discovery.go Normal file
View File

@ -0,0 +1,128 @@
// FILE: lixenwraith/config/discovery.go
package config
import (
"os"
"path/filepath"
"strings"
)
// FileDiscoveryOptions configures automatic config file discovery
type FileDiscoveryOptions struct {
// Base name of config file (without extension)
Name string
// Extensions to try (in order)
Extensions []string
// Custom search paths (in addition to defaults)
Paths []string
// Environment variable to check for explicit path
EnvVar string
// CLI flag to check (e.g., "--config" or "-c")
CLIFlag string
// Whether to search in XDG config directories
UseXDG bool
// Whether to search in current directory
UseCurrentDir bool
}
// DefaultDiscoveryOptions returns sensible defaults
func DefaultDiscoveryOptions(appName string) FileDiscoveryOptions {
return FileDiscoveryOptions{
Name: appName,
Extensions: []string{".toml", ".conf", ".config"},
EnvVar: strings.ToUpper(appName) + "_CONFIG",
CLIFlag: "--config",
UseXDG: true,
UseCurrentDir: true,
}
}
// WithFileDiscovery enables automatic config file discovery
func (b *Builder) WithFileDiscovery(opts FileDiscoveryOptions) *Builder {
// Check CLI args first (highest priority)
if opts.CLIFlag != "" && len(b.args) > 0 {
for i, arg := range b.args {
if arg == opts.CLIFlag && i+1 < len(b.args) {
b.file = b.args[i+1]
return b
}
if strings.HasPrefix(arg, opts.CLIFlag+"=") {
b.file = strings.TrimPrefix(arg, opts.CLIFlag+"=")
return b
}
}
}
// Check environment variable
if opts.EnvVar != "" {
if path := os.Getenv(opts.EnvVar); path != "" {
b.file = path
return b
}
}
// Build search paths
var searchPaths []string
// Custom paths first
searchPaths = append(searchPaths, opts.Paths...)
// Current directory
if opts.UseCurrentDir {
if cwd, err := os.Getwd(); err == nil {
searchPaths = append(searchPaths, cwd)
}
}
// XDG paths
if opts.UseXDG {
searchPaths = append(searchPaths, getXDGConfigPaths(opts.Name)...)
}
// Search for config file
for _, dir := range searchPaths {
for _, ext := range opts.Extensions {
path := filepath.Join(dir, opts.Name+ext)
if _, err := os.Stat(path); err == nil {
b.file = path
return b
}
}
}
// No file found is not an error - app can run with defaults/env
return b
}
// getXDGConfigPaths returns XDG-compliant config search paths
func getXDGConfigPaths(appName string) []string {
var paths []string
// XDG_CONFIG_HOME
if xdgHome := os.Getenv("XDG_CONFIG_HOME"); xdgHome != "" {
paths = append(paths, filepath.Join(xdgHome, appName))
} else if home := os.Getenv("HOME"); home != "" {
paths = append(paths, filepath.Join(home, ".config", appName))
}
// XDG_CONFIG_DIRS
if xdgDirs := os.Getenv("XDG_CONFIG_DIRS"); xdgDirs != "" {
for _, dir := range filepath.SplitList(xdgDirs) {
paths = append(paths, filepath.Join(dir, appName))
}
} else {
// Default system paths
paths = append(paths,
filepath.Join("/etc/xdg", appName),
filepath.Join("/etc", appName),
)
}
return paths
}

60
doc.go
View File

@ -1,60 +0,0 @@
// File: lixenwraith/config/doc.go
// Package config provides thread-safe configuration management for Go applications
// with support for multiple sources: TOML files, environment variables, command-line
// arguments, and default values with configurable precedence.
//
// Features:
// - Multiple configuration sources with customizable precedence
// - Thread-safe operations using sync.RWMutex
// - Automatic type conversions for common Go types
// - Struct registration with tag support
// - Environment variable auto-discovery and mapping
// - Builder pattern for easy initialization
// - Source tracking to see where values originated
// - Configuration validation
// - Zero dependencies (only stdlib + toml parser + mapstructure)
//
// Quick Start:
//
// type Config struct {
// Server struct {
// Host string `toml:"host"`
// Port int `toml:"port"`
// } `toml:"server"`
// }
//
// defaults := Config{}
// defaults.Server.Host = "localhost"
// defaults.Server.Port = 8080
//
// cfg, err := config.Quick(defaults, "MYAPP_", "config.toml")
// if err != nil {
// log.Fatal(err)
// }
//
// host, _ := cfg.String("server.host")
// port, _ := cfg.Int64("server.port")
//
// Default Precedence (highest to lowest):
// 1. Command-line arguments (--server.port=9090)
// 2. Environment variables (MYAPP_SERVER_PORT=9090)
// 3. Configuration file (config.toml)
// 4. Default values
//
// Custom Precedence:
//
// cfg, err := config.NewBuilder().
// WithDefaults(defaults).
// WithSources(
// config.SourceEnv, // Environment the highest priority
// config.SourceCLI,
// config.SourceFile,
// config.SourceDefault,
// ).
// Build()
//
// Thread Safety:
// All operations are thread-safe. The package uses read-write mutexes to allow
// concurrent reads while protecting writes.
package config

407
doc/access.md Normal file
View File

@ -0,0 +1,407 @@
# Access Patterns
This guide covers all methods for getting and setting configuration values, type conversions, and working with structured data.
**Always Register First**: Register paths before setting values
**Use Type Assertions**: After struct registration, types are guaranteed
## Getting Values
### Basic Get
```go
// Get returns (value, exists)
value, exists := cfg.Get("server.port")
if !exists {
log.Fatal("server.port not configured")
}
// Type assertion (safe after registration)
port := value.(int64)
```
### Type-Safe Access
When using struct registration, types are guaranteed:
```go
type Config struct {
Server struct {
Port int64 `toml:"port"`
Host string `toml:"host"`
} `toml:"server"`
}
cfg.RegisterStruct("", &Config{})
// After registration, type assertions are safe
port, _ := cfg.Get("server.port")
portNum := port.(int64) // Won't panic - type is enforced
```
### Get from Specific Source
```go
// Get value from specific source
envPort, exists := cfg.GetSource(config.SourceEnv, "server.port")
if exists {
log.Printf("Port from environment: %v", envPort)
}
// Check all sources
sources := cfg.GetSources("server.port")
for source, value := range sources {
log.Printf("%s: %v", source, value)
}
```
### Struct Scanning
```go
// Scan into struct
var serverConfig struct {
Host string `toml:"host"`
Port int64 `toml:"port"`
TLS struct {
Enabled bool `toml:"enabled"`
Cert string `toml:"cert"`
} `toml:"tls"`
}
if err := cfg.Scan(&serverConfig, "server"); err != nil {
log.Fatal(err)
}
// Use structured data
log.Printf("Server: %s:%d", serverConfig.Host, serverConfig.Port)
```
### Target Population
```go
// Populate entire config struct
var config AppConfig
if err := cfg.Target(&config); err != nil {
log.Fatal(err)
}
// Or with builder pattern
var config AppConfig
cfg, _ := config.NewBuilder().
WithTarget(&config).
Build()
// Access directly
fmt.Println(config.Server.Port)
```
### GetTyped
Retrieves a single configuration value and decodes it to the specified type.
```go
import "time"
// Returns an int, converting from string "9090" if necessary.
port, err := config.GetTyped[int](cfg, "server.port")
// Returns a time.Duration, converting from string "5m30s".
timeout, err := config.GetTyped[time.Duration](cfg, "server.timeout")
```
### ScanTyped
A generic wrapper around `Scan` that allocates, populates, and returns a pointer to a struct of the specified type.
```go
// Instead of:
// var dbConf DBConfig
// if err := cfg.Scan("database", &dbConf); err != nil { ... }
// You can write:
dbConf, err := config.ScanTyped[DBConfig](cfg, "database")
if err != nil {
// ...
}
// dbConf is a *DBConfig```
```
### Type-Aware Mode
```go
var conf AppConfig
cfg, _ := config.NewBuilder().
WithTarget(&conf).
Build()
// Get updated struct anytime
latest, err := cfg.AsStruct()
if err != nil {
log.Fatal(err)
}
appConfig := latest.(*AppConfig)
```
## Setting Values
### Basic Set
```go
// Set updates the highest priority source (default: CLI)
if err := cfg.Set("server.port", int64(9090)); err != nil {
log.Fatal(err) // Error if path not registered
}
```
### Set in Specific Source
```go
// Set value in specific source
cfg.SetSource(config.SourceEnv, "server.port", "8080")
cfg.SetSource(config.SourceCLI, "debug", true)
// File source typically set via LoadFile, but can be manual
cfg.SetSource(config.SourceFile, "feature.enabled", true)
```
### Batch Updates
```go
// Multiple updates
updates := map[string]any{
"server.port": int64(9090),
"server.host": "0.0.0.0",
"database.maxconns": int64(50),
}
for path, value := range updates {
if err := cfg.Set(path, value); err != nil {
log.Printf("Failed to set %s: %v", path, err)
}
}
```
## Type Conversions
The package uses mapstructure for flexible type conversion:
```go
// These all work for a string field
cfg.Set("name", "value") // Direct string
cfg.Set("name", 123) // Number → "123"
cfg.Set("name", true) // Boolean → "true"
// For int64 fields
cfg.Set("port", int64(8080)) // Direct
cfg.Set("port", "8080") // String → int64
cfg.Set("port", 8080.0) // Float → int64
cfg.Set("port", int(8080)) // int → int64
```
### Duration Handling
```go
type Config struct {
Timeout time.Duration `toml:"timeout"`
}
// All these work
cfg.Set("timeout", 30*time.Second) // Direct duration
cfg.Set("timeout", "30s") // String parsing
cfg.Set("timeout", "5m30s") // Complex duration
```
### Network Types
```go
type Config struct {
IP net.IP `toml:"ip"`
CIDR net.IPNet `toml:"cidr"`
URL url.URL `toml:"url"`
}
// Automatic parsing
cfg.Set("ip", "192.168.1.1")
cfg.Set("cidr", "10.0.0.0/8")
cfg.Set("url", "https://example.com:8080/path")
```
### Slice Handling
```go
type Config struct {
Tags []string `toml:"tags"`
Ports []int `toml:"ports"`
}
// Direct slice
cfg.Set("tags", []string{"prod", "stable"})
// Comma-separated string (from env/CLI)
cfg.Set("tags", "prod,stable,v2")
// Number arrays
cfg.Set("ports", []int{8080, 8081, 8082})
```
## Checking Configuration
### Path Registration
```go
// Check if path is registered
if _, exists := cfg.Get("server.port"); !exists {
log.Fatal("server.port not registered")
}
// Get all registered paths
paths := cfg.GetRegisteredPaths("server.")
for path := range paths {
log.Printf("Registered: %s", path)
}
// With default values
defaults := cfg.GetRegisteredPathsWithDefaults("")
for path, defaultVal := range defaults {
log.Printf("%s = %v (default)", path, defaultVal)
}
```
### Validation
```go
// Check required fields
if err := cfg.Validate("api.key", "database.url"); err != nil {
log.Fatal("Missing required config:", err)
}
// Custom validation
requiredPorts := []string{"server.port", "metrics.port"}
for _, path := range requiredPorts {
if val, exists := cfg.Get(path); exists {
if port := val.(int64); port < 1024 {
log.Fatalf("%s must be >= 1024", path)
}
}
}
```
### Source Inspection
```go
// Debug specific value
path := "server.port"
log.Printf("=== %s ===", path)
log.Printf("Current: %v", cfg.Get(path))
sources := cfg.GetSources(path)
for source, value := range sources {
log.Printf(" %s: %v", source, value)
}
```
## Advanced Patterns
### Dynamic Configuration
```go
// Change configuration at runtime
func updatePort(cfg *config.Config, port int64) error {
if port < 1 || port > 65535 {
return fmt.Errorf("invalid port: %d", port)
}
return cfg.Set("server.port", port)
}
```
### Configuration Facade
```go
type ConfigFacade struct {
cfg *config.Config
}
func (f *ConfigFacade) ServerPort() int64 {
val, _ := f.cfg.Get("server.port")
return val.(int64)
}
func (f *ConfigFacade) SetServerPort(port int64) error {
return f.cfg.Set("server.port", port)
}
func (f *ConfigFacade) DatabaseURL() string {
val, _ := f.cfg.Get("database.url")
return val.(string)
}
```
### Default Fallbacks
```go
// Helper for optional configuration
func getOrDefault(cfg *config.Config, path string, defaultVal any) any {
if val, exists := cfg.Get(path); exists {
return val
}
return defaultVal
}
// Usage
timeout := getOrDefault(cfg, "timeout", 30*time.Second).(time.Duration)
```
## Thread Safety
All access methods are thread-safe:
```go
// Safe concurrent access
var wg sync.WaitGroup
// Multiple readers
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
port, _ := cfg.Get("server.port")
log.Printf("Port: %v", port)
}()
}
// Concurrent writes are safe too
wg.Add(1)
go func() {
defer wg.Done()
cfg.Set("counter", atomic.AddInt64(&counter, 1))
}()
wg.Wait()
```
## Debugging
### View All Configuration
```go
// Debug output
fmt.Println(cfg.Debug())
// Dump as TOML
cfg.Dump() // Writes to stdout
```
### Clone for Testing
```go
// Create isolated copy for testing
testCfg := cfg.Clone()
testCfg.Set("server.port", int64(0)) // Random port for tests
```
## See Also
- [Live Reconfiguration](reconfiguration.md) - Reacting to changes
- [Builder Pattern](builder.md) - Type-aware configuration
- [Environment Variables](env.md) - Environment value access

378
doc/builder.md Normal file
View File

@ -0,0 +1,378 @@
# Builder Pattern
The builder pattern provides fine-grained control over configuration initialization and loading behavior.
## Basic Builder Usage
```go
cfg, err := config.NewBuilder().
WithDefaults(defaultStruct).
WithEnvPrefix("MYAPP_").
WithFile("config.toml").
Build()
```
## Builder Methods
### WithDefaults
Register a struct containing default values:
```go
type Config struct {
Host string `toml:"host"`
Port int `toml:"port"`
}
defaults := &Config{
Host: "localhost",
Port: 8080,
}
cfg, _ := config.NewBuilder().
WithDefaults(defaults).
Build()
```
### WithTarget
Enable type-aware mode with automatic struct population:
```go
var appConfig Config
cfg, _ := config.NewBuilder().
WithTarget(&appConfig). // Registers struct and enables AsStruct()
WithFile("config.toml").
Build()
// Access populated struct
populated, _ := cfg.AsStruct()
config := populated.(*Config)
```
### WithTagName
Use different struct tags for field mapping:
```go
type Config struct {
Server struct {
Host string `json:"host"` // Using JSON tags
Port int `json:"port"`
} `json:"server"`
}
cfg, _ := config.NewBuilder().
WithDefaults(&Config{}).
WithTagName("json"). // Use json tags instead of toml
Build()
```
Supported tag names: `toml` (default), `json`, `yaml`
### WithPrefix
Add a prefix to all registered paths:
```go
cfg, _ := config.NewBuilder().
WithDefaults(serverConfig).
WithPrefix("server"). // All paths prefixed with "server."
Build()
// Access as "server.host" instead of just "host"
host, _ := cfg.Get("server.host")
```
### WithEnvPrefix
Set environment variable prefix:
```go
cfg, err := config.NewBuilder().
WithEnvPrefix("MYAPP_").
Build()
// Reads from MYAPP_SERVER_PORT for "server.port"
```
### WithSources
Configure source precedence order:
```go
// Environment variables take highest priority
cfg, _ := config.NewBuilder().
WithSources(
config.SourceEnv,
config.SourceFile,
config.SourceCLI,
config.SourceDefault,
).
Build()
```
### WithEnvTransform
Custom environment variable name mapping:
```go
cfg, _ := config.NewBuilder().
WithEnvTransform(func(path string) string {
// Custom mapping logic
switch path {
case "server.port":
return "PORT" // Use $PORT instead of $MYAPP_SERVER_PORT
case "database.url":
return "DATABASE_URL"
default:
// Default transformation
return "MYAPP_" + strings.ToUpper(
strings.ReplaceAll(path, ".", "_"),
)
}
}).
Build()
```
### WithEnvWhitelist
Limit which configuration paths check environment variables:
```go
cfg, _ := config.NewBuilder().
WithEnvWhitelist(
"server.port",
"database.url",
"api.key",
). // Only these paths read from env
Build()
```
### WithValidator
Add validation functions that run *before* the target struct is populated. These validators operate on the raw `*config.Config` object and are suitable for checking required paths or formats before type conversion.
```go
// Validator runs on raw, pre-decoded values.
cfg, _ := config.NewBuilder().
WithDefaults(defaults).
WithValidator(func(c *config.Config) error {
// Validate port range
port, _ := c.Get("server.port")
if p := port.(int64); p < 1024 || p > 65535 {
return fmt.Errorf("port must be between 1024-65535")
}
return nil
}).
WithValidator(func(c *config.Config) error {
// Validate required fields
return c.Validate("api.key", "database.url")
}).
Build()
```
For type-safe validation, see `WithTypedValidator`.
### WithTypedValidator
Add a type-safe validation function that runs *after* the configuration has been fully loaded and decoded into the target struct (set by `WithTarget`). This is the recommended approach for most validation logic.
The validation function must accept a single argument: a pointer to the same struct type that was passed to `WithTarget`.
```go
type AppConfig struct {
Server struct {
Port int64 `toml:"port"`
} `toml:"server"`
}
var target AppConfig
cfg, err := config.NewBuilder().
WithTarget(&target).
WithFile("config.toml").
WithTypedValidator(func(conf *AppConfig) error {
if conf.Server.Port < 1024 || conf.Server.Port > 65535 {
return fmt.Errorf("port %d is outside the valid range", conf.Server.Port)
}
return nil
}).
Build()
```
### WithFile
Set configuration file path:
```go
cfg, _ := config.NewBuilder().
WithFile("/etc/myapp/config.toml").
Build()
```
### WithArgs
Override command-line arguments (default is os.Args[1:]):
```go
cfg, _ := config.NewBuilder().
WithArgs([]string{"--debug", "--server.port=9090"}).
Build()
```
### WithFileDiscovery
Enable automatic configuration file discovery:
```go
cfg, _ := config.NewBuilder().
WithFileDiscovery(config.FileDiscoveryOptions{
Name: "myapp",
Extensions: []string{".toml", ".conf"},
EnvVar: "MYAPP_CONFIG",
CLIFlag: "--config",
UseXDG: true,
}).
Build()
```
This searches for configuration files in:
1. Path specified by `--config` flag
2. Path in `$MYAPP_CONFIG` environment variable
3. Current directory
4. XDG config directories (`~/.config/myapp/`, `/etc/myapp/`)
## Method Interaction and Precedence
While most builder methods can be chained in any order, it's important to understand how `WithDefaults` and `WithTarget` interact to define the default configuration values.
### `WithDefaults` Has Precedence
**Rule:** If `WithDefaults()` is used anywhere in the chain, it will **always** be the definitive source for default values.
This is the recommended approach for clarity and explicitness. It cleanly separates the struct that defines the defaults from the struct that will be populated.
**Example (Recommended Pattern):**
```go
// initialData contains the fallback values.
initialData := &AppConfig{
Server: ServerConfig{Port: 8080},
}
// target is an empty shell for population.
var target AppConfig
// WithDefaults explicitly sets the defaults.
// WithTarget sets up the config for type-safe decoding.
cfg, err := config.NewBuilder().
WithTarget(&target).
WithDefaults(initialData).
WithFile("config.toml").
Build()
```
In this scenario, the `target` struct is *only* used for type information and `AsStruct()` functionality; its initial (zero) values are not used as defaults as per below.
### Using `WithTarget` for Defaults
**Rule:** If `WithDefaults()` is **not** used, the struct passed to `WithTarget()` will serve as the source of default values.
This provides a convenient shorthand for simpler cases where the initial state of your application's config struct *is* the desired default state. The unit tests for the package rely on this behavior.
**Example (Convenience Pattern):**
```go
// The initial state of this struct will be used as the defaults.
target := &AppConfig{
Server: ServerConfig{Port: 8080},
}
// Since WithDefaults() is absent, the builder uses `target`
// for both defaults and for type-safe decoding.
cfg, err := config.NewBuilder().
WithTarget(&target).
WithFile("config.toml").
Build()
```
## Usage Patterns
### Type-Safe Configuration Access
```go
type AppConfig struct {
Server ServerConfig `toml:"server"`
DB DBConfig `toml:"database"`
}
var conf AppConfig
cfg, _ := config.NewBuilder().
WithTarget(&conf).
WithFile("config.toml").
Build()
// Direct struct access after building
fmt.Printf("Port: %d\n", conf.Server.Port)
// Or get updated struct anytime
latest, _ := cfg.AsStruct()
appConf := latest.(*AppConfig)
```
### Multi-Stage Validation
```go
cfg, err := config.NewBuilder().
WithDefaults(defaults).
// Stage 1: Validate structure
WithValidator(validateStructure).
// Stage 2: Validate values
WithValidator(validateRanges).
// Stage 3: Validate relationships
WithValidator(validateRelationships).
Build()
func validateStructure(c *config.Config) error {
required := []string{"server.host", "server.port", "database.url"}
return c.Validate(required...)
}
func validateRanges(c *config.Config) error {
port, _ := c.Get("server.port")
if p := port.(int64); p < 1 || p > 65535 {
return fmt.Errorf("invalid port: %d", p)
}
return nil
}
func validateRelationships(c *config.Config) error {
// Validate that related values make sense together
// e.g., if SSL is enabled, ensure cert paths are set
return nil
}
```
### Error Handling
The builder accumulates errors and returns them on `Build()`:
```go
cfg, err := config.NewBuilder().
WithTarget(nil). // Error: nil target
WithTagName("invalid"). // Error: unsupported tag
Build()
if err != nil {
// err contains first error encountered
}
```
For panic on error use `MustBuild()`
## See Also
- [Environment Variables](env.md) - Environment configuration details
- [Live Reconfiguration](reconfiguration.md) - File watching with builder

193
doc/cli.md Normal file
View File

@ -0,0 +1,193 @@
# Command Line Arguments
The config package supports command-line argument parsing with flexible formats and automatic type conversion.
## Argument Formats
### Key-Value Pairs
```bash
# Space-separated
./myapp --server.port 8080 --database.url "postgres://localhost/db"
# Equals-separated
./myapp --server.port=8080 --database.url=postgres://localhost/db
# Mixed formats
./myapp --server.port 8080 --debug=true
```
### Boolean Flags
```bash
# Boolean flags don't require a value (assumed true)
./myapp --debug --verbose
# Explicit boolean values
./myapp --debug=true --verbose=false
```
### Nested Paths
Use dot notation for nested configuration:
```bash
./myapp --server.host=0.0.0.0 --server.port=9090 --server.tls.enabled=true
```
## Type Conversion
Command-line values are automatically converted to match registered types:
```go
type Config struct {
Port int64 `toml:"port"`
Timeout time.Duration `toml:"timeout"`
Ratio float64 `toml:"ratio"`
Enabled bool `toml:"enabled"`
Tags []string `toml:"tags"`
}
// All these are parsed correctly:
// --port=8080 → int64(8080)
// --timeout=30s → time.Duration(30 * time.Second)
// --ratio=0.95 → float64(0.95)
// --enabled=true → bool(true)
// --tags=prod,stable → []string{"prod", "stable"}
```
## Integration with flag Package
### Generate flag.FlagSet
```go
// Generate flags from registered configuration
fs := cfg.GenerateFlags()
// Parse command line
if err := fs.Parse(os.Args[1:]); err != nil {
log.Fatal(err)
}
// Apply parsed flags to configuration
if err := cfg.BindFlags(fs); err != nil {
log.Fatal(err)
}
```
### Custom Flag Registration
```go
fs := flag.NewFlagSet("myapp", flag.ContinueOnError)
// Add custom flags
verbose := fs.Bool("v", false, "verbose output")
configFile := fs.String("config", "config.toml", "config file path")
// Parse
fs.Parse(os.Args[1:])
// Use custom flags
if *verbose {
log.SetLevel(log.DebugLevel)
}
// Load config with custom file path
cfg, _ := config.NewBuilder().
WithFile(*configFile).
Build()
// Bind remaining flags
cfg.BindFlags(fs)
```
## Precedence and Overrides
Command-line arguments have the highest precedence by default:
```go
// Default precedence: CLI > Env > File > Default
cfg, _ := config.Quick(defaults, "APP_", "config.toml")
// Even if config.toml sets port=8080 and APP_PORT=9090,
// --port=7070 will win
```
Change precedence if needed:
```go
cfg, _ := config.NewBuilder().
WithSources(
config.SourceEnv, // Env highest
config.SourceCLI, // Then CLI
config.SourceFile, // Then file
config.SourceDefault, // Finally defaults
).
Build()
```
## Argument Parsing Details
### Validation
- Paths must use valid identifiers (letters, numbers, underscore, dash)
- No leading/trailing dots in paths
- Empty segments not allowed (no `..` in paths)
### Special Cases
```bash
# Double dash stops flag parsing
./myapp --port=8080 -- --not-a-flag
# Single dash flags are ignored (not GNU-style)
./myapp -p 8080 # Ignored, use --port
# Quoted values preserve spaces
./myapp --message="Hello World" --name='John Doe'
# Escape quotes in values
./myapp --json="{\"key\": \"value\"}"
```
### Value Parsing Rules
1. **Booleans**: `true`, `false` (case-sensitive)
2. **Numbers**: Standard decimal notation
3. **Strings**: Quoted or unquoted (quotes removed if present)
4. **Lists**: Comma-separated (when target type is slice)
## Override Arguments
```go
// Parse custom arguments instead of os.Args
customArgs := []string{"--debug", "--port=9090"}
cfg, _ := config.NewBuilder().
WithArgs(customArgs).
Build()
```
## Error Handling
CLI parsing errors are returned from `Build()` or `LoadCLI()`:
```go
cfg, err := config.NewBuilder().
WithDefaults(&Config{}).
Build()
if err != nil {
switch {
case errors.Is(err, config.ErrCLIParse):
log.Fatal("Invalid command line arguments:", err)
default:
log.Fatal("Configuration error:", err)
}
}
```
## See Also
- [Environment Variables](env.md) - Environment variable handling
- [Access Patterns](access.md) - Retrieving parsed values

195
doc/env.md Normal file
View File

@ -0,0 +1,195 @@
# Environment Variables
The config package provides flexible environment variable support with automatic name transformation, custom mappings, and whitelist capabilities.
## Basic Usage
Environment variables are automatically mapped from configuration paths:
```go
cfg, _ := config.Quick(defaults, "MYAPP_", "config.toml")
// These environment variables are automatically loaded:
// MYAPP_SERVER_PORT → server.port
// MYAPP_DATABASE_URL → database.url
// MYAPP_LOG_LEVEL → log.level
// MYAPP_FEATURES_ENABLED → features.enabled
```
## Name Transformation
### Default Transformation
By default, paths are transformed as follows:
1. Dots (`.`) become underscores (`_`)
2. Converted to uppercase
3. Prefix is prepended
```go
// Path transformations:
// server.port → MYAPP_SERVER_PORT
// database.url → MYAPP_DATABASE_URL
// tls.cert.path → MYAPP_TLS_CERT_PATH
// maxRetries → MYAPP_MAXRETRIES
```
### Custom Transformation
Define custom environment variable mappings:
```go
cfg, _ := config.NewBuilder().
WithEnvTransform(func(path string) string {
switch path {
case "server.port":
return "PORT" // Use $PORT directly
case "database.url":
return "DATABASE_URL"
case "api.key":
return "API_KEY"
default:
// Fallback to default transformation
return "MYAPP_" + strings.ToUpper(
strings.ReplaceAll(path, ".", "_"),
)
}
}).
Build()
```
### No Transformation
Return empty string to skip environment lookup:
```go
cfg, _ := config.NewBuilder().
WithEnvTransform(func(path string) string {
// Only allow specific env vars
allowed := map[string]string{
"port": "PORT",
"database": "DATABASE_URL",
}
return allowed[path] // Empty string if not in map
}).
Build()
```
## Explicit Environment Variable Mapping
Use the `env` struct tag for explicit mappings:
```go
type Config struct {
Port int `toml:"port" env:"PORT"`
Database string `toml:"database" env:"DATABASE_URL"`
APIKey string `toml:"api_key" env:"API_KEY"`
}
// These use the explicit env tag names, ignoring prefix
cfg, _ := config.NewBuilder().
WithDefaults(&Config{}).
WithEnvPrefix("MYAPP_"). // Not used for tagged fields
Build()
```
Or register with explicit environment variable:
```go
cfg.RegisterWithEnv("server.port", 8080, "PORT")
cfg.RegisterWithEnv("database.url", "localhost", "DATABASE_URL")
```
## Environment Variable Whitelist
Limit which paths can be set via environment:
```go
cfg, _ := config.NewBuilder().
WithEnvWhitelist(
"server.port",
"database.url",
"api.key",
"log.level",
). // Only these paths read from environment
Build()
```
## Type Conversion
Environment variables (strings) are automatically converted to the registered type:
```bash
# Booleans
export MYAPP_DEBUG=true
export MYAPP_VERBOSE=false
# Numbers
export MYAPP_PORT=8080
export MYAPP_TIMEOUT=30
export MYAPP_RATIO=0.95
# Durations
export MYAPP_TIMEOUT=30s
export MYAPP_INTERVAL=5m
# Lists (comma-separated)
export MYAPP_TAGS=prod,stable,v2
```
## Manual Environment Loading
Load environment variables at any time:
```go
cfg := config.New()
cfg.RegisterStruct("", &Config{})
// Load with prefix
if err := cfg.LoadEnv("MYAPP_"); err != nil {
log.Fatal(err)
}
// Or use existing options
if err := cfg.LoadWithOptions("", nil, config.LoadOptions{
EnvPrefix: "MYAPP_",
Sources: []config.Source{config.SourceEnv},
}); err != nil {
log.Fatal(err)
}
```
## Discovering Environment Variables
Find which environment variables are set:
```go
// Discover all env vars matching registered paths
discovered := cfg.DiscoverEnv("MYAPP_")
for path, envVar := range discovered {
log.Printf("%s is set via %s", path, envVar)
}
```
## Precedence Examples
Default precedence: CLI > Env > File > Default
Custom precedence (Env > File > CLI > Default):
```go
cfg, _ := config.NewBuilder().
WithSources(
config.SourceEnv,
config.SourceFile,
config.SourceCLI,
config.SourceDefault,
).
Build()
```
## See Also
- [Command Line](cli.md) - CLI argument handling
- [File Configuration](file.md) - Configuration file formats
- [Access Patterns](access.md) - Retrieving values

310
doc/file.md Normal file
View File

@ -0,0 +1,310 @@
# Configuration Files
The config package supports TOML configuration files with automatic loading, discovery, and atomic saving.
## TOML Format
TOML (Tom's Obvious, Minimal Language) is the supported configuration format:
```toml
# Basic values
host = "localhost"
port = 8080
debug = false
# Nested sections
[server]
host = "0.0.0.0"
port = 9090
timeout = "30s"
[database]
url = "postgres://localhost/mydb"
max_conns = 25
timeout = "5s"
# Arrays
[features]
enabled = ["auth", "api", "metrics"]
# Inline tables
tls = { enabled = true, cert = "/path/to/cert", key = "/path/to/key" }
```
## Loading Configuration Files
### Basic Loading
```go
cfg := config.New()
cfg.RegisterStruct("", &Config{})
if err := cfg.LoadFile("config.toml"); err != nil {
if errors.Is(err, config.ErrConfigNotFound) {
log.Println("Config file not found, using defaults")
} else {
log.Fatal("Failed to load config:", err)
}
}
```
### With Builder
```go
cfg, err := config.NewBuilder().
WithDefaults(&Config{}).
WithFile("/etc/myapp/config.toml").
Build()
```
### Multiple File Attempts
```go
// Try multiple locations
locations := []string{
"./config.toml",
"~/.config/myapp/config.toml",
"/etc/myapp/config.toml",
}
var cfg *config.Config
var err error
for _, path := range locations {
cfg, err = config.NewBuilder().
WithDefaults(&Config{}).
WithFile(path).
Build()
if err == nil || !errors.Is(err, config.ErrConfigNotFound) {
break
}
}
```
## Automatic File Discovery
Use file discovery to find configuration automatically:
```go
cfg, _ := config.NewBuilder().
WithDefaults(&Config{}).
WithFileDiscovery(config.FileDiscoveryOptions{
Name: "myapp",
Extensions: []string{".toml", ".conf"},
EnvVar: "MYAPP_CONFIG",
CLIFlag: "--config",
UseXDG: true,
UseCurrentDir: true,
Paths: []string{"/opt/myapp"},
}).
Build()
```
Search order:
1. CLI flag: `--config=/path/to/config.toml`
2. Environment variable: `$MYAPP_CONFIG`
3. Current directory: `./myapp.toml`, `./myapp.conf`
4. XDG config: `~/.config/myapp/myapp.toml`
5. System paths: `/etc/myapp/myapp.toml`
6. Custom paths: `/opt/myapp/myapp.toml`
## Saving Configuration
### Save Current State
```go
// Save all current values atomically
if err := cfg.Save("config.toml"); err != nil {
log.Fatal("Failed to save config:", err)
}
```
The save operation is atomic - it writes to a temporary file then renames it.
### Save Specific Source
```go
// Save only values from environment variables
if err := cfg.SaveSource("env-config.toml", config.SourceEnv); err != nil {
log.Fatal(err)
}
// Save only file-loaded values
if err := cfg.SaveSource("file-only.toml", config.SourceFile); err != nil {
log.Fatal(err)
}
```
### Generate Default Configuration
```go
// Create a default config file
defaults := &Config{}
// ... set default values ...
cfg, _ := config.NewBuilder().
WithDefaults(defaults).
Build()
// Save defaults as config template
if err := cfg.SaveSource("config.toml.example", config.SourceDefault); err != nil {
log.Fatal(err)
}
```
## File Structure Mapping
TOML structure maps directly to dot-notation paths:
```toml
# Maps to "debug"
debug = true
[server]
# Maps to "server.host"
host = "localhost"
# Maps to "server.port"
port = 8080
[server.tls]
# Maps to "server.tls.enabled"
enabled = true
# Maps to "server.tls.cert"
cert = "/path/to/cert"
[[users]]
# Array elements: "users.0.name", "users.0.role"
name = "admin"
role = "administrator"
[[users]]
# Array elements: "users.1.name", "users.1.role"
name = "user"
role = "standard"
```
## Type Handling
TOML types map to Go types:
```toml
# Strings
name = "myapp"
multiline = """
Line one
Line two
"""
# Numbers
port = 8080 # int64
timeout = 30 # int64
ratio = 0.95 # float64
max_size = 1_000_000 # int64 (underscores allowed)
# Booleans
enabled = true
debug = false
# Dates/Times (RFC 3339)
created_at = 2024-01-15T09:30:00Z
expires = 2024-12-31
# Arrays
ports = [8080, 8081, 8082]
tags = ["production", "stable"]
# Tables (objects)
[database]
host = "localhost"
port = 5432
# Array of tables
[[servers]]
name = "web1"
host = "10.0.0.1"
[[servers]]
name = "web2"
host = "10.0.0.2"
```
## Error Handling
File loading can produce several error types:
```go
err := cfg.LoadFile("config.toml")
if err != nil {
switch {
case errors.Is(err, config.ErrConfigNotFound):
// File doesn't exist - often not fatal
log.Println("No config file, using defaults")
case strings.Contains(err.Error(), "failed to parse TOML"):
// TOML syntax error
log.Fatal("Invalid TOML syntax:", err)
case strings.Contains(err.Error(), "failed to read"):
// Permission or I/O error
log.Fatal("Cannot read config file:", err)
default:
log.Fatal("Config error:", err)
}
}
```
## Security Considerations
### File Permissions
```go
// After saving, verify permissions
info, err := os.Stat("config.toml")
if err == nil {
mode := info.Mode()
if mode&0077 != 0 {
log.Warn("Config file is world/group readable")
// Fix permissions
os.Chmod("config.toml", 0600)
}
}
```
### Size Limits
Files and values have size limits:
- Maximum file size: ~10MB (10 * MaxValueSize)
- Maximum value size: 1MB
## Partial Loading
Load only specific sections:
```go
var serverCfg ServerConfig
if err := cfg.Scan("server", &serverCfg); err != nil {
log.Fatal(err)
}
var dbCfg DatabaseConfig
if err := cfg.Scan("database", &dbCfg); err != nil {
log.Fatal(err)
}
```
## Best Practices
1. **Use Example Files**: Generate `.example` files with defaults
2. **Check Permissions**: Ensure config files aren't world-readable
3. **Validate After Load**: Add validators to check loaded values
4. **Handle Missing Files**: Missing config files often aren't fatal
5. **Use Atomic Saves**: The built-in Save method is atomic
6. **Document Structure**: Comment your TOML files thoroughly
## See Also
- [Live Reconfiguration](reconfiguration.md) - Automatic file reloading
- [Builder Pattern](builder.md) - File discovery options
- [Access Patterns](access.md) - Working with loaded values

302
doc/llm-guide.md Normal file
View File

@ -0,0 +1,302 @@
# lixenwraith/config LLM Usage Guide
Thread-safe configuration management for Go applications with multi-source support, type safety, and live reconfiguration.
Use default configuration and behavior if applicable, unless explicitly required.
## Core Types
### Config
```go
// Primary configuration manager. All operations are thread-safe.
type Config struct {
// Internal fields - thread-safe configuration store
}
```
### Source
```go
// Represents a configuration source, used to define load precedence.
type Source string
const (
SourceDefault Source = "default"
SourceFile Source = "file"
SourceEnv Source = "env"
SourceCLI Source = "cli"
)
```
### LoadOptions
```go
type LoadOptions struct {
Sources []Source // Precedence order (first = highest)
EnvPrefix string // Prepended to env var names
EnvTransform EnvTransformFunc // Custom path→env mapping
LoadMode LoadMode // Uses default behavior, do not configure
EnvWhitelist map[string]bool // Limit env paths (nil = all)
SkipValidation bool // Skip path validation
}
type EnvTransformFunc func(path string) string
type LoadMode int // LoadModeReplace (default) or LoadModeMerge
```
## Error Types
```go
var (
ErrConfigNotFound = errors.New("configuration file not found")
ErrCLIParse = errors.New("failed to parse command-line arguments")
ErrEnvParse = errors.New("failed to parse environment variables")
ErrValueSize = fmt.Errorf("value size exceeds maximum %d bytes", MaxValueSize)
)
const MaxValueSize = 1024 * 1024 // 1MB
```
## Core Methods
### Creation
```go
// New creates a new Config instance with default options.
func New() *Config
// NewWithOptions creates a new Config instance with custom load options.
func NewWithOptions(opts LoadOptions) *Config
func DefaultLoadOptions() LoadOptions
```
### Registration
```go
// Register makes a configuration path known with a default value; required before use.
func (c *Config) Register(path string, defaultValue any) error
// RegisterStruct recursively registers fields from a struct using `toml` tags by default.
func (c *Config) RegisterStruct(prefix string, structWithDefaults any) error
// RegisterStructWithTags is like RegisterStruct but allows custom tag names ("json", "yaml").
func (c *Config) RegisterStructWithTags(prefix string, structWithDefaults any, tagName string) error
// RegisterWithEnv registers a path with an explicit environment variable mapping.
func (c *Config) RegisterWithEnv(path string, defaultValue any, envVar string) error
// Unregister removes a configuration path and all its children.
func (c *Config) Unregister(path string) error
```
Only default `toml` tags must be used unless support of other types are explicitly requested.
Path registration is required before setting values. Paths use dot notation (e.g., "server.port").
### Value Access
```go
// Get retrieves the final merged value; the bool indicates if the path was registered.
func (c *Config) Get(path string) (any, bool)
// GetSource retrieves a value from a specific source layer.
func (c *Config) GetSource(path string, source Source) (any, bool)
// GetSources returns all sources that have a value for the given path.
func (c *Config) GetSources(path string) map[Source]any
```
The returned `any` type requires type assertion, e.g., `port := val.(int64)`.
### Value Modification
```go
// Set updates a value in the highest priority source (default: CLI). Path must be registered.
func (c *Config) Set(path string, value any) error
// SetSource sets a value for a specific source layer.
func (c *Config) SetSource(path string, source Source, value any) error
// SetLoadOptions updates the load options, recomputing all current values.
func (c *Config) SetLoadOptions(opts LoadOptions) error
```
### Loading
```go
// Load reads configuration from a TOML file and merges overrides from command-line arguments.
func (c *Config) Load(filePath string, args []string) error
// LoadWithOptions loads configuration from multiple sources with custom options.
func (c *Config) LoadWithOptions(filePath string, args []string, opts LoadOptions) error
// LoadFile loads configuration values from a TOML file into the File source.
func (c *Config) LoadFile(path string) error
// LoadEnv loads values from environment variables into the Env source.
func (c *Config) LoadEnv(prefix string) error
// LoadCLI loads values from command-line arguments into the CLI source.
func (c *Config) LoadCLI(args []string) error
```
### Scanning & Population
```go
// Scan populates a struct from a specific config path (e.g., "server").
func (c *Config) Scan(basePath string, target any) error
// ScanSource decodes configuration from specific source
func (c *Config) ScanSource(basePath string, source Source, target any) error
// Target populates a struct from the root of the config; alias for Scan("", target).
func (c *Config) Target(out any) error
// AsStruct retrieves the pre-configured target struct (see Builder.WithTarget).
func (c *Config) AsStruct() (any, error)
```
Populates structs using mapstructure with automatic type conversion.
### Persistence
```go
// Save atomically saves the current merged configuration state to a TOML file.
func (c *Config) Save(path string) error
// SaveSource atomically saves values from only a specific source to a TOML file.
func (c *Config) SaveSource(path string, source Source) error
```
Atomic file writes in TOML format.
### State Management
```go
// Reset clears all non-default values from all sources.
func (c *Config) Reset()
// ResetSource clears all values from a specific source.
func (c *Config) ResetSource(source Source)
// Clone creates a deep copy of the configuration state.
func (c *Config) Clone() *Config
```
### Inspection
```go
// GetRegisteredPaths returns all registered paths matching a prefix.
func (c *Config) GetRegisteredPaths(prefix string) map[string]bool
// Validate checks that all specified required paths have been set.
func (c *Config) Validate(required ...string) error
// Debug returns a formatted string of all values and their sources for debugging.
func (c *Config) Debug() string
```
### Environment
```go
// DiscoverEnv discovers environment variables matching a prefix.
func (c *Config) DiscoverEnv(prefix string) map[string]string
// ExportEnv exports the current configuration as environment variables
func (c *Config) ExportEnv(prefix string) map[string]string
```
## Builder Pattern
### Builder
```go
type Builder struct {
// Internal builder state
}
type ValidatorFunc func(c *Config) error
```
### Builder Methods
```go
// NewBuilder creates a new configuration builder.
func NewBuilder() *Builder
// Build finalizes configuration; returns the first of any accumulated errors.
func (b *Builder) Build() (*Config, error)
// WithDefaults sets the struct containing default values.
func (b *Builder) WithDefaults(defaults any) *Builder
// WithTarget enables type-aware mode for AsStruct() and registers struct fields.
func (b *Builder) WithTarget(target any) *Builder
// WithTagName sets the primary struct tag for field mapping: "toml", "json", "yaml".
func (b *Builder) WithTagName(tagName string) *Builder
// WithSources sets the precedence order for configuration sources.
func (b *Builder) WithSources(sources ...Source) *Builder
// WithPrefix adds a prefix to all registered paths from a struct.
func (b *Builder) WithPrefix(prefix string) *Builder
// WithEnvPrefix sets the global environment variable prefix.
func (b *Builder) WithEnvPrefix(prefix string) *Builder
// WithFile sets the configuration file path to be loaded.
func (b *Builder) WithFile(path string) *Builder
// WithArgs sets the command-line arguments to be parsed.
func (b *Builder) WithArgs(args []string) *Builder
// WithValidator adds a validation function that runs after loading.
func (b *Builder) WithValidator(fn ValidatorFunc) *Builder
// WithEnvTransform sets a custom environment variable mapping function.
func (b *Builder) WithSources(sources ...Source) *Builder
// WithEnvTransform sets a custom environment variable mapping function.
func (b *Builder) WithEnvTransform(fn EnvTransformFunc) *Builder
// WithFileDiscovery enables automatic config file discovery
func (b *Builder) WithFileDiscovery(opts FileDiscoveryOptions) *Builder
```
### FileDiscoveryOptions
```go
type FileDiscoveryOptions struct {
Name string // Base name without extension
Extensions []string // Extensions to try in order
Paths []string // Custom search paths
EnvVar string // Environment variable for path
CLIFlag string // CLI flag for path
UseXDG bool // Search XDG directories
UseCurrentDir bool // Search current directory
}
func DefaultDiscoveryOptions(appName string) FileDiscoveryOptions
```
## Live Reconfiguration
### AutoUpdate
```go
// AutoUpdate enables automatic configuration reloading on file changes with default options.
func (c *Config) AutoUpdate()
// AutoUpdateWithOptions enables reloading with custom options.
func (c *Config) AutoUpdateWithOptions(opts WatchOptions)
// StopAutoUpdate stops the file watcher and cleans up resources.
func (c *Config) StopAutoUpdate()
// IsWatching returns true if the file watcher is active.
func (c *Config) IsWatching() bool
```
### Watch
```go
// Watch returns a channel that receives paths of changed values.
func (c *Config) Watch() <-chan string
// WatcherCount returns the number of active watch subscribers.
func (c *Config) WatcherCount() int
```
Channel receives paths of changed values or special notifications: `"file_deleted"`, `"permissions_changed"`, `"reload_error:*"`.
### WatchOptions
```go
type WatchOptions struct {
PollInterval time.Duration // File check interval (min 100ms)
Debounce time.Duration // Delay after changes
MaxWatchers int // Concurrent watch limit
ReloadTimeout time.Duration // Reload operation timeout
VerifyPermissions bool // Check permission changes
}
func DefaultWatchOptions() WatchOptions
```
## Type System
### Supported Types
- Basic: `bool`, `int64`, `float64`, `string`
- Time: `time.Duration`, `time.Time`
- Network: `net.IP`, `net.IPNet`, `url.URL`
- Slices: Any slice type with comma-separated parsing
- Complex: Any type via mapstructure decode hooks
### Type Conversion
All integer types are stored as `int64`, and floats as `float64`. String inputs from sources like environment variables or CLI arguments are automatically parsed to the target registered type. Custom types supported via decode hooks.
### Struct Tags
The `WithTagName` builder method sets the primary tag used for mapping paths.
```go
type Config struct {
// Uses the tag set by WithTagName (default "toml") for path name.
// The `env` tag provides an explicit environment variable override.
Port int64 `toml:"port" env:"PORT"`
Timeout time.Duration `toml:"timeout"`
// Slices are populated from comma-separated strings (env/CLI) or arrays (file).
Tags []string `toml:"tags"`
}
```
## Thread Safety
All methods are thread-safe. Concurrent reads and writes are synchronized internally.
## Path Validation
- Paths use dot notation: "server.port", "database.connections.max"
- Segments must be valid identifiers: `[A-Za-z0-9_-]+`
- No leading/trailing dots or empty segments
## Source Precedence
Default order (highest to lowest):
1. CLI arguments
2. Environment variables
3. Configuration file
4. Default values
Precedence is configurable via `Builder.WithSources()` or `LoadOptions.Sources`.

176
doc/quick-start.md Normal file
View File

@ -0,0 +1,176 @@
# Quick Start Guide
This guide gets you up and running with the config package in minutes.
## Basic Usage
The simplest way to use the config package is with the `Quick` function:
```go
package main
import (
"log"
"github.com/lixenwraith/config"
)
// Define your configuration structure
type Config struct {
Server struct {
Host string `toml:"host"`
Port int `toml:"port"`
} `toml:"server"`
Database struct {
URL string `toml:"url"`
MaxConns int `toml:"max_conns"`
} `toml:"database"`
Debug bool `toml:"debug"`
}
func main() {
// Create defaults
defaults := &Config{}
defaults.Server.Host = "localhost"
defaults.Server.Port = 8080
defaults.Database.URL = "postgres://localhost/mydb"
defaults.Database.MaxConns = 10
defaults.Debug = false
// Initialize configuration
cfg, err := config.Quick(
defaults, // Default values from struct
"MYAPP_", // Environment variable prefix
"config.toml", // Configuration file path
)
if err != nil {
log.Fatal(err)
}
// Access values
port, _ := cfg.Get("server.port")
dbURL, _ := cfg.Get("database.url")
log.Printf("Server running on port %d", port.(int64))
log.Printf("Database URL: %s", dbURL.(string))
}
```
## Configuration Sources
The package loads configuration from multiple sources in this default order (highest to lowest priority):
1. **Command-line arguments** - Override everything
2. **Environment variables** - Override file and defaults
3. **Configuration file** - Override defaults
4. **Default values** - Base configuration
### Command-Line Arguments
```bash
./myapp --server.port=9090 --debug
```
### Environment Variables
```bash
export MYAPP_SERVER_PORT=9090
export MYAPP_DATABASE_URL="postgres://prod/mydb"
export MYAPP_DEBUG=true
```
### Configuration File (config.toml)
```toml
[server]
host = "0.0.0.0"
port = 8080
[database]
url = "postgres://localhost/mydb"
max_conns = 25
debug = false
```
## Type Safety
The package uses struct tags to ensure type safety. When you register a struct, the types are enforced:
```go
// This struct defines the expected types
type Config struct {
Port int64 `toml:"port"` // Must be a number
Host string `toml:"host"` // Must be a string
Debug bool `toml:"debug"` // Must be a boolean
}
// Type assertions are safe after registration
port, _ := cfg.Get("port")
portNum := port.(int64) // Safe - type is guaranteed
```
## Error Handling
The package validates types during loading:
```go
cfg, err := config.Quick(defaults, "APP_", "config.toml")
if err != nil {
// Handle errors like:
// - Invalid TOML syntax
// - Type mismatches (e.g., string value for int field)
// - File permissions issues
log.Fatal(err)
}
```
## Common Patterns
### Required Fields
```go
// Register required configuration
cfg.RegisterRequired("api.key", "")
cfg.RegisterRequired("database.url", "")
// Validate all required fields are set
if err := cfg.Validate("api.key", "database.url"); err != nil {
log.Fatal("Missing required configuration:", err)
}
```
### Using Different Struct Tags
```go
// Use JSON tags instead of TOML
type Config struct {
Server struct {
Host string `json:"host"`
Port int `json:"port"`
} `json:"server"`
}
cfg, _ := config.NewBuilder().
WithTarget(&Config{}).
WithTagName("json").
WithFile("config.toml").
Build()
```
### Checking Value Sources
```go
// See which source provided a value
port, _ := cfg.Get("server.port")
sources := cfg.GetSources("server.port")
for source, value := range sources {
log.Printf("server.port from %s: %v", source, value)
}
```
## Next Steps
- [Builder Pattern](builder.md) - Advanced configuration options
- [Environment Variables](env.md) - Detailed environment variable handling
- [Access Patterns](access.md) - All ways to get and set values

355
doc/reconfiguration.md Normal file
View File

@ -0,0 +1,355 @@
# Live Reconfiguration
The config package supports automatic configuration reloading when files change, enabling zero-downtime reconfiguration.
## Basic File Watching
### Enable Auto-Update
```go
cfg, _ := config.NewBuilder().
WithDefaults(&Config{}).
WithFile("config.toml").
Build()
// Enable automatic reloading
cfg.AutoUpdate()
// Your application continues running
// Config reloads automatically when file changes
// Stop watching when done
defer cfg.StopAutoUpdate()
```
### Watch for Changes
```go
// Get notified of configuration changes
changes := cfg.Watch()
go func() {
for path := range changes {
log.Printf("Configuration changed: %s", path)
// React to specific changes
switch path {
case "server.port":
// Restart server with new port
restartServer()
case "log.level":
// Update log level
updateLogLevel()
}
}
}()
```
## Watch Options
### Custom Watch Configuration
```go
opts := config.WatchOptions{
PollInterval: 500 * time.Millisecond, // Check every 500ms
Debounce: 200 * time.Millisecond, // Wait 200ms after changes
MaxWatchers: 50, // Limit concurrent watchers
ReloadTimeout: 10 * time.Second, // Timeout for reload
VerifyPermissions: true, // Security check
}
cfg.AutoUpdateWithOptions(opts)
```
### Watch Without Auto-Update
```go
// Just watch, don't auto-reload
changes := cfg.WatchWithOptions(config.WatchOptions{
PollInterval: time.Second,
})
// Manually reload when desired
go func() {
for range changes {
if shouldReload() {
cfg.LoadFile("config.toml")
}
}
}()
```
## Change Detection
### Value Changes
The watcher detects and notifies about:
- New values added
- Existing values modified
- Values removed
- Type changes
```go
changes := cfg.Watch()
for path := range changes {
newVal, exists := cfg.Get(path)
if !exists {
log.Printf("Removed: %s", path)
continue
}
sources := cfg.GetSources(path)
fileVal, hasFile := sources[config.SourceFile]
log.Printf("Changed: %s = %v (from file: %v)",
path, newVal, hasFile)
}
```
### Special Notifications
```go
changes := cfg.Watch()
for notification := range changes {
switch notification {
case "file_deleted":
log.Warn("Config file was deleted")
case "permissions_changed":
log.Error("Config file permissions changed - potential security issue")
case "reload_timeout":
log.Error("Config reload timed out")
default:
if strings.HasPrefix(notification, "reload_error:") {
log.Error("Reload error:", notification)
} else {
// Normal path change
handleConfigChange(notification)
}
}
}
```
## Debouncing
Rapid file changes are automatically debounced:
```go
// Multiple rapid saves to config.toml
// Only triggers one reload after debounce period
opts := config.WatchOptions{
PollInterval: 100 * time.Millisecond,
Debounce: 500 * time.Millisecond, // Wait 500ms
}
cfg.AutoUpdateWithOptions(opts)
```
## Permission Monitoring
```go
opts := config.WatchOptions{
VerifyPermissions: true, // Enabled by default
}
cfg.AutoUpdateWithOptions(opts)
// Detects if file becomes world-writable
changes := cfg.Watch()
for change := range changes {
if change == "permissions_changed" {
// File permissions changed
// Possible security breach
alert("Config file permissions modified!")
}
}
```
## Pattern: Reconfiguration
```go
type Server struct {
cfg *config.Config
listener net.Listener
mu sync.RWMutex
}
func (s *Server) watchConfig() {
changes := s.cfg.Watch()
for path := range changes {
switch {
case strings.HasPrefix(path, "server."):
s.scheduleRestart()
case path == "log.level":
s.updateLogLevel()
case strings.HasPrefix(path, "feature."):
s.reloadFeatures()
}
}
}
func (s *Server) scheduleRestart() {
s.mu.Lock()
defer s.mu.Unlock()
// Graceful restart logic
log.Info("Scheduling server restart for config changes")
// ... drain connections, restart listener ...
}
```
## Pattern: Feature Flags
```go
type FeatureFlags struct {
cfg *config.Config
mu sync.RWMutex
}
func (ff *FeatureFlags) Watch() {
changes := ff.cfg.Watch()
for path := range changes {
if strings.HasPrefix(path, "features.") {
feature := strings.TrimPrefix(path, "features.")
enabled, _ := ff.cfg.Get(path)
log.Printf("Feature %s: %v", feature, enabled)
ff.notifyFeatureChange(feature, enabled.(bool))
}
}
}
func (ff *FeatureFlags) IsEnabled(feature string) bool {
ff.mu.RLock()
defer ff.mu.RUnlock()
val, exists := ff.cfg.Get("features." + feature)
return exists && val.(bool)
}
```
## Pattern: Multi-Stage Reload
```go
func watchConfigWithValidation(cfg *config.Config) {
changes := cfg.Watch()
for range changes {
// Stage 1: Snapshot current config
backup := cfg.Clone()
// Stage 2: Validate new configuration
if err := validateNewConfig(cfg); err != nil {
log.Error("Invalid configuration:", err)
continue
}
// Stage 3: Apply changes
if err := applyConfigChanges(cfg, backup); err != nil {
log.Error("Failed to apply changes:", err)
// Could restore from backup here
continue
}
log.Info("Configuration successfully reloaded")
}
}
```
## Monitoring
### Watch Status
```go
// Check if watching is active
if cfg.IsWatching() {
log.Printf("Auto-update is enabled")
log.Printf("Active watchers: %d", cfg.WatcherCount())
}
```
### Resource Management
```go
// Limit watchers to prevent resource exhaustion
opts := config.WatchOptions{
MaxWatchers: 10, // Max 10 concurrent watch channels
}
// Watchers beyond limit receive closed channels
cfg.AutoUpdateWithOptions(opts)
```
## Best Practices
1. **Always Stop Watching**: Use `defer cfg.StopAutoUpdate()` to clean up
2. **Handle All Notifications**: Check for special error notifications
3. **Validate After Reload**: Ensure new config is valid before applying
4. **Use Debouncing**: Prevent reload storms from rapid edits
5. **Monitor Permissions**: Enable permission verification for security
6. **Graceful Updates**: Plan how your app handles config changes
7. **Log Changes**: Audit configuration modifications
## Limitations
- File watching uses polling (not inotify/kqueue)
- No support for watching multiple files
- Changes only detected for registered paths
- Reloads entire file (no partial updates)
## Common Issues
### Changes Not Detected
```go
// Ensure path is registered before watching
cfg.Register("new.value", "default")
// Now changes to new.value will be detected
```
### Rapid Reloads
```go
// Increase debounce to prevent rapid reloads
opts := config.WatchOptions{
Debounce: 2 * time.Second, // Wait 2s after changes stop
}
```
### Memory Leaks
```go
// Always stop watching to prevent goroutine leaks
watcher := cfg.Watch()
// Use context for cancellation
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
for {
select {
case change := <-watcher:
handleChange(change)
case <-ctx.Done():
return
}
}
}()
```
## See Also
- [File Configuration](file.md) - File format and loading
- [Access Patterns](access.md) - Reacting to changed values
- [Builder Pattern](builder.md) - Setting up watching with builder

418
dynamic_test.go Normal file
View File

@ -0,0 +1,418 @@
// FILE: lixenwraith/config/dynamic_test.go
package config
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMultiFormatLoading tests loading different config formats
func TestMultiFormatLoading(t *testing.T) {
tmpDir := t.TempDir()
// Create test config in different formats
tomlConfig := `
[server]
host = "toml-host"
port = 8080
[database]
url = "postgres://localhost/toml"
`
jsonConfig := `{
"server": {
"host": "json-host",
"port": 9090
},
"database": {
"url": "postgres://localhost/json"
}
}`
yamlConfig := `
server:
host: yaml-host
port: 7070
database:
url: postgres://localhost/yaml
`
// Write config files
tomlPath := filepath.Join(tmpDir, "config.toml")
jsonPath := filepath.Join(tmpDir, "config.json")
yamlPath := filepath.Join(tmpDir, "config.yaml")
require.NoError(t, os.WriteFile(tomlPath, []byte(tomlConfig), 0644))
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlConfig), 0644))
t.Run("AutoDetectFormats", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
cfg.Register("server.port", 0)
cfg.Register("database.url", "")
// Test TOML
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(tomlPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "toml-host", host)
// Test JSON
require.NoError(t, cfg.LoadFile(jsonPath))
host, _ = cfg.Get("server.host")
assert.Equal(t, "json-host", host)
port, _ := cfg.Get("server.port")
// JSON number should be preserved as json.Number but convertible
switch v := port.(type) {
case json.Number:
// Expected for raw value
assert.Equal(t, json.Number("9090"), v)
case int64:
// Expected after decode hook conversion
assert.Equal(t, int64(9090), v)
case float64:
// Alternative conversion
assert.Equal(t, float64(9090), v)
default:
t.Errorf("Unexpected type for port: %T", port)
}
// Test YAML
require.NoError(t, cfg.LoadFile(yamlPath))
host, _ = cfg.Get("server.host")
assert.Equal(t, "yaml-host", host)
})
t.Run("ExplicitFormat", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
// Force JSON parsing on .conf file
confPath := filepath.Join(tmpDir, "config.conf")
require.NoError(t, os.WriteFile(confPath, []byte(jsonConfig), 0644))
cfg.SetFileFormat("json")
require.NoError(t, cfg.LoadFile(confPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "json-host", host)
})
t.Run("ContentDetection", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
// Ambiguous extension
ambigPath := filepath.Join(tmpDir, "config.conf")
require.NoError(t, os.WriteFile(ambigPath, []byte(yamlConfig), 0644))
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(ambigPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "yaml-host", host)
})
}
// TestDynamicFormatSwitching tests runtime format changes
func TestDynamicFormatSwitching(t *testing.T) {
tmpDir := t.TempDir()
// Create configs in different formats with same structure
configs := map[string]string{
"toml": `value = "from-toml"`,
"json": `{"value": "from-json"}`,
"yaml": `value: from-yaml`,
}
cfg := New()
cfg.Register("value", "default")
for format, content := range configs {
t.Run(format, func(t *testing.T) {
filePath := filepath.Join(tmpDir, "config."+format)
require.NoError(t, os.WriteFile(filePath, []byte(content), 0644))
// Set format and load
require.NoError(t, cfg.SetFileFormat(format))
require.NoError(t, cfg.LoadFile(filePath))
val, _ := cfg.Get("value")
assert.Equal(t, "from-"+format, val)
})
}
}
// TestWatchFileFormatSwitch tests watching different file formats
func TestWatchFileFormatSwitch(t *testing.T) {
tmpDir := t.TempDir()
tomlPath := filepath.Join(tmpDir, "config.toml")
jsonPath := filepath.Join(tmpDir, "config.json")
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-1"`), 0644))
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-1"}`), 0644))
cfg := New()
cfg.Register("value", "default")
// Configure fast polling for test
opts := WatchOptions{
PollInterval: testPollInterval, // Fast polling for tests
Debounce: testDebounce, // Short debounce
MaxWatchers: 10,
}
// Start watching TOML
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(tomlPath))
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
// Wait for watcher to start
require.Eventually(t, func() bool {
return cfg.IsWatching()
}, 4*testDebounce, 2*SpinWaitInterval)
val, _ := cfg.Get("value")
assert.Equal(t, "toml-1", val)
// Switch to JSON with format hint
require.NoError(t, cfg.WatchFile(jsonPath, "json"))
// Wait for new watcher to start
require.Eventually(t, func() bool {
return cfg.IsWatching()
}, 4*testDebounce, 2*SpinWaitInterval)
// Get watch channel AFTER switching files
changes := cfg.Watch()
val, _ = cfg.Get("value")
assert.Equal(t, "json-1", val)
// Update JSON file
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-2"}`), 0644))
// Wait for change notification
select {
case path := <-changes:
assert.Equal(t, "value", path)
// Wait a bit for value to be updated
require.Eventually(t, func() bool {
val, _ := cfg.Get("value")
return val == "json-2"
}, testEventuallyTimeout, 2*SpinWaitInterval)
case <-time.After(testWatchTimeout):
t.Error("Timeout waiting for JSON file change")
}
// Update old TOML file - should NOT trigger notification
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-2"`), 0644))
// Should not receive notification from old file
select {
case <-changes:
t.Error("Should not receive changes from old TOML file")
case <-time.After(testPollWindow):
// Expected - no change notification
}
}
// TestSecurityOptions tests security features
func TestSecurityOptions(t *testing.T) {
tmpDir := t.TempDir()
t.Run("PathTraversal", func(t *testing.T) {
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
PreventPathTraversal: true,
})
// Test various malicious paths
maliciousPaths := []string{
"../../../etc/passwd",
"./../etc/passwd",
"config/../../../etc/passwd",
filepath.Join("..", "..", "etc", "passwd"),
}
for _, malPath := range maliciousPaths {
err := cfg.LoadFile(malPath)
assert.Error(t, err, "Should reject path: %s", malPath)
assert.Contains(t, err.Error(), "path traversal")
}
// Valid paths should work
validPath := filepath.Join(tmpDir, "config.toml")
os.WriteFile(validPath, []byte(`test = "value"`), 0644)
cfg.Register("test", "")
err := cfg.LoadFile(validPath)
assert.NoError(t, err, "Should accept valid absolute path")
})
t.Run("FileSizeLimit", func(t *testing.T) {
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
MaxFileSize: 100, // 100 bytes limit
})
// Create large file
largePath := filepath.Join(tmpDir, "large.toml")
largeContent := make([]byte, 1024)
for i := range largeContent {
largeContent[i] = 'a'
}
require.NoError(t, os.WriteFile(largePath, largeContent, 0644))
err := cfg.LoadFile(largePath)
assert.Error(t, err)
assert.Contains(t, err.Error(), "exceeds maximum size")
})
t.Run("FileOwnership", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping ownership test on Windows")
}
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
EnforceFileOwnership: true,
})
// Create file owned by current user (should succeed)
ownedPath := filepath.Join(tmpDir, "owned.toml")
require.NoError(t, os.WriteFile(ownedPath, []byte(`test = "value"`), 0644))
cfg.Register("test", "")
err := cfg.LoadFile(ownedPath)
assert.NoError(t, err)
})
}
// waitForWatchingState waits for watcher state, preventing race conditions of goroutine start and test check
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
t.Helper()
require.Eventually(t, func() bool {
return cfg.IsWatching() == expected
}, testEventuallyTimeout, 2*SpinWaitInterval, msgAndArgs...)
}
// TestBuilderWithFormat tests Builder integration
func TestBuilderWithFormat(t *testing.T) {
tmpDir := t.TempDir()
jsonPath := filepath.Join(tmpDir, "config.json")
jsonConfig := `{
"server": {
"host": "builder-host",
"port": 8080
}
}`
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
type Config struct {
Server struct {
Host string `json:"host" toml:"host"`
Port int `json:"port" toml:"port"`
} `json:"server" toml:"server"`
}
defaults := &Config{}
defaults.Server.Host = "default-host"
defaults.Server.Port = 3000
cfg, err := NewBuilder().
WithDefaults(defaults).
WithFile(jsonPath).
WithFileFormat("json").
WithTagName("toml"). // Use toml tags for registration
WithSecurityOptions(SecurityOptions{
PreventPathTraversal: true,
MaxFileSize: 1024 * 1024, // 1MB
}).
Build()
require.NoError(t, err)
// Check the value was loaded
host, exists := cfg.Get("server.host")
assert.True(t, exists, "server.host should exist")
assert.Equal(t, "builder-host", host)
port, exists := cfg.Get("server.port")
assert.True(t, exists, "server.port should exist")
// Handle json.Number or converted int
switch v := port.(type) {
case json.Number:
p, _ := v.Int64()
assert.Equal(t, int64(8080), p)
case int64:
assert.Equal(t, int64(8080), v)
case float64:
assert.Equal(t, float64(8080), v)
default:
t.Errorf("Unexpected type for port: %T", port)
}
}
// BenchmarkFormatParsing benchmarks different format parsing speeds
func BenchmarkFormatParsing(b *testing.B) {
tmpDir := b.TempDir()
// Create test data
configs := map[string]string{
"toml": `
[server]
host = "localhost"
port = 8080
[database]
url = "postgres://localhost/db"
[cache]
ttl = 300
`,
"json": `{
"server": {"host": "localhost", "port": 8080},
"database": {"url": "postgres://localhost/db"},
"cache": {"ttl": 300}
}`,
"yaml": `
server:
host: localhost
port: 8080
database:
url: postgres://localhost/db
cache:
ttl: 300
`,
}
for format, content := range configs {
b.Run(format, func(b *testing.B) {
path := filepath.Join(tmpDir, "bench."+format)
os.WriteFile(path, []byte(content), 0644)
cfg := New()
cfg.Register("server.host", "")
cfg.Register("server.port", 0)
cfg.Register("database.url", "")
cfg.Register("cache.ttl", 0)
cfg.SetFileFormat(format)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cfg.LoadFile(path)
}
})
}
}

View File

@ -1,225 +0,0 @@
// File: lixenwraith/config/env_test.go
package config_test
import (
"os"
"testing"
"github.com/lixenwraith/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEnvironmentVariables(t *testing.T) {
t.Run("Basic Environment Loading", func(t *testing.T) {
// Set up environment
envVars := map[string]string{
"TEST_SERVER_HOST": "env-host",
"TEST_SERVER_PORT": "9999",
"TEST_DEBUG": "true",
}
for k, v := range envVars {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
cfg := config.New()
cfg.Register("server.host", "default-host")
cfg.Register("server.port", 8080)
cfg.Register("debug", false)
// Load environment variables
err := cfg.LoadEnv("TEST_")
require.NoError(t, err)
// Verify values
host, _ := cfg.String("server.host")
assert.Equal(t, "env-host", host)
port, _ := cfg.Int64("server.port")
assert.Equal(t, int64(9999), port)
debug, _ := cfg.Bool("debug")
assert.True(t, debug)
})
t.Run("Custom Environment Transform", func(t *testing.T) {
os.Setenv("PORT", "3000")
os.Setenv("DATABASE_URL", "postgres://localhost/test")
defer func() {
os.Unsetenv("PORT")
os.Unsetenv("DATABASE_URL")
}()
cfg := config.New()
cfg.Register("server.port", 8080)
cfg.Register("database.url", "sqlite://memory")
opts := config.LoadOptions{
Sources: []config.Source{config.SourceEnv, config.SourceDefault},
EnvTransform: func(path string) string {
mapping := map[string]string{
"server.port": "PORT",
"database.url": "DATABASE_URL",
}
return mapping[path]
},
}
err := cfg.LoadWithOptions("", nil, opts)
require.NoError(t, err)
port, _ := cfg.Int64("server.port")
assert.Equal(t, int64(3000), port)
dbURL, _ := cfg.String("database.url")
assert.Equal(t, "postgres://localhost/test", dbURL)
})
t.Run("Environment Discovery", func(t *testing.T) {
// Set up various env vars
envVars := map[string]string{
"APP_SERVER_HOST": "discovered",
"APP_SERVER_PORT": "4444",
"APP_UNREGISTERED": "ignored",
}
for k, v := range envVars {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
cfg := config.New()
cfg.Register("server.host", "default")
cfg.Register("server.port", 8080)
cfg.Register("server.timeout", 30)
// Discover which registered paths have env vars
discovered := cfg.DiscoverEnv("APP_")
// Should find 2 env vars
assert.Len(t, discovered, 2)
assert.Equal(t, "APP_SERVER_HOST", discovered["server.host"])
assert.Equal(t, "APP_SERVER_PORT", discovered["server.port"])
assert.NotContains(t, discovered, "unregistered")
})
t.Run("Environment Whitelist", func(t *testing.T) {
envVars := map[string]string{
"SECRET_API_KEY": "secret-value",
"SECRET_DATABASE_PASSWORD": "db-pass",
"SECRET_SERVER_PORT": "5555",
}
for k, v := range envVars {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
cfg := config.New()
cfg.Register("api.key", "")
cfg.Register("database.password", "")
cfg.Register("server.port", 8080)
opts := config.LoadOptions{
Sources: []config.Source{config.SourceEnv, config.SourceDefault},
EnvPrefix: "SECRET_",
EnvWhitelist: map[string]bool{
"api.key": true,
"database.password": true,
// server.port is NOT whitelisted
},
}
cfg.LoadWithOptions("", nil, opts)
// Whitelisted values should load
apiKey, _ := cfg.String("api.key")
assert.Equal(t, "secret-value", apiKey)
dbPass, _ := cfg.String("database.password")
assert.Equal(t, "db-pass", dbPass)
// Non-whitelisted should use default
port, _ := cfg.Int64("server.port")
assert.Equal(t, int64(8080), port)
})
t.Run("RegisterWithEnv", func(t *testing.T) {
os.Setenv("CUSTOM_PORT", "6666")
defer os.Unsetenv("CUSTOM_PORT")
cfg := config.New()
// Register with explicit env mapping
err := cfg.RegisterWithEnv("server.port", 8080, "CUSTOM_PORT")
require.NoError(t, err)
// Should immediately have env value
port, _ := cfg.Int64("server.port")
assert.Equal(t, int64(6666), port)
})
t.Run("Export Environment", func(t *testing.T) {
cfg := config.New()
cfg.Register("app.name", "myapp")
cfg.Register("app.version", "1.0.0")
cfg.Register("server.port", 8080)
// Set some non-default values
cfg.Set("app.version", "2.0.0")
cfg.Set("server.port", 9090)
// Export as env vars
exports := cfg.ExportEnv("EXPORT_")
// Should export non-default values
assert.Equal(t, "2.0.0", exports["EXPORT_APP_VERSION"])
assert.Equal(t, "9090", exports["EXPORT_SERVER_PORT"])
// Should not export defaults
assert.NotContains(t, exports, "EXPORT_APP_NAME")
})
t.Run("Type Parsing from Environment", func(t *testing.T) {
envVars := map[string]string{
"TYPES_STRING": "hello world",
"TYPES_INT": "42",
"TYPES_FLOAT": "3.14159",
"TYPES_BOOL_TRUE": "true",
"TYPES_BOOL_FALSE": "false",
"TYPES_QUOTED": `"quoted string"`,
}
for k, v := range envVars {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
cfg := config.New()
cfg.Register("string", "")
cfg.Register("int", 0)
cfg.Register("float", 0.0)
cfg.Register("bool.true", false)
cfg.Register("bool.false", true)
cfg.Register("quoted", "")
cfg.LoadEnv("TYPES_")
// Verify type conversions
s, _ := cfg.String("string")
assert.Equal(t, "hello world", s)
i, _ := cfg.Int64("int")
assert.Equal(t, int64(42), i)
f, _ := cfg.Float64("float")
assert.Equal(t, 3.14159, f)
bt, _ := cfg.Bool("bool.true")
assert.True(t, bt)
bf, _ := cfg.Bool("bool.false")
assert.False(t, bf)
q, _ := cfg.String("quoted")
assert.Equal(t, "quoted string", q)
})
}

202
example/main.go Normal file
View File

@ -0,0 +1,202 @@
// FILE: lixenwraith/config/example/main.go
package main
import (
"fmt"
"log"
"os"
"sync"
"time"
"github.com/lixenwraith/config"
)
// AppConfig defines a richer configuration structure to showcase more features.
type AppConfig struct {
Server struct {
Host string `toml:"host"`
Port int64 `toml:"port"`
LogLevel string `toml:"log_level"`
} `toml:"server"`
FeatureFlags map[string]bool `toml:"feature_flags"`
}
const configFilePath = "config.toml"
func main() {
// =========================================================================
// PART 1: INITIAL SETUP
// Create a clean config.toml file on disk for our program to read.
// =========================================================================
log.Println("---")
log.Println("➡️ PART 1: Creating initial configuration file...")
// Defer cleanup to run at the very end of the program.
defer func() {
log.Println("---")
log.Println("🧹 Cleaning up...")
os.Remove(configFilePath)
// Unset the environment variable we use for testing.
os.Unsetenv("APP_SERVER_PORT")
log.Printf("Removed %s and unset APP_SERVER_PORT.", configFilePath)
}()
initialData := &AppConfig{}
initialData.Server.Host = "localhost"
initialData.Server.Port = 8080
initialData.Server.LogLevel = "info"
initialData.FeatureFlags = map[string]bool{"enable_metrics": true}
if err := createInitialConfigFile(initialData); err != nil {
log.Fatalf("❌ Failed during initial file creation: %v", err)
}
log.Printf("✅ Initial configuration saved to %s.", configFilePath)
// =========================================================================
// PART 2: RECOMMENDED CONFIGURATION USING THE BUILDER
// This demonstrates source precedence, validation, and type-safe targets.
// =========================================================================
log.Println("---")
log.Println("➡️ PART 2: Configuring manager with the Builder...")
// Set an environment variable to demonstrate source precedence (Env > File).
os.Setenv("APP_SERVER_PORT", "8888")
log.Println(" (Set environment variable APP_SERVER_PORT=8888)")
// Create a "target" struct. The builder will automatically populate this
// and keep it updated when using `AsStruct()`.
target := &AppConfig{}
// Use the builder to chain multiple configuration options.
builder := config.NewBuilder().
WithTarget(target). // Enables type-safe `AsStruct()` and auto-registration.
WithDefaults(initialData). // Explicitly set the source of defaults.
WithFile(configFilePath). // Specifies the config file to read.
WithEnvPrefix("APP_"). // Sets prefix for environment variables (e.g., APP_SERVER_PORT).
WithTypedValidator(func(cfg *AppConfig) error { // <-- NEW METHOD
// No type assertion needed! `cfg.Server.Port` is guaranteed to be an int64
// because the validator runs *after* the target struct is populated.
if cfg.Server.Port < 1024 || cfg.Server.Port > 65535 {
return fmt.Errorf("port %d is outside the recommended range (1024-65535)", cfg.Server.Port)
}
return nil
})
// Build the final config object.
cfg, err := builder.Build()
if err != nil {
log.Fatalf("❌ Builder failed: %v", err)
}
log.Println("✅ Builder finished successfully. Initial values loaded.")
initialTarget, _ := cfg.AsStruct()
printCurrentState(initialTarget.(*AppConfig), "Initial State (Env overrides File)")
// =========================================================================
// PART 3: DYNAMIC RELOADING WITH THE WATCHER
// We'll now modify the file and verify the watcher updates the config.
// =========================================================================
log.Println("---")
log.Println("➡️ PART 3: Testing the file watcher...")
// Use WithOptions to demonstrate customizing the watcher.
watchOpts := config.WatchOptions{
PollInterval: 250 * time.Millisecond,
Debounce: 100 * time.Millisecond,
}
cfg.AutoUpdateWithOptions(watchOpts)
changes := cfg.Watch()
log.Println("✅ Watcher is now active with custom options.")
// Start a goroutine to modify the file after a short delay.
var wg sync.WaitGroup
wg.Add(1)
go modifyFileOnDiskStructurally(&wg)
log.Println(" (Modifier goroutine dispatched to change file in 1 second...)")
log.Println(" (Waiting for watcher notification...)")
select {
case path := <-changes:
log.Printf("✅ Watcher detected a change for path: '%s'", path)
log.Println(" Verifying in-memory config using AsStruct()...")
// Retrieve the updated, type-safe struct.
updatedTarget, err := cfg.AsStruct()
if err != nil {
log.Fatalf("❌ AsStruct() failed after update: %v", err)
}
// Type-assert and verify the new values.
typedCfg := updatedTarget.(*AppConfig)
expectedLevel := "debug"
if typedCfg.Server.LogLevel != expectedLevel {
log.Fatalf("❌ VERIFICATION FAILED: Expected log_level '%s', but got '%s'.", expectedLevel, typedCfg.Server.LogLevel)
}
log.Println("✅ VERIFICATION SUCCESSFUL: In-memory config was updated by the watcher.")
printCurrentState(typedCfg, "Final State (Updated by Watcher)")
case <-time.After(5 * time.Second):
log.Fatalf("❌ TEST FAILED: Timed out waiting for watcher notification.")
}
wg.Wait()
}
// createInitialConfigFile is a helper to set up the initial file state.
func createInitialConfigFile(data *AppConfig) error {
cfg := config.New()
if err := cfg.RegisterStruct("", data); err != nil {
return err
}
return cfg.Save(configFilePath)
}
// modifyFileOnDiskStructurally simulates an external program that changes the config file.
func modifyFileOnDiskStructurally(wg *sync.WaitGroup) {
defer wg.Done()
time.Sleep(1 * time.Second)
log.Println(" (Modifier goroutine: now changing file on disk...)")
// Create a new, independent config instance to simulate an external process.
modifierCfg := config.New()
// Register the struct shape so the loader knows what paths are valid.
if err := modifierCfg.RegisterStruct("", &AppConfig{}); err != nil {
log.Fatalf("❌ Modifier failed to register struct: %v", err)
}
// Load the current state from disk.
if err := modifierCfg.LoadFile(configFilePath); err != nil {
log.Fatalf("❌ Modifier failed to load file: %v", err)
}
// Change the log level.
modifierCfg.Set("server.log_level", "debug")
// Use the generic GetTyped function. This is safe because modifierCfg has loaded the file.
featureFlags, err := config.GetTyped[map[string]bool](modifierCfg, "feature_flags")
if err != nil {
log.Fatalf("❌ Modifier failed to get typed feature_flags: %v", err)
}
// Modify the typed map and set it back.
featureFlags["enable_metrics"] = false
modifierCfg.Set("feature_flags", featureFlags)
// Save the changes back to disk, which will trigger the watcher in the main goroutine.
if err := modifierCfg.Save(configFilePath); err != nil {
log.Fatalf("❌ Modifier failed to save file: %v", err)
}
log.Println(" (Modifier goroutine: finished.)")
}
// printCurrentState is a helper to display the typed config state.
func printCurrentState(cfg *AppConfig, title string) {
fmt.Println(" --------------------------------------------------")
fmt.Printf(" %s\n", title)
fmt.Println(" --------------------------------------------------")
fmt.Printf(" Server Host: %s\n", cfg.Server.Host)
fmt.Printf(" Server Port: %d\n", cfg.Server.Port)
fmt.Printf(" Server Log Level: %s\n", cfg.Server.LogLevel)
fmt.Printf(" Feature Flags: %v\n", cfg.FeatureFlags)
fmt.Println(" --------------------------------------------------")
}

6
go.mod
View File

@ -1,15 +1,17 @@
module github.com/lixenwraith/config module github.com/lixenwraith/config
go 1.24.5 go 1.25.1
require ( require (
github.com/BurntSushi/toml v1.5.0 github.com/BurntSushi/toml v1.5.0
github.com/mitchellh/mapstructure v1.5.0 github.com/mitchellh/mapstructure v1.5.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )
replace github.com/mitchellh/mapstructure => github.com/go-viper/mapstructure v1.6.0

4
go.sum
View File

@ -2,8 +2,8 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/go-viper/mapstructure v1.6.0 h1:0WdPOF2rmmQDN1xo8qIgxyugvLp71HrZSWyGLxofobw=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/go-viper/mapstructure v1.6.0/go.mod h1:FcbLReH7/cjaC0RVQR+LHFIrBhHF3s1e/ud1KMDoBVw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=

View File

@ -1,4 +1,4 @@
// File: lixenwraith/config/helper.go // FILE: lixenwraith/config/helper.go
package config package config
import "strings" import "strings"
@ -86,15 +86,3 @@ func isValidKeySegment(s string) bool {
} }
return true return true
} }
// isAlpha checks if a character is a letter (A-Z, a-z)
// Note: not used, potential future use.
func isAlpha(c rune) bool {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
}
// isNumeric checks if a character is a digit (0-9)
// Note: not used, potential future use.
func isNumeric(c rune) bool {
return c >= '0' && c <= '9'
}

View File

@ -1,18 +1,83 @@
// File: lixenwraith/config/io.go // FILE: lixenwraith/config/loader.go
package config package config
import ( import (
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "runtime"
"strings" "strings"
"syscall"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"gopkg.in/yaml.v3"
) )
// Source represents a configuration source, used to define load precedence
type Source string
const (
// SourceDefault represents use of registered default values
SourceDefault Source = "default"
// SourceFile represents values loaded from a configuration file
SourceFile Source = "file"
// SourceEnv represents values loaded from environment variables
SourceEnv Source = "env"
// SourceCLI represents values loaded from command-line arguments
SourceCLI Source = "cli"
)
// LoadMode defines how configuration sources are processed
type LoadMode int
const (
// LoadModeReplace completely replaces values (default behavior)
LoadModeReplace LoadMode = iota
// LoadModeMerge merges maps/structs instead of replacing
// TODO: future implementation
LoadModeMerge
)
// EnvTransformFunc converts a configuration path to an environment variable name
type EnvTransformFunc func(path string) string
// LoadOptions configures how configuration is loaded from multiple sources
type LoadOptions struct {
// Sources defines the precedence order (first = highest priority)
// Default: [SourceCLI, SourceEnv, SourceFile, SourceDefault]
Sources []Source
// EnvPrefix is prepended to environment variable names
// Example: "MYAPP_" transforms "server.port" to "MYAPP_SERVER_PORT"
EnvPrefix string
// EnvTransform customizes how paths map to environment variables
// If nil, uses default transformation (dots to underscores, uppercase)
EnvTransform EnvTransformFunc
// LoadMode determines how values are merged
LoadMode LoadMode
// EnvWhitelist limits which paths are checked for env vars (nil = all)
EnvWhitelist map[string]bool
// SkipValidation skips path validation during load
SkipValidation bool
}
// DefaultLoadOptions returns the standard load options
func DefaultLoadOptions() LoadOptions {
return LoadOptions{
Sources: []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault},
LoadMode: LoadModeReplace,
}
}
// Load reads configuration from a TOML file and merges overrides from command-line arguments. // Load reads configuration from a TOML file and merges overrides from command-line arguments.
// This is a convenience method that maintains backward compatibility. // This is a convenience method that maintains backward compatibility.
func (c *Config) Load(filePath string, args []string) error { func (c *Config) Load(filePath string, args []string) error {
@ -83,120 +148,250 @@ func (c *Config) LoadFile(filePath string) error {
// loadFile reads and parses a TOML configuration file // loadFile reads and parses a TOML configuration file
func (c *Config) loadFile(path string) error { func (c *Config) loadFile(path string) error {
fileData, err := os.ReadFile(path) // Security: Path traversal check
if c.securityOpts != nil && c.securityOpts.PreventPathTraversal {
// Clean the path and check for traversal attempts
cleanPath := filepath.Clean(path)
// Check if cleaned path tries to go outside current directory
if strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) || cleanPath == ".." {
return fmt.Errorf("potential path traversal detected in config path: %s", path)
}
// Also check for absolute paths that might escape jail
if filepath.IsAbs(cleanPath) && filepath.IsAbs(path) {
// Absolute paths are OK if that's what was provided
} else if filepath.IsAbs(cleanPath) && !filepath.IsAbs(path) {
// Relative path became absolute after cleaning - suspicious
return fmt.Errorf("potential path traversal detected in config path: %s", path)
}
}
// Read file with size limit
fileInfo, err := os.Stat(path)
if err != nil { if err != nil {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
return ErrConfigNotFound return ErrConfigNotFound
} }
return fmt.Errorf("failed to stat config file '%s': %w", path, err)
}
// Security: File size check
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
if fileInfo.Size() > c.securityOpts.MaxFileSize {
return fmt.Errorf("config file '%s' exceeds maximum size %d bytes", path, c.securityOpts.MaxFileSize)
}
}
// Security: File ownership check (Unix only)
if c.securityOpts != nil && c.securityOpts.EnforceFileOwnership && runtime.GOOS != "windows" {
if stat, ok := fileInfo.Sys().(*syscall.Stat_t); ok {
if stat.Uid != uint32(os.Geteuid()) {
return fmt.Errorf("config file '%s' is not owned by current user (file UID: %d, process UID: %d)",
path, stat.Uid, os.Geteuid())
}
}
}
// 1. Read and parse file data
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("failed to open config file '%s': %w", path, err)
}
defer file.Close()
// Use LimitedReader for additional safety
var reader io.Reader = file
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
reader = io.LimitReader(file, c.securityOpts.MaxFileSize)
}
fileData, err := io.ReadAll(reader)
if err != nil {
return fmt.Errorf("failed to read config file '%s': %w", path, err) return fmt.Errorf("failed to read config file '%s': %w", path, err)
} }
// Determine format
format := c.fileFormat
if format == "" || format == "auto" {
// Try extension first
format = detectFileFormat(path)
if format == "" {
// Fall back to content detection
format = detectFormatFromContent(fileData)
if format == "" {
// Last resort: use tagName as hint
format = c.tagName
}
}
}
// Parse based on detected/specified format
fileConfig := make(map[string]any) fileConfig := make(map[string]any)
switch format {
case "toml":
if err := toml.Unmarshal(fileData, &fileConfig); err != nil { if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err) return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
} }
case "json":
decoder := json.NewDecoder(bytes.NewReader(fileData))
decoder.UseNumber() // Preserve number precision
if err := decoder.Decode(&fileConfig); err != nil {
return fmt.Errorf("failed to parse JSON config file '%s': %w", path, err)
}
case "yaml":
if err := yaml.Unmarshal(fileData, &fileConfig); err != nil {
return fmt.Errorf("failed to parse YAML config file '%s': %w", path, err)
}
default:
return fmt.Errorf("unable to determine config format for file '%s'", path)
}
// Flatten and apply file data // 2. Prepare New State (Read-Lock Only)
flattenedFileConfig := flattenMap(fileConfig, "") newFileData := make(map[string]any)
// Briefly acquire a read-lock to safely get the list of registered paths.
c.mutex.RLock()
registeredPaths := make(map[string]bool, len(c.items))
for p := range c.items {
registeredPaths[p] = true
}
c.mutex.RUnlock()
// Define a recursive function to populate newFileData. This runs without any lock.
var apply func(prefix string, data map[string]any)
apply = func(prefix string, data map[string]any) {
for key, value := range data {
fullPath := key
if prefix != "" {
fullPath = prefix + "." + key
}
if registeredPaths[fullPath] {
newFileData[fullPath] = value
} else if subMap, isMap := value.(map[string]any); isMap {
apply(fullPath, subMap)
}
}
}
apply("", fileConfig)
// 3. Atomically Update Config (Write-Lock)
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
// Store in cache c.configFilePath = path
c.fileData = flattenedFileConfig c.fileData = newFileData
// Apply to registered paths // Apply the new state to the main config items.
for path, value := range flattenedFileConfig { for path, item := range c.items {
if item, exists := c.items[path]; exists { if value, exists := newFileData[path]; exists {
if item.values == nil { if item.values == nil {
item.values = make(map[Source]any) item.values = make(map[Source]any)
} }
item.values[SourceFile] = value item.values[SourceFile] = value
item.currentValue = c.computeValue(path, item) } else {
// Key was not in the new file, so remove its old file-sourced value.
delete(item.values, SourceFile)
}
// Recompute the current value based on new source precedence.
item.currentValue = c.computeValue(item)
c.items[path] = item c.items[path] = item
} }
// Ignore unregistered paths from file
}
c.invalidateCache()
return nil return nil
} }
// loadEnv loads configuration from environment variables // loadEnv loads configuration from environment variables
func (c *Config) loadEnv(opts LoadOptions) error { func (c *Config) loadEnv(opts LoadOptions) error {
// Default transform function
transform := opts.EnvTransform transform := opts.EnvTransform
if transform == nil { if transform == nil {
transform = func(path string) string { transform = defaultEnvTransform(opts.EnvPrefix)
// Convert dots to underscores and uppercase
env := strings.ReplaceAll(path, ".", "_")
env = strings.ToUpper(env)
if opts.EnvPrefix != "" {
env = opts.EnvPrefix + env
}
return env
}
} }
c.mutex.Lock() // -- 1. Prepare data (Read-Lock to get paths)
defer c.mutex.Unlock() c.mutex.RLock()
paths := make([]string, 0, len(c.items))
for p := range c.items {
paths = append(paths, p)
}
c.mutex.RUnlock()
// Clear previous env data // -- 2. Process env vars (No Lock)
c.envData = make(map[string]any) foundEnvVars := make(map[string]string)
for _, path := range paths {
// Check each registered path for corresponding env var
for path, item := range c.items {
// Skip if whitelisted and not in whitelist
if opts.EnvWhitelist != nil && !opts.EnvWhitelist[path] { if opts.EnvWhitelist != nil && !opts.EnvWhitelist[path] {
continue continue
} }
envVar := transform(path) envVar := transform(path)
if value, exists := os.LookupEnv(envVar); exists { if value, exists := os.LookupEnv(envVar); exists {
// Parse the string value if len(value) > MaxValueSize {
parsedValue := parseValue(value) return ErrValueSize
}
foundEnvVars[path] = value
}
}
// If no relevant env vars were found, we are done.
if len(foundEnvVars) == 0 {
return nil
}
// -- 3. Atomically update config (Write-Lock)
c.mutex.Lock()
defer c.mutex.Unlock()
c.envData = make(map[string]any, len(foundEnvVars))
for path, value := range foundEnvVars {
// Store raw string value - mapstructure will handle conversion later.
if item, exists := c.items[path]; exists {
if item.values == nil { if item.values == nil {
item.values = make(map[Source]any) item.values = make(map[Source]any)
} }
item.values[SourceEnv] = parsedValue item.values[SourceEnv] = value // Store as string
item.currentValue = c.computeValue(path, item) item.currentValue = c.computeValue(item)
c.items[path] = item c.items[path] = item
c.envData[path] = parsedValue c.envData[path] = value
} }
} }
c.invalidateCache()
return nil return nil
} }
// loadCLI loads configuration from command-line arguments // loadCLI loads configuration from command-line arguments
func (c *Config) loadCLI(args []string) error { func (c *Config) loadCLI(args []string) error {
// -- 1. Prepare data (No Lock)
parsedCLI, err := parseArgs(args) parsedCLI, err := parseArgs(args)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrCLIParse, err) return fmt.Errorf("%w: %w", ErrCLIParse, err)
} }
// Flatten CLI data
flattenedCLI := flattenMap(parsedCLI, "") flattenedCLI := flattenMap(parsedCLI, "")
if len(flattenedCLI) == 0 {
return nil // No CLI args to process.
}
// 2. Atomically update config (Write-Lock)
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
// Store in cache
c.cliData = flattenedCLI c.cliData = flattenedCLI
// Apply to registered paths
for path, value := range flattenedCLI { for path, value := range flattenedCLI {
if item, exists := c.items[path]; exists { if item, exists := c.items[path]; exists {
if item.values == nil { if item.values == nil {
item.values = make(map[Source]any) item.values = make(map[Source]any)
} }
item.values[SourceCLI] = value item.values[SourceCLI] = value
item.currentValue = c.computeValue(path, item) item.currentValue = c.computeValue(item)
c.items[path] = item c.items[path] = item
} }
// Ignore unregistered paths from CLI
} }
c.invalidateCache()
return nil return nil
} }
@ -260,20 +455,13 @@ func defaultEnvTransform(prefix string) EnvTransformFunc {
} }
// parseValue attempts to parse a string into appropriate types // parseValue attempts to parse a string into appropriate types
// Only basic parse, complex parsing is deferred to mapstructure's decode hooks
func parseValue(s string) any { func parseValue(s string) any {
// Try boolean if s == "true" {
if v, err := strconv.ParseBool(s); err == nil { return true
return v
} }
if s == "false" {
// Try int64 return false
if v, err := strconv.ParseInt(s, 10, 64); err == nil {
return v
}
// Try float64
if v, err := strconv.ParseFloat(s, 64); err == nil {
return v
} }
// Remove quotes if present // Remove quotes if present
@ -281,7 +469,7 @@ func parseValue(s string) any {
return s[1 : len(s)-1] return s[1 : len(s)-1]
} }
// Return as string // Return as string - mapstructure will convert as needed
return s return s
} }
@ -370,14 +558,13 @@ func (c *Config) SaveSource(path string, source Source) error {
c.mutex.RUnlock() c.mutex.RUnlock()
// Use the same atomic save logic // Marshal using BurntSushi/toml
var buf bytes.Buffer var buf bytes.Buffer
encoder := toml.NewEncoder(&buf) encoder := toml.NewEncoder(&buf)
if err := encoder.Encode(nestedData); err != nil { if err := encoder.Encode(nestedData); err != nil {
return fmt.Errorf("failed to marshal config data to TOML: %w", err) return fmt.Errorf("failed to marshal %s source data to TOML: %w", source, err)
} }
// ... (rest of atomic save logic same as Save method)
return atomicWriteFile(path, buf.Bytes()) return atomicWriteFile(path, buf.Bytes())
} }
@ -433,7 +620,6 @@ func parseArgs(args []string) (map[string]any, error) {
continue continue
} }
// Remove the leading "--"
argContent := strings.TrimPrefix(arg, "--") argContent := strings.TrimPrefix(arg, "--")
if argContent == "" { if argContent == "" {
// Skip "--" argument if used as a separator // Skip "--" argument if used as a separator
@ -453,20 +639,22 @@ func parseArgs(args []string) (map[string]any, error) {
} else { } else {
// Handle "--key value" or "--booleanflag" // Handle "--key value" or "--booleanflag"
keyPath = argContent keyPath = argContent
// Check if it's potentially a boolean flag // Check if it's a boolean flag (next arg is another flag or end of args)
isBoolFlag := i+1 >= len(args) || strings.HasPrefix(args[i+1], "--") if i+1 >= len(args) || strings.HasPrefix(args[i+1], "--") {
if isBoolFlag {
// Assume boolean flag is true if no value follows
valueStr = "true" valueStr = "true"
i++ // Consume only the flag argument i++ // Consume only the flag argument
} else { } else {
// Potential key-value pair with space separation // It's a key-value pair with a space
valueStr = args[i+1] valueStr = args[i+1]
i += 2 // Consume flag and value arguments i += 2 // Consume both flag and value arguments
} }
} }
if keyPath == "" {
// Skip invalid flags like --=value
continue
}
// Validate keyPath segments // Validate keyPath segments
segments := strings.Split(keyPath, ".") segments := strings.Split(keyPath, ".")
for _, segment := range segments { for _, segment := range segments {
@ -475,10 +663,50 @@ func parseArgs(args []string) (map[string]any, error) {
} }
} }
// Parse the value // Always store as a string. Let Scan handle final type conversion.
value := parseValue(valueStr) setNestedValue(result, keyPath, valueStr)
setNestedValue(result, keyPath, value)
} }
return result, nil return result, nil
} }
// detectFileFormat determines format from file extension
func detectFileFormat(path string) string {
ext := strings.ToLower(filepath.Ext(path))
switch ext {
case ".toml", ".tml":
return "toml"
case ".json":
return "json"
case ".yaml", ".yml":
return "yaml"
case ".conf", ".config":
// Try to detect from content
return ""
default:
return ""
}
}
// detectFormatFromContent attempts to detect format by parsing
func detectFormatFromContent(data []byte) string {
// Try JSON first (strict format)
var jsonTest any
if err := json.Unmarshal(data, &jsonTest); err == nil {
return "json"
}
// Try YAML (superset of JSON, so check after JSON)
var yamlTest any
if err := yaml.Unmarshal(data, &yamlTest); err == nil {
return "yaml"
}
// Try TOML last
var tomlTest any
if err := toml.Unmarshal(data, &tomlTest); err == nil {
return "toml"
}
return ""
}

434
loader_test.go Normal file
View File

@ -0,0 +1,434 @@
// FILE: lixenwraith/config/loader_test.go
package config
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestFileLoading tests TOML file loading
func TestFileLoading(t *testing.T) {
tmpDir := t.TempDir()
t.Run("ValidTOMLFile", func(t *testing.T) {
configFile := filepath.Join(tmpDir, "valid.toml")
content := `
# Server configuration
[server]
host = "example.com"
port = 9000
enabled = true
[server.tls]
cert = "/path/to/cert.pem"
key = "/path/to/key.pem"
[database]
connections = [1, 2, 3]
tags = ["primary", "replica"]
`
os.WriteFile(configFile, []byte(content), 0644)
cfg := New()
// Register all paths
cfg.Register("server.host", "localhost")
cfg.Register("server.port", 8080)
cfg.Register("server.enabled", false)
cfg.Register("server.tls.cert", "")
cfg.Register("server.tls.key", "")
cfg.Register("database.connections", []int{})
cfg.Register("database.tags", []string{})
err := cfg.LoadFile(configFile)
require.NoError(t, err)
// Verify loaded values
host, _ := cfg.Get("server.host")
assert.Equal(t, "example.com", host)
port, _ := cfg.Get("server.port")
assert.Equal(t, int64(9000), port)
enabled, _ := cfg.Get("server.enabled")
assert.Equal(t, true, enabled)
cert, _ := cfg.Get("server.tls.cert")
assert.Equal(t, "/path/to/cert.pem", cert)
// Arrays are loaded as []any
connections, _ := cfg.Get("database.connections")
assert.Equal(t, []any{int64(1), int64(2), int64(3)}, connections)
})
t.Run("InvalidTOMLFile", func(t *testing.T) {
configFile := filepath.Join(tmpDir, "invalid.toml")
os.WriteFile(configFile, []byte(`invalid = toml content`), 0644)
cfg := New()
err := cfg.LoadFile(configFile)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse TOML")
})
t.Run("NonExistentFile", func(t *testing.T) {
cfg := New()
err := cfg.LoadFile("/non/existent/file.toml")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrConfigNotFound)
})
t.Run("UnregisteredPathsIgnored", func(t *testing.T) {
configFile := filepath.Join(tmpDir, "extra.toml")
os.WriteFile(configFile, []byte(`
registered = "value"
unregistered = "ignored"
`), 0644)
cfg := New()
cfg.Register("registered", "")
err := cfg.LoadFile(configFile)
require.NoError(t, err)
val, exists := cfg.Get("registered")
assert.True(t, exists)
assert.Equal(t, "value", val)
_, exists = cfg.Get("unregistered")
assert.False(t, exists)
})
}
// TestEnvironmentLoading tests environment variable loading
func TestEnvironmentLoading(t *testing.T) {
// Save and restore environment
originalEnv := os.Environ()
defer func() {
os.Clearenv()
for _, e := range originalEnv {
parts := splitEnvVar(e)
if len(parts) == 2 {
os.Setenv(parts[0], parts[1])
}
}
}()
t.Run("DefaultEnvTransform", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "localhost")
cfg.Register("server.port", 8080)
cfg.Register("enable_debug", false)
os.Setenv("APP_SERVER_HOST", "envhost")
os.Setenv("APP_SERVER_PORT", "9090")
os.Setenv("APP_ENABLE_DEBUG", "true")
err := cfg.LoadEnv("APP_")
require.NoError(t, err)
host, _ := cfg.Get("server.host")
assert.Equal(t, "envhost", host)
port, _ := cfg.Get("server.port")
assert.Equal(t, "9090", port) // String from env
debug, _ := cfg.Get("enable_debug")
assert.Equal(t, "true", debug) // String from env
})
t.Run("CustomEnvTransform", func(t *testing.T) {
cfg := New()
cfg.Register("db.host", "localhost")
os.Setenv("DATABASE_HOSTNAME", "customhost")
opts := LoadOptions{
Sources: []Source{SourceEnv, SourceDefault},
EnvTransform: func(path string) string {
if path == "db.host" {
return "DATABASE_HOSTNAME"
}
return path
},
}
err := cfg.LoadWithOptions("", nil, opts)
require.NoError(t, err)
host, _ := cfg.Get("db.host")
assert.Equal(t, "customhost", host)
})
t.Run("EnvWhitelist", func(t *testing.T) {
cfg := New()
cfg.Register("allowed.path", "default1")
cfg.Register("blocked.path", "default2")
os.Setenv("ALLOWED_PATH", "env1")
os.Setenv("BLOCKED_PATH", "env2")
opts := LoadOptions{
Sources: []Source{SourceEnv, SourceDefault},
EnvWhitelist: map[string]bool{"allowed.path": true},
}
err := cfg.LoadWithOptions("", nil, opts)
require.NoError(t, err)
allowed, _ := cfg.Get("allowed.path")
assert.Equal(t, "env1", allowed)
blocked, _ := cfg.Get("blocked.path")
assert.Equal(t, "default2", blocked) // Should not load from env
})
t.Run("DiscoverEnv", func(t *testing.T) {
cfg := New()
cfg.Register("test.one", "")
cfg.Register("test.two", "")
cfg.Register("other.value", "")
os.Setenv("PREFIX_TEST_ONE", "value1")
os.Setenv("PREFIX_TEST_TWO", "value2")
os.Setenv("PREFIX_OTHER_VALUE", "value3")
os.Setenv("UNRELATED_VAR", "ignored")
discovered := cfg.DiscoverEnv("PREFIX_")
assert.Len(t, discovered, 3)
assert.Equal(t, "PREFIX_TEST_ONE", discovered["test.one"])
assert.Equal(t, "PREFIX_TEST_TWO", discovered["test.two"])
assert.Equal(t, "PREFIX_OTHER_VALUE", discovered["other.value"])
})
}
// TestCLIParsing tests command-line argument parsing
func TestCLIParsing(t *testing.T) {
tests := []struct {
name string
args []string
expected map[string]any
}{
{
name: "KeyValueWithEquals",
args: []string{"--server.host=example.com", "--server.port=9000"},
expected: map[string]any{
"server.host": "example.com",
"server.port": "9000",
},
},
{
name: "KeyValueWithSpace",
args: []string{"--server.host", "example.com", "--server.port", "9000"},
expected: map[string]any{
"server.host": "example.com",
"server.port": "9000",
},
},
{
name: "BooleanFlags",
args: []string{"--enable.debug", "--disable.cache", "false"},
expected: map[string]any{
"enable.debug": "true",
"disable.cache": "false",
},
},
{
name: "MixedFormats",
args: []string{
"--server.host=localhost",
"--server.port", "8080",
"--enable.tls",
"--database.pool.size=10",
},
expected: map[string]any{
"server.host": "localhost",
"server.port": "8080",
"enable.tls": "true",
"database.pool.size": "10",
},
},
{
name: "EmptyAndInvalidArgs",
args: []string{"", "--", "---", "--=value"},
expected: map[string]any{
"": "value",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := New()
// Register expected paths
for path := range tt.expected {
if path != "" { // Skip empty path
cfg.Register(path, "")
}
}
err := cfg.LoadCLI(tt.args)
require.NoError(t, err)
// Verify values
for path, expected := range tt.expected {
if path != "" {
val, exists := cfg.Get(path)
assert.True(t, exists, "Path %s should exist", path)
assert.Equal(t, expected, val)
}
}
})
}
t.Run("InvalidKeySegment", func(t *testing.T) {
result, err := parseArgs([]string{"--invalid!key=value"})
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid command-line key segment")
assert.Nil(t, result)
})
}
// TestLoadWithOptions tests complete loading with multiple sources
func TestLoadWithOptions(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.toml")
os.WriteFile(configFile, []byte(`
[server]
host = "filehost"
port = 8080
`), 0644)
os.Setenv("TEST_SERVER_HOST", "envhost")
os.Setenv("TEST_SERVER_PORT", "9090")
defer func() {
os.Unsetenv("TEST_SERVER_HOST")
os.Unsetenv("TEST_SERVER_PORT")
}()
cfg := New()
cfg.Register("server.host", "defaulthost")
cfg.Register("server.port", 3000)
args := []string{"--server.port=7070"}
opts := LoadOptions{
Sources: []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault},
EnvPrefix: "TEST_",
}
err := cfg.LoadWithOptions(configFile, args, opts)
require.NoError(t, err)
// CLI should win
port, _ := cfg.Get("server.port")
assert.Equal(t, "7070", port)
// ENV should win over file
host, _ := cfg.Get("server.host")
assert.Equal(t, "envhost", host)
// Test source inspection
sources := cfg.GetSources("server.port")
assert.Equal(t, "7070", sources[SourceCLI])
assert.Equal(t, "9090", sources[SourceEnv])
assert.Equal(t, int64(8080), sources[SourceFile])
}
// TestAtomicSave tests atomic file saving
func TestAtomicSave(t *testing.T) {
tmpDir := t.TempDir()
cfg := New()
cfg.Register("server.host", "localhost")
cfg.Register("server.port", 8080)
cfg.Register("database.url", "postgres://localhost/db")
// Set some values
cfg.Set("server.host", "savehost")
cfg.Set("server.port", 9999)
t.Run("SaveCurrentState", func(t *testing.T) {
savePath := filepath.Join(tmpDir, "saved.toml")
err := cfg.Save(savePath)
require.NoError(t, err)
// Verify file exists and is readable
content, err := os.ReadFile(savePath)
require.NoError(t, err)
assert.Contains(t, string(content), "savehost")
assert.Contains(t, string(content), "9999")
// Load into new config to verify
cfg2 := New()
cfg2.Register("server.host", "")
cfg2.Register("server.port", 0)
err = cfg2.LoadFile(savePath)
require.NoError(t, err)
host, _ := cfg2.Get("server.host")
assert.Equal(t, "savehost", host)
})
t.Run("SaveSpecificSource", func(t *testing.T) {
cfg.SetSource(SourceEnv, "server.host", "envhost")
cfg.SetSource(SourceEnv, "server.port", "7777")
cfg.SetSource(SourceFile, "server.port", "6666")
savePath := filepath.Join(tmpDir, "env-only.toml")
err := cfg.SaveSource(savePath, SourceEnv)
require.NoError(t, err)
content, err := os.ReadFile(savePath)
require.NoError(t, err)
assert.Contains(t, string(content), "envhost")
assert.Contains(t, string(content), "7777")
assert.NotContains(t, string(content), "6666")
})
t.Run("SaveToNonExistentDirectory", func(t *testing.T) {
savePath := filepath.Join(tmpDir, "new", "dir", "config.toml")
err := cfg.Save(savePath)
require.NoError(t, err)
// Verify file was created
_, err = os.Stat(savePath)
assert.NoError(t, err)
})
}
// TestExportEnv tests environment variable export
func TestExportEnv(t *testing.T) {
cfg := New()
cfg.Register("server.host", "defaulthost")
cfg.Register("server.port", 8080)
cfg.Register("feature.enabled", false)
// Only export non-default values
cfg.Set("server.host", "exporthost")
cfg.Set("feature.enabled", true)
exports := cfg.ExportEnv("APP_")
assert.Len(t, exports, 2)
assert.Equal(t, "exporthost", exports["APP_SERVER_HOST"])
assert.Equal(t, "true", exports["APP_FEATURE_ENABLED"])
assert.NotContains(t, exports, "APP_SERVER_PORT") // Still default
}
// splitEnvVar splits environment variable into key and value
func splitEnvVar(env string) []string {
parts := make([]string, 2)
for i := 0; i < len(env); i++ {
if env[i] == '=' {
parts[0] = env[:i]
parts[1] = env[i+1:]
return parts
}
}
return []string{env}
}

View File

@ -1,4 +1,4 @@
// File: lixenwraith/config/register.go // FILE: lixenwraith/config/register.go
package config package config
import ( import (
@ -6,8 +6,6 @@ import (
"os" "os"
"reflect" "reflect"
"strings" "strings"
"github.com/mitchellh/mapstructure"
) )
// Register makes a configuration path known to the Config instance. // Register makes a configuration path known to the Config instance.
@ -48,7 +46,7 @@ func (c *Config) RegisterWithEnv(path string, defaultValue any, envVar string) e
// Check if the environment variable exists and load it // Check if the environment variable exists and load it
if value, exists := os.LookupEnv(envVar); exists { if value, exists := os.LookupEnv(envVar); exists {
parsed := parseValue(value) parsed := parseValue(value)
return c.SetSource(path, SourceEnv, parsed) return c.SetSource(SourceEnv, path, parsed)
} }
return nil return nil
@ -102,24 +100,37 @@ func (c *Config) Unregister(path string) error {
// It uses struct tags (`toml:"..."`) to determine the configuration paths. // It uses struct tags (`toml:"..."`) to determine the configuration paths.
// The prefix is prepended to all paths (e.g., "log."). An empty prefix is allowed. // The prefix is prepended to all paths (e.g., "log."). An empty prefix is allowed.
func (c *Config) RegisterStruct(prefix string, structWithDefaults any) error { func (c *Config) RegisterStruct(prefix string, structWithDefaults any) error {
return c.RegisterStructWithTags(prefix, structWithDefaults, "toml")
}
// RegisterStructWithTags is like RegisterStruct but allows custom tag names
func (c *Config) RegisterStructWithTags(prefix string, structWithDefaults any, tagName string) error {
v := reflect.ValueOf(structWithDefaults) v := reflect.ValueOf(structWithDefaults)
// Handle pointer or direct struct value // Handle pointer or direct struct value
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if v.IsNil() { if v.IsNil() {
return fmt.Errorf("RegisterStruct requires a non-nil struct pointer or value") return fmt.Errorf("RegisterStructWithTags requires a non-nil struct pointer or value")
} }
v = v.Elem() v = v.Elem()
} }
if v.Kind() != reflect.Struct { if v.Kind() != reflect.Struct {
return fmt.Errorf("RegisterStruct requires a struct or struct pointer, got %T", structWithDefaults) return fmt.Errorf("RegisterStructWithTags requires a struct or struct pointer, got %T", structWithDefaults)
}
// Validate tag name
switch tagName {
case "toml", "json", "yaml":
// Supported tags
default:
return fmt.Errorf("unsupported tag name %q, must be one of: toml, json, yaml", tagName)
} }
var errors []string var errors []string
// Use a helper function for recursive registration // Use helper function for recursive registration with specified tag
c.registerFields(v, prefix, "", &errors) c.registerFields(v, prefix, "", &errors, tagName)
if len(errors) > 0 { if len(errors) > 0 {
return fmt.Errorf("failed to register %d field(s): %s", len(errors), strings.Join(errors, "; ")) return fmt.Errorf("failed to register %d field(s): %s", len(errors), strings.Join(errors, "; "))
@ -128,22 +139,8 @@ func (c *Config) RegisterStruct(prefix string, structWithDefaults any) error {
return nil return nil
} }
// RegisterStructWithTags is like RegisterStruct but allows custom tag names
func (c *Config) RegisterStructWithTags(prefix string, structWithDefaults any, tagName string) error {
// Save current tag preference
oldTag := "toml"
// Temporarily use custom tag
// Note: This would require modifying registerFields to accept tagName parameter
// For now, we'll keep using "toml" tag
_ = oldTag
_ = tagName
return c.RegisterStruct(prefix, structWithDefaults)
}
// registerFields is a helper function that handles the recursive field registration. // registerFields is a helper function that handles the recursive field registration.
func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, errors *[]string) { func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, errors *[]string, tagName string) {
t := v.Type() t := v.Type()
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
@ -154,16 +151,13 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
continue continue
} }
// Get tag value or use field name // Get tag value based on tagName parameter
tag := field.Tag.Get("toml") tag := field.Tag.Get(tagName)
if tag == "-" { if tag == "-" {
continue // Skip this field continue
} }
// Check for additional tags // Fall back to field name if no tag
envTag := field.Tag.Get("env") // Explicit env var name
required := field.Tag.Get("required") == "true"
key := field.Name key := field.Name
if tag != "" { if tag != "" {
parts := strings.Split(tag, ",") parts := strings.Split(tag, ",")
@ -172,6 +166,10 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
} }
} }
// Check for additional tags
envTag := field.Tag.Get("env") // Explicit env var name
required := field.Tag.Get("required") == "true"
// Build full path // Build full path
currentPath := key currentPath := key
if pathPrefix != "" { if pathPrefix != "" {
@ -181,27 +179,38 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
currentPath = pathPrefix + key currentPath = pathPrefix + key
} }
// TODO: use mapstructure instead of logic with reflection
// Handle nested structs recursively // Handle nested structs recursively
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
isStruct := fieldValue.Kind() == reflect.Struct isStruct := fieldValue.Kind() == reflect.Struct
isPtrToStruct := fieldValue.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct isPtrToStruct := fieldValue.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct
if isStruct || isPtrToStruct { if isStruct || isPtrToStruct {
// Dereference pointer if necessary // Check if the field's TYPE is one that should be treated as a single value,
// even though it's a struct. These types have custom decode hooks.
fieldType := fieldValue.Type()
isAtomicStruct := false
switch fieldType.String() {
case "time.Time", "*net.IPNet", "*url.URL", "net.IP": // Match the exact type names
isAtomicStruct = true
}
// Only recurse if it's a "normal" struct, not an atomic one.
if !isAtomicStruct {
nestedValue := fieldValue nestedValue := fieldValue
if isPtrToStruct { if isPtrToStruct {
if fieldValue.IsNil() { if fieldValue.IsNil() {
// Skip nil pointers continue // Skip nil pointers in the default struct
continue
} }
nestedValue = fieldValue.Elem() nestedValue = fieldValue.Elem()
} }
// For nested structs, append a dot and continue recursion
nestedPrefix := currentPath + "." nestedPrefix := currentPath + "."
c.registerFields(nestedValue, nestedPrefix, fieldPath+field.Name+".", errors) c.registerFields(nestedValue, nestedPrefix, fieldPath+field.Name+".", errors, tagName)
continue continue
} }
// If it is an atomic struct, we fall through and register it as a single value.
}
// Register non-struct fields // Register non-struct fields
defaultValue := fieldValue.Interface() defaultValue := fieldValue.Interface()
@ -221,7 +230,7 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
if envTag != "" && err == nil { if envTag != "" && err == nil {
if value, exists := os.LookupEnv(envTag); exists { if value, exists := os.LookupEnv(envTag); exists {
parsed := parseValue(value) parsed := parseValue(value)
if setErr := c.SetSource(currentPath, SourceEnv, parsed); setErr != nil { if setErr := c.SetSource(SourceEnv, currentPath, parsed); setErr != nil {
*errors = append(*errors, fmt.Sprintf("field %s%s env %s: %v", fieldPath, field.Name, envTag, setErr)) *errors = append(*errors, fmt.Sprintf("field %s%s env %s: %v", fieldPath, field.Name, envTag, setErr))
} }
} }
@ -230,13 +239,18 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
} }
// GetRegisteredPaths returns all registered configuration paths with the specified prefix. // GetRegisteredPaths returns all registered configuration paths with the specified prefix.
func (c *Config) GetRegisteredPaths(prefix string) map[string]bool { func (c *Config) GetRegisteredPaths(prefix ...string) map[string]bool {
p := ""
if len(prefix) > 0 {
p = prefix[0]
}
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
result := make(map[string]bool) result := make(map[string]bool)
for path := range c.items { for path := range c.items {
if strings.HasPrefix(path, prefix) { if strings.HasPrefix(path, p) {
result[path] = true result[path] = true
} }
} }
@ -245,13 +259,18 @@ func (c *Config) GetRegisteredPaths(prefix string) map[string]bool {
} }
// GetRegisteredPathsWithDefaults returns paths with their default values // GetRegisteredPathsWithDefaults returns paths with their default values
func (c *Config) GetRegisteredPathsWithDefaults(prefix string) map[string]any { func (c *Config) GetRegisteredPathsWithDefaults(prefix ...string) map[string]any {
p := ""
if len(prefix) > 0 {
p = prefix[0]
}
c.mutex.RLock() c.mutex.RLock()
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
result := make(map[string]any) result := make(map[string]any)
for path, item := range c.items { for path, item := range c.items {
if strings.HasPrefix(path, prefix) { if strings.HasPrefix(path, p) {
result[path] = item.defaultValue result[path] = item.defaultValue
} }
} }
@ -259,164 +278,12 @@ func (c *Config) GetRegisteredPathsWithDefaults(prefix string) map[string]any {
return result return result
} }
// Scan decodes the configuration data under a specific base path // Scan decodes configuration into target using the unified unmarshal function
// into the target struct or map. It operates on the current, merged configuration state. func (c *Config) Scan(target any, basePath ...string) error {
// The target must be a non-nil pointer to a struct or map. return c.unmarshal("", target, basePath...)
// It uses the "toml" struct tag for mapping fields.
func (c *Config) Scan(basePath string, target any) error {
// Validate target
rv := reflect.ValueOf(target)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return fmt.Errorf("target of Scan must be a non-nil pointer, got %T", target)
} }
c.mutex.RLock() // Read lock is sufficient // ScanSource decodes configuration from specific source using unified unmarshal
func (c *Config) ScanSource(source Source, target any, basePath ...string) error {
// Build the full nested map from the current state of registered items return c.unmarshal(source, target, basePath...)
fullNestedMap := make(map[string]any)
for path, item := range c.items {
setNestedValue(fullNestedMap, path, item.currentValue)
}
c.mutex.RUnlock() // Unlock before decoding
var sectionData any = fullNestedMap
// Navigate to the specific section if basePath is provided
if basePath != "" {
// Allow trailing dot for convenience
basePath = strings.TrimSuffix(basePath, ".")
if basePath == "" { // Handle case where input was just "."
// Use the full map
} else {
segments := strings.Split(basePath, ".")
current := any(fullNestedMap)
found := true
for _, segment := range segments {
currentMap, ok := current.(map[string]any)
if !ok {
// Path segment does not lead to a map/table
found = false
break
}
value, exists := currentMap[segment]
if !exists {
// The requested path segment does not exist in the current config
found = false
break
}
current = value
}
if !found {
// If the path doesn't fully exist, decode an empty map into the target.
sectionData = make(map[string]any)
} else {
sectionData = current
}
}
}
// Ensure the final data we are decoding from is actually a map
sectionMap, ok := sectionData.(map[string]any)
if !ok {
// This can happen if the basePath points to a non-map value
return fmt.Errorf("configuration path %q does not refer to a scannable section (map), but to type %T", basePath, sectionData)
}
// Use mapstructure to decode the relevant section map into the target
decoderConfig := &mapstructure.DecoderConfig{
Result: target,
TagName: "toml", // Use the same tag name for consistency
WeaklyTypedInput: true, // Allow conversions
DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
),
}
decoder, err := mapstructure.NewDecoder(decoderConfig)
if err != nil {
return fmt.Errorf("failed to create mapstructure decoder: %w", err)
}
err = decoder.Decode(sectionMap)
if err != nil {
return fmt.Errorf("failed to scan section %q into %T: %w", basePath, target, err)
}
return nil
}
// ScanSource scans configuration from a specific source
func (c *Config) ScanSource(basePath string, source Source, target any) error {
// Validate target
rv := reflect.ValueOf(target)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return fmt.Errorf("target of ScanSource must be a non-nil pointer, got %T", target)
}
c.mutex.RLock()
// Build nested map from specific source only
nestedMap := make(map[string]any)
for path, item := range c.items {
if val, exists := item.values[source]; exists {
setNestedValue(nestedMap, path, val)
}
}
c.mutex.RUnlock()
// Rest of the logic is similar to Scan
var sectionData any = nestedMap
if basePath != "" {
basePath = strings.TrimSuffix(basePath, ".")
if basePath != "" {
segments := strings.Split(basePath, ".")
current := any(nestedMap)
for _, segment := range segments {
currentMap, ok := current.(map[string]any)
if !ok {
sectionData = make(map[string]any)
break
}
value, exists := currentMap[segment]
if !exists {
sectionData = make(map[string]any)
break
}
current = value
}
sectionData = current
}
}
sectionMap, ok := sectionData.(map[string]any)
if !ok {
return fmt.Errorf("path %q does not refer to a map in source %s", basePath, source)
}
decoderConfig := &mapstructure.DecoderConfig{
Result: target,
TagName: "toml",
WeaklyTypedInput: true,
DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
),
}
decoder, err := mapstructure.NewDecoder(decoderConfig)
if err != nil {
return fmt.Errorf("failed to create decoder: %w", err)
}
return decoder.Decode(sectionMap)
} }

View File

@ -1,219 +0,0 @@
// File: lixenwraith/config/source_test.go
package config_test
import (
"os"
"path/filepath"
"testing"
"github.com/lixenwraith/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMultiSourceConfiguration(t *testing.T) {
t.Run("Source Precedence", func(t *testing.T) {
cfg := config.New()
cfg.Register("test.value", "default")
// Set values in different sources
cfg.SetSource("test.value", config.SourceFile, "from-file")
cfg.SetSource("test.value", config.SourceEnv, "from-env")
cfg.SetSource("test.value", config.SourceCLI, "from-cli")
// Default precedence: CLI > Env > File > Default
val, _ := cfg.String("test.value")
assert.Equal(t, "from-cli", val)
// Change precedence
opts := config.LoadOptions{
Sources: []config.Source{
config.SourceEnv,
config.SourceCLI,
config.SourceFile,
config.SourceDefault,
},
}
cfg.SetLoadOptions(opts)
// Now env should win
val, _ = cfg.String("test.value")
assert.Equal(t, "from-env", val)
})
t.Run("Source Tracking", func(t *testing.T) {
cfg := config.New()
cfg.Register("server.port", 8080)
// Set from multiple sources
cfg.SetSource("server.port", config.SourceFile, 9090)
cfg.SetSource("server.port", config.SourceEnv, 7070)
// Get all sources
sources := cfg.GetSources("server.port")
// Should have 2 sources
assert.Len(t, sources, 2)
assert.Equal(t, 9090, sources[config.SourceFile])
assert.Equal(t, 7070, sources[config.SourceEnv])
})
t.Run("GetSource", func(t *testing.T) {
cfg := config.New()
cfg.Register("api.key", "default-key")
cfg.SetSource("api.key", config.SourceEnv, "env-key")
// Get from specific source
envVal, exists := cfg.GetSource("api.key", config.SourceEnv)
assert.True(t, exists)
assert.Equal(t, "env-key", envVal)
// Get from missing source
_, exists = cfg.GetSource("api.key", config.SourceFile)
assert.False(t, exists)
})
t.Run("Reset Sources", func(t *testing.T) {
cfg := config.New()
cfg.Register("test1", "default1")
cfg.Register("test2", "default2")
// Set values
cfg.SetSource("test1", config.SourceFile, "file1")
cfg.SetSource("test1", config.SourceEnv, "env1")
cfg.SetSource("test2", config.SourceCLI, "cli2")
// Reset specific source
cfg.ResetSource(config.SourceEnv)
// Env value should be gone
_, exists := cfg.GetSource("test1", config.SourceEnv)
assert.False(t, exists)
// Other sources remain
fileVal, _ := cfg.GetSource("test1", config.SourceFile)
assert.Equal(t, "file1", fileVal)
// Reset all
cfg.Reset()
// All values should be defaults
val1, _ := cfg.String("test1")
val2, _ := cfg.String("test2")
assert.Equal(t, "default1", val1)
assert.Equal(t, "default2", val2)
})
t.Run("LoadWithOptions Integration", func(t *testing.T) {
// Create temp config file
tmpdir := t.TempDir()
configFile := filepath.Join(tmpdir, "test.toml")
configContent := `
[server]
host = "file-host"
port = 8080
[feature]
enabled = true
`
require.NoError(t, os.WriteFile(configFile, []byte(configContent), 0644))
// Set environment
os.Setenv("TEST_SERVER_PORT", "9090")
os.Setenv("TEST_FEATURE_ENABLED", "false")
t.Cleanup(func() {
os.Unsetenv("TEST_SERVER_PORT")
os.Unsetenv("TEST_FEATURE_ENABLED")
})
cfg := config.New()
cfg.Register("server.host", "default-host")
cfg.Register("server.port", 7070)
cfg.Register("feature.enabled", false)
// Load with custom precedence (File highest)
opts := config.LoadOptions{
Sources: []config.Source{
config.SourceFile,
config.SourceEnv,
config.SourceCLI,
config.SourceDefault,
},
EnvPrefix: "TEST_",
}
err := cfg.LoadWithOptions(configFile, []string{"--server.host=cli-host"}, opts)
require.NoError(t, err)
// File should win for all values
host, _ := cfg.String("server.host")
assert.Equal(t, "file-host", host)
port, _ := cfg.Int64("server.port")
assert.Equal(t, int64(8080), port)
enabled, _ := cfg.Bool("feature.enabled")
assert.True(t, enabled)
})
t.Run("ScanSource", func(t *testing.T) {
type ServerConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
}
cfg := config.New()
cfg.Register("server.host", "default")
cfg.Register("server.port", 8080)
// Set different values in different sources
cfg.SetSource("server.host", config.SourceFile, "file-host")
cfg.SetSource("server.port", config.SourceFile, 8080)
cfg.SetSource("server.host", config.SourceEnv, "env-host")
cfg.SetSource("server.port", config.SourceEnv, 9090)
// Scan from specific source
var fileConfig ServerConfig
err := cfg.ScanSource("server", config.SourceFile, &fileConfig)
require.NoError(t, err)
assert.Equal(t, "file-host", fileConfig.Host)
assert.Equal(t, 8080, fileConfig.Port)
var envConfig ServerConfig
err = cfg.ScanSource("server", config.SourceEnv, &envConfig)
require.NoError(t, err)
assert.Equal(t, "env-host", envConfig.Host)
assert.Equal(t, 9090, envConfig.Port)
})
t.Run("SaveSource", func(t *testing.T) {
cfg := config.New()
cfg.Register("app.name", "myapp")
cfg.Register("app.version", "1.0.0")
cfg.Register("server.port", 8080)
// Set values in different sources
cfg.SetSource("app.name", config.SourceFile, "fileapp")
cfg.SetSource("app.version", config.SourceEnv, "2.0.0")
cfg.SetSource("server.port", config.SourceCLI, 9090)
// Save only env source
tmpfile := filepath.Join(t.TempDir(), "config-source.toml")
err := cfg.SaveSource(tmpfile, config.SourceEnv)
require.NoError(t, err)
// Load saved file and verify
newCfg := config.New()
newCfg.Register("app.version", "")
newCfg.LoadFile(tmpfile)
version, _ := newCfg.String("app.version")
assert.Equal(t, "2.0.0", version)
// Should not have other source values
name, _ := newCfg.String("app.name")
assert.Empty(t, name)
})
}

26
timing.go Normal file
View File

@ -0,0 +1,26 @@
// FILE: lixenwraith/config/timing.go
package config
import "time"
// Core timing constants for production use.
// These define the fundamental timing behavior of the config package.
const (
// File watching intervals (ordered by frequency)
SpinWaitInterval = 5 * time.Millisecond // CPU-friendly busy-wait quantum
MinPollInterval = 100 * time.Millisecond // Hard floor for file stat polling
ShutdownTimeout = 100 * time.Millisecond // Graceful watcher termination window
DefaultDebounce = 500 * time.Millisecond // File change coalescence period
DefaultPollInterval = time.Second // Standard file monitoring frequency
DefaultReloadTimeout = 5 * time.Second // Maximum duration for reload operations
)
// Derived timing relationships for internal use.
// These maintain consistent ratios between related timers.
const (
// shutdownPollCycles defines how many spin-wait cycles comprise a shutdown timeout
shutdownPollCycles = ShutdownTimeout / SpinWaitInterval // = 20 cycles
// debounceSettleMultiplier ensures sufficient time for debounce to complete
debounceSettleMultiplier = 3 // Wait 3x debounce period for value stabilization
)

162
type.go
View File

@ -1,162 +0,0 @@
// File: lixenwraith/config/type.go
package config
import (
"fmt"
"reflect"
"strconv"
)
// String retrieves a string configuration value using the path.
// Attempts conversion from common types if the stored value isn't already a string.
func (c *Config) String(path string) (string, error) {
val, found := c.Get(path)
if !found {
return "", fmt.Errorf("path not registered: %s", path)
}
if val == nil {
return "", nil // Treat nil as empty string for convenience
}
if strVal, ok := val.(string); ok {
return strVal, nil
}
// Attempt conversion for common types
switch v := val.(type) {
case fmt.Stringer:
return v.String(), nil
case []byte:
return string(v), nil
case int, int8, int16, int32, int64:
return strconv.FormatInt(reflect.ValueOf(val).Int(), 10), nil
case uint, uint8, uint16, uint32, uint64:
return strconv.FormatUint(reflect.ValueOf(val).Uint(), 10), nil
case float32, float64:
return strconv.FormatFloat(reflect.ValueOf(val).Float(), 'f', -1, 64), nil
case bool:
return strconv.FormatBool(v), nil
case error:
return v.Error(), nil
default:
return "", fmt.Errorf("cannot convert type %T to string for path %s", val, path)
}
}
// Int64 retrieves an int64 configuration value using the path.
// Attempts conversion from numeric types, parsable strings, and booleans.
func (c *Config) Int64(path string) (int64, error) {
val, found := c.Get(path)
if !found {
return 0, fmt.Errorf("path not registered: %s", path)
}
if val == nil {
return 0, fmt.Errorf("value for path %s is nil, cannot convert to int64", path)
}
// Use reflection for broader compatibility with numeric types
v := reflect.ValueOf(val)
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
u := v.Uint()
// Check for potential overflow converting uint64 to int64
maxInt64 := int64(^uint64(0) >> 1)
if u > uint64(maxInt64) {
return 0, fmt.Errorf("cannot convert unsigned integer %d (type %T) to int64 for path %s: overflow", u, val, path)
}
return int64(u), nil
case reflect.Float32, reflect.Float64:
// Truncate float to int
return int64(v.Float()), nil
case reflect.String:
s := v.String()
if i, err := strconv.ParseInt(s, 0, 64); err == nil { // Use base 0 for auto-detection (e.g., "0xFF")
return i, nil
} else {
if f, ferr := strconv.ParseFloat(s, 64); ferr == nil {
return int64(f), nil // Truncate
}
// Return the original integer parsing error if float also fails
return 0, fmt.Errorf("cannot convert string %q to int64 for path %s: %w", s, path, err)
}
case reflect.Bool:
if v.Bool() {
return 1, nil
}
return 0, nil
}
return 0, fmt.Errorf("cannot convert type %T to int64 for path %s", val, path)
}
// Bool retrieves a boolean configuration value using the path.
// Attempts conversion from numeric types (0=false, non-zero=true) and parsable strings.
func (c *Config) Bool(path string) (bool, error) {
val, found := c.Get(path)
if !found {
return false, fmt.Errorf("path not registered: %s", path)
}
if val == nil {
return false, fmt.Errorf("value for path %s is nil, cannot convert to bool", path)
}
v := reflect.ValueOf(val)
switch v.Kind() {
case reflect.Bool:
return v.Bool(), nil
case reflect.String:
s := v.String()
if b, err := strconv.ParseBool(s); err == nil {
return b, nil
} else {
return false, fmt.Errorf("cannot convert string %q to bool for path %s: %w", s, path, err)
}
// Numeric interpretation: 0 is false, non-zero is true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() != 0, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return v.Uint() != 0, nil
case reflect.Float32, reflect.Float64:
return v.Float() != 0, nil
}
return false, fmt.Errorf("cannot convert type %T to bool for path %s", val, path)
}
// Float64 retrieves a float64 configuration value using the path.
// Attempts conversion from numeric types, parsable strings, and booleans.
func (c *Config) Float64(path string) (float64, error) {
val, found := c.Get(path)
if !found {
return 0.0, fmt.Errorf("path not registered: %s", path)
}
if val == nil {
return 0.0, fmt.Errorf("value for path %s is nil, cannot convert to float64", path)
}
v := reflect.ValueOf(val)
switch v.Kind() {
case reflect.Float32, reflect.Float64:
return v.Float(), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return float64(v.Int()), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return float64(v.Uint()), nil
case reflect.String:
s := v.String()
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f, nil
} else {
return 0.0, fmt.Errorf("cannot convert string %q to float64 for path %s: %w", s, path, err)
}
case reflect.Bool:
if v.Bool() {
return 1.0, nil
}
return 0.0, nil
}
return 0.0, fmt.Errorf("cannot convert type %T to float64 for path %s", val, path)
}

123
validator.go Normal file
View File

@ -0,0 +1,123 @@
// FILE: lixenwraith/config/validator.go
package config
import (
"fmt"
"net"
"regexp"
"strings"
)
// Common validators for configuration values
// Port validates TCP/UDP port range
func Port(p int64) error {
if p < 1 || p > 65535 {
return fmt.Errorf("must be 1-65535, got %d", p)
}
return nil
}
// Positive validates positive numbers
func Positive[T int64 | float64](n T) error {
if n <= 0 {
return fmt.Errorf("must be positive, got %v", n)
}
return nil
}
// NonNegative validates non-negative numbers
func NonNegative[T int64 | float64](n T) error {
if n < 0 {
return fmt.Errorf("must be non-negative, got %v", n)
}
return nil
}
// IPAddress validates IP address format
func IPAddress(s string) error {
if s == "" || s == "0.0.0.0" || s == "::" {
return nil // Allow common defaults
}
if net.ParseIP(s) == nil {
return fmt.Errorf("invalid IP address: %s", s)
}
return nil
}
// IPv4Address validates IPv4 address format
func IPv4Address(s string) error {
if s == "" || s == "0.0.0.0" {
return nil // Allow common defaults
}
ip := net.ParseIP(s)
if ip == nil || ip.To4() == nil {
return fmt.Errorf("invalid IPv4 address: %s", s)
}
return nil
}
// IPv6Address validates IPv6 address format
func IPv6Address(s string) error {
if s == "" || s == "::" {
return nil // Allow common defaults
}
ip := net.ParseIP(s)
if ip == nil {
return fmt.Errorf("invalid IPv6 address: %s", s)
}
// Valid net.ParseIP with nil ip.To4 indicates IPv6
if ip.To4() != nil {
return fmt.Errorf("invalid IPv6 address (is an IPv4 address): %s", s)
}
return nil
}
// URLPath validates URL path format
func URLPath(s string) error {
if s != "" && !strings.HasPrefix(s, "/") {
return fmt.Errorf("must start with /: %s", s)
}
return nil
}
// OneOf creates a validator for allowed values
func OneOf[T comparable](allowed ...T) func(T) error {
return func(val T) error {
for _, a := range allowed {
if val == a {
return nil
}
}
return fmt.Errorf("must be one of %v, got %v", allowed, val)
}
}
// Range creates a min/max validator
func Range[T int64 | float64](min, max T) func(T) error {
return func(val T) error {
if val < min || val > max {
return fmt.Errorf("must be %v-%v, got %v", min, max, val)
}
return nil
}
}
// Pattern creates a regex validator
func Pattern(pattern string) func(string) error {
re := regexp.MustCompile(pattern)
return func(s string) error {
if !re.MatchString(s) {
return fmt.Errorf("must match pattern %s", pattern)
}
return nil
}
}
// NonEmpty validates non-empty strings
func NonEmpty(s string) error {
if strings.TrimSpace(s) == "" {
return fmt.Errorf("must not be empty")
}
return nil
}

163
validator_test.go Normal file
View File

@ -0,0 +1,163 @@
// FILE: lixenwraith/config/validator_test.go
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestPortValidator tests the Port validator
func TestPortValidator(t *testing.T) {
tests := []struct {
name string
port int64
wantErr bool
}{
{"ValidLowPort", 1, false},
{"ValidCommonPort", 8080, false},
{"ValidHighPort", 65535, false},
{"InvalidZeroPort", 0, true},
{"InvalidNegativePort", -1, true},
{"InvalidTooHighPort", 65536, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := Port(tt.port)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestPositiveValidator tests the Positive validator
func TestPositiveValidator(t *testing.T) {
t.Run("Int64", func(t *testing.T) {
assert.NoError(t, Positive(int64(1)))
assert.Error(t, Positive(int64(0)))
assert.Error(t, Positive(int64(-1)))
})
t.Run("Float64", func(t *testing.T) {
assert.NoError(t, Positive(0.001))
assert.Error(t, Positive(0.0))
assert.Error(t, Positive(-0.001))
})
}
// TestNonNegativeValidator tests the NonNegative validator
func TestNonNegativeValidator(t *testing.T) {
t.Run("Int64", func(t *testing.T) {
assert.NoError(t, NonNegative(int64(1)))
assert.NoError(t, NonNegative(int64(0)))
assert.Error(t, NonNegative(int64(-1)))
})
t.Run("Float64", func(t *testing.T) {
assert.NoError(t, NonNegative(0.001))
assert.NoError(t, NonNegative(0.0))
assert.Error(t, NonNegative(-0.001))
})
}
// TestIPAddressValidators tests all IP-related validators
func TestIPAddressValidators(t *testing.T) {
t.Run("IPAddress", func(t *testing.T) {
assert.NoError(t, IPAddress("192.168.1.1"))
assert.NoError(t, IPAddress("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
assert.NoError(t, IPAddress(""))
assert.NoError(t, IPAddress("0.0.0.0"))
assert.NoError(t, IPAddress("::"))
assert.Error(t, IPAddress("not-an-ip"))
assert.Error(t, IPAddress("192.168.1.256"))
})
t.Run("IPv4Address", func(t *testing.T) {
assert.NoError(t, IPv4Address("192.168.1.1"))
assert.NoError(t, IPv4Address(""))
assert.NoError(t, IPv4Address("0.0.0.0"))
assert.Error(t, IPv4Address("::1")) // Is not IPv4
assert.Error(t, IPv4Address("not-an-ip"))
})
t.Run("IPv6Address", func(t *testing.T) {
assert.NoError(t, IPv6Address("2001:db8::1"))
assert.NoError(t, IPv6Address(""))
assert.NoError(t, IPv6Address("::"))
assert.Error(t, IPv6Address("127.0.0.1")) // Is not IPv6
assert.Error(t, IPv6Address("not-an-ip"))
})
}
// TestURLPathValidator tests the URLPath validator
func TestURLPathValidator(t *testing.T) {
assert.NoError(t, URLPath("/api/v1"))
assert.NoError(t, URLPath("/"))
assert.NoError(t, URLPath(""))
assert.Error(t, URLPath("api/v1"))
assert.Error(t, URLPath("no-slash"))
}
// TestOneOfValidator tests the OneOf validator
func TestOneOfValidator(t *testing.T) {
t.Run("String", func(t *testing.T) {
validator := OneOf("prod", "dev", "staging")
assert.NoError(t, validator("prod"))
assert.NoError(t, validator("dev"))
err := validator("test")
assert.Error(t, err)
assert.Contains(t, err.Error(), "must be one of")
})
t.Run("Int", func(t *testing.T) {
validator := OneOf(200, 404, 500)
assert.NoError(t, validator(404))
err := validator(302)
assert.Error(t, err)
assert.Contains(t, err.Error(), "must be one of")
})
}
// TestRangeValidator tests the Range validator
func TestRangeValidator(t *testing.T) {
t.Run("Int64", func(t *testing.T) {
validator := Range[int64](10, 100)
assert.NoError(t, validator(10))
assert.NoError(t, validator(50))
assert.NoError(t, validator(100))
assert.Error(t, validator(9))
assert.Error(t, validator(101))
})
t.Run("Float64", func(t *testing.T) {
validator := Range[float64](-1.5, 1.5)
assert.NoError(t, validator(-1.5))
assert.NoError(t, validator(0.0))
assert.NoError(t, validator(1.5))
assert.Error(t, validator(-1.51))
assert.Error(t, validator(1.51))
})
}
// TestPatternValidator tests the Pattern validator
func TestPatternValidator(t *testing.T) {
// Simple email regex
validator := Pattern(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
assert.NoError(t, validator("test@example.com"))
assert.NoError(t, validator("user.name+alias@domain.co.uk"))
assert.Error(t, validator("not-an-email"))
assert.Error(t, validator("test@example"))
}
// TestNonEmptyValidator tests the NonEmpty validator
func TestNonEmptyValidator(t *testing.T) {
assert.NoError(t, NonEmpty("hello"))
assert.NoError(t, NonEmpty(" a "))
assert.Error(t, NonEmpty(""))
assert.Error(t, NonEmpty(" "))
assert.Error(t, NonEmpty(" \t\n "))
}

434
watch.go Normal file
View File

@ -0,0 +1,434 @@
// FILE: lixenwraith/config/watch.go
package config
import (
"context"
"fmt"
"os"
"reflect"
"sync"
"sync/atomic"
"time"
)
const DefaultMaxWatchers = 100 // Prevent resource exhaustion
// WatchOptions configures file watching behavior
type WatchOptions struct {
// PollInterval for file stat checks (minimum 100ms)
PollInterval time.Duration
// Debounce duration to avoid rapid reloads
Debounce time.Duration
// MaxWatchers limits concurrent watch channels
MaxWatchers int
// ReloadTimeout for file reload operations
ReloadTimeout time.Duration
// VerifyPermissions checks file hasn't been replaced with different permissions
VerifyPermissions bool
}
// DefaultWatchOptions returns sensible defaults for file watching
func DefaultWatchOptions() WatchOptions {
return WatchOptions{
PollInterval: DefaultPollInterval,
Debounce: DefaultDebounce,
MaxWatchers: DefaultMaxWatchers,
ReloadTimeout: DefaultReloadTimeout,
VerifyPermissions: true,
}
}
// watcher manages file watching state
type watcher struct {
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
opts WatchOptions
filePath string
lastModTime time.Time
lastSize int64
lastMode os.FileMode
watching atomic.Bool
reloadInProgress atomic.Bool
watchers map[int64]chan string // subscriber channels
watcherID atomic.Int64
debounceTimer *time.Timer
}
// configWatcher extends Config with watching capabilities
type configWatcher struct {
*Config
watcher *watcher
}
// AutoUpdate enables automatic configuration reloading when the file changes
func (c *Config) AutoUpdate() {
c.AutoUpdateWithOptions(DefaultWatchOptions())
}
// AutoUpdateWithOptions enables automatic configuration reloading with custom options
func (c *Config) AutoUpdateWithOptions(opts WatchOptions) {
// Validate options
if opts.PollInterval < MinPollInterval {
opts.PollInterval = MinPollInterval
}
if opts.MaxWatchers <= 0 {
opts.MaxWatchers = 100
}
if opts.ReloadTimeout <= 0 {
opts.ReloadTimeout = DefaultReloadTimeout
}
c.mutex.Lock()
defer c.mutex.Unlock()
// Get path of current file to watch
filePath := c.getConfigFilePath()
if filePath == "" {
// No file configured, nothing to watch
return
}
// Stop existing watcher if path changed
if c.watcher != nil && c.watcher.filePath != filePath {
c.watcher.stop()
c.watcher = nil
}
// Initialize watcher if needed
if c.watcher == nil {
ctx, cancel := context.WithCancel(context.Background())
c.watcher = &watcher{
ctx: ctx,
cancel: cancel,
opts: opts,
filePath: filePath,
watchers: make(map[int64]chan string),
}
// Get initial file state
if info, err := os.Stat(filePath); err == nil {
c.watcher.lastModTime = info.ModTime()
c.watcher.lastSize = info.Size()
c.watcher.lastMode = info.Mode()
}
// Start watching
go c.watcher.watchLoop(c)
}
}
// StopAutoUpdate stops automatic configuration reloading
func (c *Config) StopAutoUpdate() {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.watcher != nil {
c.watcher.stop()
c.watcher = nil
}
}
// Watch returns a channel that receives paths of changed configuration values
func (c *Config) Watch() <-chan string {
return c.WatchWithOptions(DefaultWatchOptions())
}
// WatchFile stops any existing file watcher, loads a new configuration file,
// and starts a new watcher on that file path. Optionally accepts format hint.
func (c *Config) WatchFile(filePath string, formatHint ...string) error {
// Stop any currently running watcher
c.StopAutoUpdate()
// Set format hint if provided
if len(formatHint) > 0 {
if err := c.SetFileFormat(formatHint[0]); err != nil {
return fmt.Errorf("invalid format hint: %w", err)
}
}
// Load the new file
if err := c.LoadFile(filePath); err != nil {
return fmt.Errorf("failed to load new file for watching: %w", err)
}
// Get previous watcher options if available
c.mutex.RLock()
opts := DefaultWatchOptions()
if c.watcher != nil {
opts = c.watcher.opts
}
c.mutex.RUnlock()
// Start new watcher (AutoUpdateWithOptions will create a new watcher with the new file path)
c.AutoUpdateWithOptions(opts)
return nil
}
// WatchWithOptions returns a channel with custom watch options
// should not restart the watcher if it's already running with the same file
func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string {
c.mutex.RLock()
watcher := c.watcher
filePath := c.configFilePath
c.mutex.RUnlock()
// If no file configured, return closed channel
if filePath == "" {
ch := make(chan string)
close(ch)
return ch
}
// If watcher exists and is watching the current file, just subscribe
if watcher != nil && watcher.filePath == filePath && watcher.watching.Load() {
return watcher.subscribe()
}
// First ensure auto-update is running
c.AutoUpdateWithOptions(opts)
c.mutex.RLock()
watcher = c.watcher
c.mutex.RUnlock()
if watcher == nil {
// No file to watch, return closed channel
ch := make(chan string)
close(ch)
return ch
}
return watcher.subscribe()
}
// IsWatching returns true if auto-update is enabled
func (c *Config) IsWatching() bool {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.watcher != nil && c.watcher.watching.Load()
}
// WatcherCount returns the number of active watch channels
func (c *Config) WatcherCount() int {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.watcher == nil {
return 0
}
c.watcher.mu.RLock()
defer c.watcher.mu.RUnlock()
return len(c.watcher.watchers)
}
// watchLoop is the main file watching loop
func (w *watcher) watchLoop(c *Config) {
if !w.watching.CompareAndSwap(false, true) {
return // Already watching
}
defer w.watching.Store(false)
ticker := time.NewTicker(w.opts.PollInterval)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
w.checkAndReload(c)
}
}
}
// checkAndReload checks if file changed and triggers reload
func (w *watcher) checkAndReload(c *Config) {
info, err := os.Stat(w.filePath)
if err != nil {
if os.IsNotExist(err) {
// File was deleted, notify watchers
w.notifyWatchers("file_deleted")
}
return
}
// Check for changes
changed := false
// Compare modification time and size
if !info.ModTime().Equal(w.lastModTime) || info.Size() != w.lastSize {
changed = true
}
// SECURITY: Verify permissions haven't changed suspiciously
if w.opts.VerifyPermissions && w.lastMode != 0 {
if info.Mode() != w.lastMode {
// Permission change detected
if (info.Mode() & 0077) != (w.lastMode & 0077) {
// World/group permissions changed - potential security issue
w.notifyWatchers("permissions_changed")
// Don't reload on permission change for security
return
}
}
}
if changed {
// Update tracked state
w.lastModTime = info.ModTime()
w.lastSize = info.Size()
w.lastMode = info.Mode()
// Debounce rapid changes
w.mu.Lock()
if w.debounceTimer != nil {
w.debounceTimer.Stop()
}
w.debounceTimer = time.AfterFunc(w.opts.Debounce, func() {
w.performReload(c)
})
w.mu.Unlock()
}
}
// performReload reloads the configuration file
func (w *watcher) performReload(c *Config) {
// Prevent concurrent reloads
if !w.reloadInProgress.CompareAndSwap(false, true) {
return
}
defer w.reloadInProgress.Store(false)
// Create a timeout context for reload
ctx, cancel := context.WithTimeout(w.ctx, w.opts.ReloadTimeout)
defer cancel()
// Track what changed
oldValues := c.snapshot()
// Reload file in a goroutine with timeout
done := make(chan error, 1)
go func() {
done <- c.loadFile(w.filePath)
}()
select {
case err := <-done:
if err != nil {
// Reload failed, notify error
w.notifyWatchers(fmt.Sprintf("reload_error:%v", err))
return
}
// Compare and notify changes
newValues := c.snapshot()
for path, newVal := range newValues {
if oldVal, existed := oldValues[path]; !existed || !reflect.DeepEqual(oldVal, newVal) {
w.notifyWatchers(path)
}
}
// Check for deletions
for path := range oldValues {
if _, exists := newValues[path]; !exists {
w.notifyWatchers(path)
}
}
case <-ctx.Done():
// Reload timeout
w.notifyWatchers("reload_timeout")
}
}
// subscribe creates a new watcher channel
func (w *watcher) subscribe() <-chan string {
w.mu.Lock()
defer w.mu.Unlock()
// Check watcher limit
if len(w.watchers) >= w.opts.MaxWatchers {
// Return closed channel to prevent resource exhaustion
ch := make(chan string)
close(ch)
return ch
}
// Create buffered channel to prevent blocking
ch := make(chan string, 10)
id := w.watcherID.Add(1)
w.watchers[id] = ch
// Cleanup goroutine
go func() {
<-w.ctx.Done()
w.mu.Lock()
delete(w.watchers, id)
close(ch)
w.mu.Unlock()
}()
return ch
}
// notifyWatchers sends change notification to all subscribers
func (w *watcher) notifyWatchers(path string) {
w.mu.RLock()
defer w.mu.RUnlock()
for id, ch := range w.watchers {
select {
case ch <- path:
// Sent successfully
default:
// Channel full or closed, skip
// Could implement removal of dead watchers here
_ = id
}
}
}
// stop terminates the watcher
func (w *watcher) stop() {
if w.cancel != nil {
w.cancel()
}
// Stop debounce timer
w.mu.Lock()
if w.debounceTimer != nil {
w.debounceTimer.Stop()
w.debounceTimer = nil
}
w.mu.Unlock()
// Wait for watch loop to exit with timeout
deadline := time.Now().Add(ShutdownTimeout)
for w.watching.Load() && time.Now().Before(deadline) {
time.Sleep(SpinWaitInterval)
}
}
// getConfigFilePath returns the current config file path
func (c *Config) getConfigFilePath() string {
// Access the tracked config file path
return c.configFilePath
}
// snapshot creates a snapshot of current values
func (c *Config) snapshot() map[string]any {
c.mutex.RLock()
defer c.mutex.RUnlock()
snapshot := make(map[string]any, len(c.items))
for path, item := range c.items {
snapshot[path] = item.currentValue
}
return snapshot
}

520
watch_test.go Normal file
View File

@ -0,0 +1,520 @@
// FILE: lixenwraith/config/watch_test.go
package config
import (
"fmt"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test-specific timing constants derived from production values.
// These accelerate test execution while maintaining timing relationships.
const (
// testAcceleration reduces all intervals by this factor for faster tests
testAcceleration = 10
// Accelerated test timings
testPollInterval = DefaultPollInterval / testAcceleration // 100ms (from 1s)
testDebounce = DefaultDebounce / testAcceleration // 50ms (from 500ms)
testReloadTimeout = DefaultReloadTimeout / testAcceleration // 500ms (from 5s)
testShutdownTimeout = ShutdownTimeout // Keep original for safety
testSpinWaitInterval = SpinWaitInterval // Keep original for CPU efficiency
// Test assertion timeouts
testEventuallyTimeout = testReloadTimeout // Aligns with reload timing
testWatchTimeout = 2 * DefaultPollInterval // 2s for change propagation
// Derived test multipliers with clear purpose
testDebounceSettle = debounceSettleMultiplier * testDebounce // 150ms for debounce verification
testPollWindow = 3 * testPollInterval // 300ms change detection window
testStateStabilize = 4 * testDebounce // 200ms for state convergence
)
// TestAutoUpdate tests automatic configuration reloading
func TestAutoUpdate(t *testing.T) {
// Setup
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
initialConfig := `
[server]
port = 8080
host = "localhost"
[features]
enabled = true
`
require.NoError(t, os.WriteFile(configPath, []byte(initialConfig), 0644))
// Create config with defaults
type TestConfig struct {
Server struct {
Port int `toml:"port"`
Host string `toml:"host"`
} `toml:"server"`
Features struct {
Enabled bool `toml:"enabled"`
} `toml:"features"`
}
defaults := &TestConfig{}
defaults.Server.Port = 3000
defaults.Server.Host = "0.0.0.0"
// Build config
cfg, err := NewBuilder().
WithDefaults(defaults).
WithFile(configPath).
Build()
require.NoError(t, err)
// Verify initial values
port, exists := cfg.Get("server.port")
assert.True(t, exists)
assert.Equal(t, int64(8080), port)
// Enable auto-update with fast polling
opts := WatchOptions{
PollInterval: testPollInterval,
Debounce: testDebounce,
MaxWatchers: 10,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
// Start watching
changes := cfg.Watch()
// Collect changes
var mu sync.Mutex
changedPaths := make(map[string]bool)
go func() {
for path := range changes {
mu.Lock()
changedPaths[path] = true
mu.Unlock()
}
}()
// Update config file
updatedConfig := `
[server]
port = 9090
host = "0.0.0.0"
[features]
enabled = false
`
require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644))
// Wait for changes to be detected
time.Sleep(testPollWindow)
// Verify new values
port, _ = cfg.Get("server.port")
assert.Equal(t, int64(9090), port)
host, _ := cfg.Get("server.host")
assert.Equal(t, "0.0.0.0", host)
enabled, _ := cfg.Get("features.enabled")
assert.Equal(t, false, enabled)
// Check that changes were notified
mu.Lock()
defer mu.Unlock()
expectedChanges := []string{"server.port", "server.host", "features.enabled"}
for _, path := range expectedChanges {
assert.True(t, changedPaths[path], "Expected change notification for %s", path)
}
}
// TestWatchFileDeleted tests behavior when config file is deleted
func TestWatchFileDeleted(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
// Create initial config
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
cfg := New()
cfg.Register("test", "default")
require.NoError(t, cfg.LoadFile(configPath))
// Enable watching
opts := WatchOptions{
PollInterval: testPollInterval,
Debounce: testDebounce,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
changes := cfg.Watch()
// Delete file
require.NoError(t, os.Remove(configPath))
// Wait for deletion detection
select {
case path := <-changes:
assert.Equal(t, "file_deleted", path)
case <-time.After(testEventuallyTimeout):
t.Error("Timeout waiting for deletion notification")
}
}
// TestWatchPermissionChange tests permission change detection
func TestWatchPermissionChange(t *testing.T) {
// Skip on Windows where permission model is different
if runtime.GOOS == "windows" {
t.Skip("Skipping permission test on Windows")
}
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
// Create config with specific permissions
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
cfg := New()
cfg.Register("test", "default")
require.NoError(t, cfg.LoadFile(configPath))
// Enable watching with permission verification
opts := WatchOptions{
PollInterval: testPollInterval,
Debounce: testDebounce,
VerifyPermissions: true,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
changes := cfg.Watch()
// Change permissions to world-writable (security risk)
require.NoError(t, os.Chmod(configPath, 0666))
// Wait for permission change detection
select {
case path := <-changes:
assert.Equal(t, "permissions_changed", path)
case <-time.After(testEventuallyTimeout):
t.Error("Timeout waiting for permission change notification")
}
}
// TestMaxWatchers tests watcher limit enforcement
func TestMaxWatchers(t *testing.T) {
cfg := New()
cfg.Register("test", "value")
// Create config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
require.NoError(t, cfg.LoadFile(configPath))
// Enable watching with low max watchers
opts := WatchOptions{
PollInterval: testPollInterval,
MaxWatchers: 3,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
// Create maximum allowed watchers
channels := make([]<-chan string, 0, 4)
for i := 0; i < 4; i++ {
ch := cfg.Watch()
channels = append(channels, ch)
// Check if channel is open
if i < 3 {
// First 3 should be open
select {
case _, ok := <-ch:
assert.True(t, ok || i < 3, "Channel %d should be open", i)
default:
// Channel is open and empty, expected
}
} else {
// 4th should be closed immediately
select {
case _, ok := <-ch:
assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)")
case <-time.After(testEventuallyTimeout):
t.Error("Channel 3 should be closed immediately")
}
}
}
// Verify watcher count
assert.Equal(t, 3, cfg.WatcherCount())
}
// TestRapidDebounce tests that rapid changes are debounced
func TestRapidDebounce(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
// Create initial config
require.NoError(t, os.WriteFile(configPath, []byte(`value = 1`), 0644))
cfg := New()
cfg.Register("value", 0)
require.NoError(t, cfg.LoadFile(configPath))
// Enable watching with longer debounce
opts := WatchOptions{
PollInterval: testDebounce,
Debounce: testStateStabilize,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
changes := cfg.Watch()
var changeCount int
var mu sync.Mutex
done := make(chan bool)
go func() {
for {
select {
case <-changes:
mu.Lock()
changeCount++
mu.Unlock()
case <-done:
return
}
}
}()
// Make rapid changes
for i := 2; i <= 5; i++ {
content := fmt.Sprintf(`value = %d`, i)
require.NoError(t, os.WriteFile(configPath, []byte(content), 0644))
time.Sleep(testDebounce) // Less than debounce period
}
// Wait for debounce to complete
time.Sleep(2 * testStateStabilize)
done <- true
// Should only see one change due to debounce
mu.Lock()
defer mu.Unlock()
assert.Equal(t, 1, changeCount, "Expected 1 change due to debounce, got %d", changeCount)
// Verify final value
val, _ := cfg.Get("value")
assert.Equal(t, int64(5), val)
}
// TestWatchWithoutFile tests watching behavior when no file is configured
func TestWatchWithoutFile(t *testing.T) {
cfg := New()
cfg.Register("test", "value")
// No file loaded, watch should return closed channel
ch := cfg.Watch()
select {
case _, ok := <-ch:
assert.False(t, ok, "Channel should be closed when no file to watch")
case <-time.After(10 * time.Millisecond):
t.Error("Channel should be closed immediately")
}
assert.False(t, cfg.IsWatching())
}
// TestConcurrentWatchOperations tests thread safety of watch operations
func TestConcurrentWatchOperations(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
require.NoError(t, os.WriteFile(configPath, []byte(`value = 1`), 0644))
cfg := New()
cfg.Register("value", 0)
require.NoError(t, cfg.LoadFile(configPath))
opts := WatchOptions{
PollInterval: testDebounce,
MaxWatchers: 50,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
var wg sync.WaitGroup
errors := make(chan error, 100)
// Start multiple watchers concurrently
for i := 0; i < 20; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
ch := cfg.Watch()
if ch == nil {
errors <- fmt.Errorf("watcher %d: got nil channel", id)
return
}
// Try to receive
select {
case <-ch:
// OK, got a change
case <-time.After(2 * SpinWaitInterval):
// OK, no changes yet
}
}(i)
}
// Concurrent config updates
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
content := fmt.Sprintf(`value = %d`, id+10)
if err := os.WriteFile(configPath, []byte(content), 0644); err != nil {
errors <- fmt.Errorf("writer %d: %v", id, err)
}
}(i)
}
// Check IsWatching concurrently
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
isWatching := false
for j := 0; j < 5; j++ { // Poll a few times, double-dip wait for goroutine to start
if cfg.IsWatching() {
isWatching = true
break
}
time.Sleep(2 * SpinWaitInterval)
}
if !isWatching {
errors <- fmt.Errorf("checker %d: IsWatching returned false", id)
}
}(i)
}
wg.Wait()
close(errors)
// Check for errors
var errs []error
for err := range errors {
errs = append(errs, err)
}
assert.Empty(t, errs, "Concurrent operations should not produce errors")
}
// TestReloadTimeout tests reload timeout handling
func TestReloadTimeout(t *testing.T) {
// This test would require mocking file operations to simulate a slow read
// For now, we'll test that timeout option is respected in configuration
cfg := New()
cfg.Register("test", "value")
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
require.NoError(t, cfg.LoadFile(configPath))
// Very short timeout
opts := WatchOptions{
PollInterval: testPollInterval,
ReloadTimeout: 1 * time.Nanosecond,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
waitForWatchingState(t, cfg, true)
}
// TestStopAutoUpdate tests clean shutdown of watcher
func TestStopAutoUpdate(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
require.NoError(t, os.WriteFile(configPath, []byte(`test = "value"`), 0644))
cfg := New()
cfg.Register("test", "value")
require.NoError(t, cfg.LoadFile(configPath))
// Start watching
cfg.AutoUpdate()
waitForWatchingState(t, cfg, true, "Watcher should be active after first start")
ch := cfg.Watch()
// Stop watching
cfg.StopAutoUpdate()
// Verify stopped
waitForWatchingState(t, cfg, false, "Watcher should be inactive after stop")
assert.Equal(t, 0, cfg.WatcherCount())
// Channel should eventually close
select {
case _, ok := <-ch:
assert.False(t, ok, "Channel should be closed after stop")
case <-time.After(ShutdownTimeout):
// OK, channel might not close immediately
}
// Starting again should work
cfg.AutoUpdate()
waitForWatchingState(t, cfg, true, "Watcher should be active after restart")
cfg.StopAutoUpdate()
}
// BenchmarkWatchOverhead benchmarks the overhead of file watching
func BenchmarkWatchOverhead(b *testing.B) {
tmpDir := b.TempDir()
configPath := filepath.Join(tmpDir, "bench.toml")
// Create config with many values
var configContent string
for i := 0; i < 100; i++ {
configContent += fmt.Sprintf("value%d = %d\n", i, i)
}
require.NoError(b, os.WriteFile(configPath, []byte(configContent), 0644))
cfg := New()
for i := 0; i < 100; i++ {
cfg.Register(fmt.Sprintf("value%d", i), 0)
}
require.NoError(b, cfg.LoadFile(configPath))
// Enable watching
opts := WatchOptions{
PollInterval: testPollInterval,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
// Benchmark value retrieval with watching enabled
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = cfg.Get(fmt.Sprintf("value%d", i%100))
}
}