e3.0.0 Added env variable support, improved cli arg, added tests, updated documentation.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,5 +1,6 @@
|
||||
.idea
|
||||
data
|
||||
dev
|
||||
logs
|
||||
log
|
||||
*.log
|
||||
bin
|
||||
|
||||
305
README.md
305
README.md
@ -1,18 +1,6 @@
|
||||
# Config
|
||||
|
||||
A simple, thread-safe configuration management package for Go applications that supports TOML files, command-line argument overrides, and registered default values.
|
||||
|
||||
## Features
|
||||
|
||||
- **Thread-Safe Operations:** Uses `sync.RWMutex` to protect concurrent access during all configuration operations.
|
||||
- **TOML Configuration:** Uses [BurntSushi/toml](https://github.com/BurntSushi/toml) for loading and saving configuration files.
|
||||
- **Command-Line Overrides:** Allows overriding configuration values using dot notation in CLI arguments (e.g., `--server.port 9090`).
|
||||
- **Path-Based Access:** Register configuration paths with default values for direct, consistent access with clear error messages.
|
||||
- **Struct Registration:** Register an entire struct as configuration defaults, using struct tags to determine paths.
|
||||
- **Atomic File Operations:** Ensures configuration files are written atomically to prevent corruption.
|
||||
- **Path Validation:** Validates configuration path segments against TOML key requirements.
|
||||
- **Type Conversions:** Helper methods for converting configuration values to common Go types with detailed error messages.
|
||||
- **Hierarchical Data Management:** Automatically handles nested structures through dot notation.
|
||||
Thread-safe configuration management for Go with support for TOML files, environment variables, command-line arguments, and defaults with configurable precedence.
|
||||
|
||||
## Installation
|
||||
|
||||
@ -20,196 +8,173 @@ A simple, thread-safe configuration management package for Go applications that
|
||||
go get github.com/LixenWraith/config
|
||||
```
|
||||
|
||||
Dependencies will be automatically fetched:
|
||||
```
|
||||
github.com/BurntSushi/toml
|
||||
github.com/mitchellh/mapstructure
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage Pattern
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
// 1. Initialize a new Config instance
|
||||
cfg := config.New()
|
||||
package main
|
||||
|
||||
// 2. Register configuration paths with default values
|
||||
cfg.Register("server.host", "127.0.0.1")
|
||||
cfg.Register("server.port", 8080)
|
||||
import (
|
||||
"log"
|
||||
|
||||
// 3. Load configuration from file with CLI argument overrides
|
||||
err := cfg.Load("app_config.toml", os.Args[1:])
|
||||
if err != nil {
|
||||
if errors.Is(err, config.ErrConfigNotFound) {
|
||||
log.Println("Config file not found, using defaults")
|
||||
} else {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
"github.com/lixenwraith/config"
|
||||
)
|
||||
|
||||
// 4. Access configuration values using the registered paths
|
||||
serverHost, err := cfg.String("server.host")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
serverPort, err := cfg.Int64("server.port")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// 5. Save configuration (creates the file if it doesn't exist)
|
||||
err = cfg.Save("app_config.toml")
|
||||
```
|
||||
|
||||
### Struct-Based Registration
|
||||
|
||||
```go
|
||||
// Define a configuration struct with TOML tags
|
||||
type ServerConfig struct {
|
||||
Host string `toml:"host"`
|
||||
Port int64 `toml:"port"`
|
||||
Timeout int64 `toml:"timeout"`
|
||||
Debug bool `toml:"debug"`
|
||||
}
|
||||
|
||||
// Create default configuration
|
||||
defaults := ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Timeout: 30,
|
||||
Debug: false,
|
||||
}
|
||||
|
||||
// Register the entire struct at once
|
||||
err := cfg.RegisterStruct("server.", defaults)
|
||||
```
|
||||
|
||||
### Accessing Typed Values
|
||||
|
||||
```go
|
||||
// Use type-specific accessor methods
|
||||
port, err := cfg.Int64("server.port")
|
||||
debug, err := cfg.Bool("debug")
|
||||
rate, err := cfg.Float64("rate.limit")
|
||||
name, err := cfg.String("server.name")
|
||||
```
|
||||
|
||||
### Using Scan to Populate Structs
|
||||
|
||||
```go
|
||||
// Define a struct matching your configuration
|
||||
type AppConfig struct {
|
||||
ServerName string `toml:"name"`
|
||||
ServerPort int64 `toml:"port"`
|
||||
Server struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
} `toml:"server"`
|
||||
Database struct {
|
||||
URL string `toml:"url"`
|
||||
MaxConns int `toml:"max_conns"`
|
||||
} `toml:"database"`
|
||||
Debug bool `toml:"debug"`
|
||||
}
|
||||
|
||||
// Create an instance to receive the configuration
|
||||
var appConfig AppConfig
|
||||
func main() {
|
||||
// Define defaults
|
||||
defaults := AppConfig{}
|
||||
defaults.Server.Host = "localhost"
|
||||
defaults.Server.Port = 8080
|
||||
defaults.Database.URL = "postgres://localhost/myapp"
|
||||
defaults.Database.MaxConns = 10
|
||||
|
||||
// Scan the configuration into the struct
|
||||
err := cfg.Scan("server", &appConfig)
|
||||
// Initialize with environment prefix and config file
|
||||
cfg, err := config.Quick(defaults, "MYAPP_", "config.toml")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Access values
|
||||
host, _ := cfg.String("server.host")
|
||||
port, _ := cfg.Int64("server.port")
|
||||
dbURL, _ := cfg.String("database.url")
|
||||
debug, _ := cfg.Bool("debug")
|
||||
|
||||
log.Printf("Server: %s:%d, DB: %s, Debug: %v", host, port, dbURL, debug)
|
||||
}
|
||||
```
|
||||
|
||||
## API
|
||||
**config.toml:**
|
||||
```toml
|
||||
[server]
|
||||
host = "production.example.com"
|
||||
port = 9090
|
||||
|
||||
### `New() *Config`
|
||||
[database]
|
||||
url = "postgres://prod-db/myapp"
|
||||
max_conns = 50
|
||||
|
||||
Creates and returns a new, initialized `*Config` instance ready for use.
|
||||
debug = false
|
||||
```
|
||||
|
||||
### `(*Config) Register(path string, defaultValue any) error`
|
||||
**Usage:**
|
||||
```bash
|
||||
# Override with environment variables
|
||||
export MYAPP_SERVER_PORT=8443
|
||||
export MYAPP_DEBUG=true
|
||||
|
||||
Registers a configuration path with a default value.
|
||||
# Override with CLI arguments
|
||||
./myapp --server.port=9999 --debug
|
||||
```
|
||||
|
||||
- **path**: Dot-separated path corresponding to the TOML structure. Each segment must be a valid TOML key.
|
||||
- **defaultValue**: The value returned if no other value has been set through Load or Set.
|
||||
- **Returns**: Error (nil on success)
|
||||
## Key Features
|
||||
|
||||
### `(*Config) RegisterStruct(prefix string, structWithDefaults interface{}) error`
|
||||
- **Multiple Sources**: Defaults → File → Environment → CLI (configurable order)
|
||||
- **Type Safety**: Automatic conversion with detailed error messages
|
||||
- **Thread-Safe**: Concurrent reads with protected writes
|
||||
- **Builder Pattern**: Fluent interface for advanced configuration
|
||||
- **Source Tracking**: See which source provided each value
|
||||
- **Zero Dependencies**: Only stdlib + minimal parsers
|
||||
|
||||
Registers all fields of a struct as configuration paths, using struct tags to determine the paths.
|
||||
## Common Patterns
|
||||
|
||||
- **prefix**: Prefix to prepend to all generated paths (e.g., "server.").
|
||||
- **structWithDefaults**: Struct containing default values. Fields must have `toml` tags.
|
||||
- **Returns**: Error if registration fails for any field.
|
||||
### Custom Precedence
|
||||
```go
|
||||
cfg, _ := config.NewBuilder().
|
||||
WithDefaults(defaults).
|
||||
WithSources(
|
||||
config.SourceEnv, // Env vars highest priority
|
||||
config.SourceFile,
|
||||
config.SourceCLI,
|
||||
config.SourceDefault,
|
||||
).
|
||||
Build()
|
||||
```
|
||||
|
||||
### `(*Config) GetRegisteredPaths(prefix string) map[string]bool`
|
||||
### Environment Variable Mapping
|
||||
```go
|
||||
// Custom env var names
|
||||
opts := config.LoadOptions{
|
||||
EnvTransform: func(path string) string {
|
||||
switch path {
|
||||
case "server.port": return "PORT"
|
||||
case "database.url": return "DATABASE_URL"
|
||||
default: return ""
|
||||
}
|
||||
},
|
||||
}
|
||||
cfg.LoadWithOptions("config.toml", os.Args[1:], opts)
|
||||
```
|
||||
|
||||
Returns all registered configuration paths that start with the given prefix.
|
||||
### Validation
|
||||
```go
|
||||
// Register and validate required fields
|
||||
cfg.RegisterRequired("api.key", "")
|
||||
cfg.RegisterRequired("database.url", "")
|
||||
|
||||
- **prefix**: Path prefix to filter by (e.g., "server.").
|
||||
- **Returns**: Map where keys are the registered paths that match the prefix.
|
||||
if err := cfg.Validate("api.key", "database.url"); err != nil {
|
||||
log.Fatal("Missing required config: ", err)
|
||||
}
|
||||
```
|
||||
|
||||
### `(*Config) Get(path string) (any, bool)`
|
||||
### Source Inspection
|
||||
```go
|
||||
// See all sources for a value
|
||||
sources := cfg.GetSources("server.port")
|
||||
for source, value := range sources {
|
||||
fmt.Printf("%s: %v\n", source, value)
|
||||
}
|
||||
|
||||
Retrieves a configuration value using the registered path.
|
||||
// Get value from specific source
|
||||
envPort, exists := cfg.GetSource("server.port", config.SourceEnv)
|
||||
```
|
||||
|
||||
- **path**: The dot-separated path string used during registration.
|
||||
- **Returns**: The configuration value and a boolean indicating if the path was registered.
|
||||
- **Value precedence**: CLI Argument > Config File Value > Registered Default Value
|
||||
### Struct Scanning
|
||||
```go
|
||||
var serverConfig struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
}
|
||||
cfg.Scan("server", &serverConfig)
|
||||
```
|
||||
|
||||
### `(*Config) String(path string) (string, error)`
|
||||
### `(*Config) Int64(path string) (int64, error)`
|
||||
### `(*Config) Bool(path string) (bool, error)`
|
||||
### `(*Config) Float64(path string) (float64, error)`
|
||||
### Environment Whitelist
|
||||
```go
|
||||
// Only load specific env vars
|
||||
cfg, _ := config.NewBuilder().
|
||||
WithDefaults(defaults).
|
||||
WithEnvPrefix("MYAPP_").
|
||||
WithEnvWhitelist("api.key", "database.password").
|
||||
Build()
|
||||
```
|
||||
|
||||
Type-specific accessor methods that retrieve and attempt to convert configuration values to the desired type.
|
||||
## API Reference
|
||||
|
||||
- **path**: The dot-separated path string used during registration.
|
||||
- **Returns**: The typed value and an error (nil on success).
|
||||
- **Errors**: Detailed error messages when:
|
||||
- The path is not registered
|
||||
- The value cannot be converted to the requested type
|
||||
- Type conversion fails (with the specific reason)
|
||||
### Core Methods
|
||||
- `Quick(defaults, envPrefix, configFile)` - Quick initialization
|
||||
- `Register(path, defaultValue)` - Register configuration path
|
||||
- `Get/String/Int64/Bool/Float64(path)` - Type-safe accessors
|
||||
- `Set(path, value)` - Update configuration
|
||||
- `Validate(paths...)` - Ensure required values are set
|
||||
|
||||
### `(*Config) Set(path string, value any) error`
|
||||
|
||||
Updates a configuration value using the registered path.
|
||||
|
||||
- **path**: The dot-separated path string used during registration.
|
||||
- **value**: The new value to set.
|
||||
- **Returns**: Error if the path wasn't registered or if setting the value fails.
|
||||
|
||||
### `(*Config) Unregister(path string) error`
|
||||
|
||||
Removes a configuration path and all its children from the configuration.
|
||||
|
||||
- **path**: The dot-separated path string used during registration.
|
||||
- **Effects**:
|
||||
- Removes the specified path
|
||||
- Recursively removes all child paths (e.g., unregistering "server" also removes "server.host", "server.port", etc.)
|
||||
- Completely removes both registration and data
|
||||
- **Returns**: Error if the path wasn't registered.
|
||||
|
||||
### `(*Config) Scan(basePath string, target any) error`
|
||||
|
||||
Decodes a section of the configuration into a struct or map.
|
||||
|
||||
- **basePath**: Dot-separated path to the configuration subtree.
|
||||
- **target**: Pointer to a struct or map where the configuration should be unmarshaled.
|
||||
- **Returns**: Error if unmarshaling fails.
|
||||
|
||||
### `(*Config) Load(filePath string, args []string) error`
|
||||
|
||||
Loads configuration from a TOML file and merges overrides from command-line arguments.
|
||||
|
||||
- **filePath**: Path to the TOML configuration file.
|
||||
- **args**: Command-line arguments (e.g., `os.Args[1:]`).
|
||||
- **Returns**: Error on failure, which can be checked with:
|
||||
- `errors.Is(err, config.ErrConfigNotFound)` to detect missing file
|
||||
- `errors.Is(err, config.ErrCLIParse)` to detect CLI parsing errors
|
||||
|
||||
### `(*Config) Save(filePath string) error`
|
||||
|
||||
Saves the current configuration to the specified TOML file path, performing an atomic write.
|
||||
|
||||
- **filePath**: Path where the TOML configuration file will be written.
|
||||
- **Returns**: Error if marshaling or file operations fail, nil on success.
|
||||
### Advanced Methods
|
||||
- `NewBuilder()` - Create custom configuration
|
||||
- `GetSource(path, source)` - Get value from specific source
|
||||
- `GetSources(path)` - Get all source values
|
||||
- `Scan(basePath, target)` - Unmarshal into struct
|
||||
- `Clone()` - Deep copy configuration
|
||||
- `Debug()` - Show all values and sources
|
||||
|
||||
## License
|
||||
|
||||
|
||||
110
builder.go
Normal file
110
builder.go
Normal file
@ -0,0 +1,110 @@
|
||||
// File: lixenwraith/config/builder.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Builder provides a fluent interface for building configurations
|
||||
type Builder struct {
|
||||
cfg *Config
|
||||
opts LoadOptions
|
||||
defaults interface{}
|
||||
prefix string
|
||||
file string
|
||||
args []string
|
||||
err error
|
||||
}
|
||||
|
||||
// NewBuilder creates a new configuration builder
|
||||
func NewBuilder() *Builder {
|
||||
return &Builder{
|
||||
cfg: New(),
|
||||
opts: DefaultLoadOptions(),
|
||||
args: os.Args[1:],
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefaults sets the struct containing default values
|
||||
func (b *Builder) WithDefaults(defaults interface{}) *Builder {
|
||||
b.defaults = defaults
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPrefix sets the prefix for struct registration
|
||||
func (b *Builder) WithPrefix(prefix string) *Builder {
|
||||
b.prefix = prefix
|
||||
return b
|
||||
}
|
||||
|
||||
// WithEnvPrefix sets the environment variable prefix
|
||||
func (b *Builder) WithEnvPrefix(prefix string) *Builder {
|
||||
b.opts.EnvPrefix = prefix
|
||||
return b
|
||||
}
|
||||
|
||||
// WithFile sets the configuration file path
|
||||
func (b *Builder) WithFile(path string) *Builder {
|
||||
b.file = path
|
||||
return b
|
||||
}
|
||||
|
||||
// WithArgs sets the command-line arguments
|
||||
func (b *Builder) WithArgs(args []string) *Builder {
|
||||
b.args = args
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSources sets the precedence order for configuration sources
|
||||
func (b *Builder) WithSources(sources ...Source) *Builder {
|
||||
b.opts.Sources = sources
|
||||
return b
|
||||
}
|
||||
|
||||
// WithEnvTransform sets a custom environment variable transformer
|
||||
func (b *Builder) WithEnvTransform(fn EnvTransformFunc) *Builder {
|
||||
b.opts.EnvTransform = fn
|
||||
return b
|
||||
}
|
||||
|
||||
// WithEnvWhitelist limits which paths are checked for env vars
|
||||
func (b *Builder) WithEnvWhitelist(paths ...string) *Builder {
|
||||
if b.opts.EnvWhitelist == nil {
|
||||
b.opts.EnvWhitelist = make(map[string]bool)
|
||||
}
|
||||
for _, path := range paths {
|
||||
b.opts.EnvWhitelist[path] = true
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Build creates the Config instance with all specified options
|
||||
func (b *Builder) Build() (*Config, error) {
|
||||
if b.err != nil {
|
||||
return nil, b.err
|
||||
}
|
||||
|
||||
// Register defaults if provided
|
||||
if b.defaults != nil {
|
||||
if err := b.cfg.RegisterStruct(b.prefix, b.defaults); err != nil {
|
||||
return nil, fmt.Errorf("failed to register defaults: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
if err := b.cfg.LoadWithOptions(b.file, b.args, b.opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.cfg, nil
|
||||
}
|
||||
|
||||
// MustBuild is like Build but panics on error
|
||||
func (b *Builder) MustBuild() *Config {
|
||||
cfg, err := b.Build()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("config build failed: %v", err))
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
271
builder_test.go
Normal file
271
builder_test.go
Normal file
@ -0,0 +1,271 @@
|
||||
// File: lixenwraith/config/builder_test.go
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/lixenwraith/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuilder(t *testing.T) {
|
||||
t.Run("Basic Builder", func(t *testing.T) {
|
||||
type AppConfig struct {
|
||||
Name string `toml:"name"`
|
||||
Version string `toml:"version"`
|
||||
Debug bool `toml:"debug"`
|
||||
}
|
||||
|
||||
defaults := AppConfig{
|
||||
Name: "testapp",
|
||||
Version: "1.0.0",
|
||||
Debug: false,
|
||||
}
|
||||
|
||||
cfg, err := config.NewBuilder().
|
||||
WithDefaults(defaults).
|
||||
WithPrefix("app.").
|
||||
Build()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check registered paths
|
||||
paths := cfg.GetRegisteredPaths("app.")
|
||||
assert.Len(t, paths, 3)
|
||||
|
||||
// Check values
|
||||
name, err := cfg.String("app.name")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testapp", name)
|
||||
})
|
||||
|
||||
t.Run("Builder with All Options", func(t *testing.T) {
|
||||
os.Setenv("BUILDER_SERVER_PORT", "5555")
|
||||
defer os.Unsetenv("BUILDER_SERVER_PORT")
|
||||
|
||||
type Config struct {
|
||||
Server struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
} `toml:"server"`
|
||||
API struct {
|
||||
Key string `toml:"key"`
|
||||
Timeout int `toml:"timeout"`
|
||||
} `toml:"api"`
|
||||
}
|
||||
|
||||
defaults := Config{}
|
||||
defaults.Server.Host = "localhost"
|
||||
defaults.Server.Port = 8080
|
||||
defaults.API.Timeout = 30
|
||||
|
||||
cfg, err := config.NewBuilder().
|
||||
WithDefaults(defaults).
|
||||
WithEnvPrefix("BUILDER_").
|
||||
WithArgs([]string{"--api.key=test-key"}).
|
||||
WithSources(
|
||||
config.SourceCLI,
|
||||
config.SourceEnv,
|
||||
config.SourceDefault,
|
||||
).
|
||||
WithEnvWhitelist("server.port", "api.key").
|
||||
Build()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// CLI should provide api.key
|
||||
apiKey, err := cfg.String("api.key")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-key", apiKey)
|
||||
|
||||
// Env should provide server.port (whitelisted)
|
||||
port, err := cfg.Int64("server.port")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5555), port)
|
||||
|
||||
// Non-whitelisted env should not load
|
||||
os.Setenv("BUILDER_API_TIMEOUT", "99")
|
||||
defer os.Unsetenv("BUILDER_API_TIMEOUT")
|
||||
|
||||
cfg2, err := config.NewBuilder().
|
||||
WithDefaults(defaults).
|
||||
WithEnvPrefix("BUILDER_").
|
||||
WithEnvWhitelist("server.port"). // api.timeout NOT whitelisted
|
||||
Build()
|
||||
require.NoError(t, err)
|
||||
|
||||
timeout, err := cfg2.Int64("api.timeout")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(30), timeout, "non-whitelisted env should not load")
|
||||
})
|
||||
|
||||
t.Run("Builder Custom Transform", func(t *testing.T) {
|
||||
os.Setenv("PORT", "3333")
|
||||
os.Setenv("DB_URL", "postgres://custom")
|
||||
defer func() {
|
||||
os.Unsetenv("PORT")
|
||||
os.Unsetenv("DB_URL")
|
||||
}()
|
||||
|
||||
type Config struct {
|
||||
Server struct {
|
||||
Port int `toml:"port"`
|
||||
} `toml:"server"`
|
||||
Database struct {
|
||||
URL string `toml:"url"`
|
||||
} `toml:"database"`
|
||||
}
|
||||
|
||||
cfg, err := config.NewBuilder().
|
||||
WithDefaults(Config{}).
|
||||
WithEnvTransform(func(path string) string {
|
||||
switch path {
|
||||
case "server.port":
|
||||
return "PORT"
|
||||
case "database.url":
|
||||
return "DB_URL"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}).
|
||||
Build()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
port, err := cfg.Int64("server.port")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3333), port)
|
||||
|
||||
dbURL, err := cfg.String("database.url")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "postgres://custom", dbURL)
|
||||
})
|
||||
|
||||
t.Run("MustBuild Panic", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
config.NewBuilder().
|
||||
WithDefaults("not a struct").
|
||||
MustBuild()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestQuickFunctions(t *testing.T) {
|
||||
t.Run("Quick Success", func(t *testing.T) {
|
||||
type Config struct {
|
||||
App struct {
|
||||
Name string `toml:"name"`
|
||||
} `toml:"app"`
|
||||
}
|
||||
|
||||
defaults := Config{}
|
||||
defaults.App.Name = "quicktest"
|
||||
|
||||
cfg, err := config.Quick(defaults, "QUICK_", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
name, err := cfg.String("app.name")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "quicktest", name)
|
||||
})
|
||||
|
||||
t.Run("QuickCustom", func(t *testing.T) {
|
||||
opts := config.LoadOptions{
|
||||
Sources: []config.Source{
|
||||
config.SourceDefault,
|
||||
config.SourceEnv,
|
||||
},
|
||||
EnvPrefix: "CUSTOM_",
|
||||
}
|
||||
|
||||
cfg, err := config.QuickCustom(nil, opts, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cfg)
|
||||
})
|
||||
|
||||
t.Run("MustQuick Panic", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
config.MustQuick("invalid", "TEST_", "")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestConvenienceFunctions(t *testing.T) {
|
||||
t.Run("Validate", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("required1", "")
|
||||
cfg.Register("required2", 0)
|
||||
cfg.Register("optional", "has-default")
|
||||
|
||||
// Initial validation should fail
|
||||
err := cfg.Validate("required1", "required2")
|
||||
assert.Error(t, err, "expected validation to fail for empty values")
|
||||
|
||||
// Set required values
|
||||
cfg.Set("required1", "value1")
|
||||
cfg.Set("required2", 42)
|
||||
|
||||
// Now should pass
|
||||
err = cfg.Validate("required1", "required2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Validate unregistered path
|
||||
err = cfg.Validate("unregistered")
|
||||
assert.Error(t, err, "expected error for unregistered path")
|
||||
})
|
||||
|
||||
t.Run("Debug Output", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("test.value", "default")
|
||||
cfg.SetSource("test.value", config.SourceFile, "from-file")
|
||||
cfg.SetSource("test.value", config.SourceEnv, "from-env")
|
||||
|
||||
debug := cfg.Debug()
|
||||
|
||||
// Should contain key information
|
||||
assert.NotEmpty(t, debug)
|
||||
|
||||
// Should show sources - checking for actual source string values
|
||||
assert.Contains(t, debug, "file")
|
||||
assert.Contains(t, debug, "from-file")
|
||||
assert.Contains(t, debug, "env")
|
||||
assert.Contains(t, debug, "from-env")
|
||||
})
|
||||
|
||||
t.Run("Clone", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("original", "value")
|
||||
cfg.Set("original", "modified")
|
||||
|
||||
// Clone configuration
|
||||
clone := cfg.Clone()
|
||||
|
||||
// Clone should have same values
|
||||
val, err := clone.String("original")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "modified", val)
|
||||
|
||||
// Modifying clone should not affect original
|
||||
clone.Set("original", "clone-modified")
|
||||
|
||||
origVal, err := cfg.String("original")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "modified", origVal, "original should not be affected by clone modification")
|
||||
})
|
||||
|
||||
t.Run("GetRegisteredPathsWithDefaults", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("app.name", "myapp")
|
||||
cfg.Register("app.version", "1.0.0")
|
||||
cfg.Register("server.port", 8080)
|
||||
|
||||
// Get paths with defaults
|
||||
paths := cfg.GetRegisteredPathsWithDefaults("app.")
|
||||
|
||||
assert.Len(t, paths, 2)
|
||||
assert.Equal(t, "myapp", paths["app.name"])
|
||||
assert.Equal(t, "1.0.0", paths["app.version"])
|
||||
})
|
||||
}
|
||||
310
cmd/main.go
310
cmd/main.go
@ -1,310 +0,0 @@
|
||||
// Test program for the config package
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors" // Import errors package
|
||||
"fmt"
|
||||
"log" // Using standard log for simplicity
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/LixenWraith/config" // Assuming this is the correct import path after potential renaming/moving
|
||||
)
|
||||
|
||||
// LogConfig represents logging configuration parameters
|
||||
type LogConfig struct {
|
||||
// Basic settings
|
||||
Level int64 `toml:"level"`
|
||||
Name string `toml:"name"`
|
||||
Directory string `toml:"directory"`
|
||||
Format string `toml:"format"` // "txt" or "json"
|
||||
Extension string `toml:"extension"`
|
||||
// Formatting
|
||||
ShowTimestamp bool `toml:"show_timestamp"`
|
||||
ShowLevel bool `toml:"show_level"`
|
||||
// Buffer and size limits
|
||||
BufferSize int64 `toml:"buffer_size"` // Channel buffer size
|
||||
MaxSizeMB int64 `toml:"max_size_mb"` // Max size per log file
|
||||
MaxTotalSizeMB int64 `toml:"max_total_size_mb"` // Max total size of all logs in dir
|
||||
MinDiskFreeMB int64 `toml:"min_disk_free_mb"` // Minimum free disk space required
|
||||
// Timers
|
||||
FlushIntervalMs int64 `toml:"flush_interval_ms"` // Interval for flushing file buffer
|
||||
TraceDepth int64 `toml:"trace_depth"` // Default trace depth (0-10)
|
||||
RetentionPeriodHrs float64 `toml:"retention_period_hrs"` // Hours to keep logs (0=disabled)
|
||||
RetentionCheckMins float64 `toml:"retention_check_mins"` // How often to check retention
|
||||
// Disk check settings
|
||||
DiskCheckIntervalMs int64 `toml:"disk_check_interval_ms"` // Base interval for disk checks
|
||||
EnableAdaptiveInterval bool `toml:"enable_adaptive_interval"` // Adjust interval based on log rate
|
||||
MinCheckIntervalMs int64 `toml:"min_check_interval_ms"` // Minimum adaptive interval
|
||||
MaxCheckIntervalMs int64 `toml:"max_check_interval_ms"` // Maximum adaptive interval
|
||||
}
|
||||
|
||||
// Define default configuration values
|
||||
var defaultLogConfig = LogConfig{
|
||||
// Basic settings
|
||||
Level: 1,
|
||||
Name: "default_logger",
|
||||
Directory: "./logs",
|
||||
Format: "txt",
|
||||
Extension: ".log",
|
||||
// Formatting
|
||||
ShowTimestamp: true,
|
||||
ShowLevel: true,
|
||||
// Buffer and size limits
|
||||
BufferSize: 1000,
|
||||
MaxSizeMB: 10,
|
||||
MaxTotalSizeMB: 100,
|
||||
MinDiskFreeMB: 500,
|
||||
// Timers
|
||||
FlushIntervalMs: 1000,
|
||||
TraceDepth: 3,
|
||||
RetentionPeriodHrs: 24.0,
|
||||
RetentionCheckMins: 15.0,
|
||||
// Disk check settings
|
||||
DiskCheckIntervalMs: 60000,
|
||||
EnableAdaptiveInterval: false,
|
||||
MinCheckIntervalMs: 5000,
|
||||
MaxCheckIntervalMs: 300000,
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Create a temporary file path for our test
|
||||
tempDir := os.TempDir()
|
||||
configPath := filepath.Join(tempDir, "logconfig_test_enhanced.toml")
|
||||
|
||||
// Clean up any existing file from previous runs
|
||||
os.Remove(configPath)
|
||||
defer os.Remove(configPath) // Ensure cleanup even on error exit
|
||||
|
||||
fmt.Println("=== Enhanced LogConfig Test Program ===")
|
||||
fmt.Printf("Using temporary config file: %s\n\n", configPath)
|
||||
|
||||
// 1. Initialize the Config instance
|
||||
cfg := config.New()
|
||||
|
||||
// 2. Register default values using RegisterStruct
|
||||
fmt.Println("Registering default values using RegisterStruct...")
|
||||
err := cfg.RegisterStruct("log.", defaultLogConfig) // Note the "log." prefix
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Error registering defaults: %v\n", err)
|
||||
}
|
||||
fmt.Println("Defaults registered.")
|
||||
|
||||
// 3. Load configuration (file doesn't exist yet)
|
||||
fmt.Println("\nAttempting initial load (expecting file not found)...")
|
||||
err = cfg.Load(configPath, nil) // No CLI args yet
|
||||
if err != nil {
|
||||
// Check specifically for ErrConfigNotFound
|
||||
if errors.Is(err, config.ErrConfigNotFound) {
|
||||
fmt.Println("SUCCESS: Correctly detected config file not found.")
|
||||
} else {
|
||||
// Any other error during initial load is unexpected here
|
||||
log.Fatalf("FATAL: Unexpected error loading initial config: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
log.Fatalf("FATAL: Expected an error (ErrConfigNotFound) during initial load, but got nil")
|
||||
}
|
||||
|
||||
// 4. Unmarshal defaults into LogConfig struct
|
||||
var currentConfig LogConfig
|
||||
fmt.Println("\nUnmarshaling current config (should be defaults)...")
|
||||
err = cfg.Scan("log", ¤tConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Error unmarshaling default config: %v\n", err)
|
||||
}
|
||||
|
||||
// Print default values
|
||||
fmt.Println("\n=== Current Configuration (Defaults) ===")
|
||||
printLogConfig(currentConfig)
|
||||
|
||||
// 5. Modify some values using Set
|
||||
fmt.Println("\n=== Modifying Configuration Values via Set ===")
|
||||
fmt.Println("Changing:")
|
||||
fmt.Println(" - log.name: default_logger → saved_logger")
|
||||
fmt.Println(" - log.max_size_mb: 10 → 50")
|
||||
fmt.Println(" - log.retention_period_hrs: 24.0 → 48.0") // Different from CLI override later
|
||||
|
||||
cfg.Set("log.name", "saved_logger") // This will be saved to file
|
||||
cfg.Set("log.max_size_mb", int64(50))
|
||||
cfg.Set("log.retention_period_hrs", 48.0)
|
||||
|
||||
// 6. Save the configuration
|
||||
fmt.Println("\nSaving configuration to file...")
|
||||
err = cfg.Save(configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Error saving config: %v\n", err)
|
||||
}
|
||||
fmt.Printf("Saved configuration to: %s\n", configPath)
|
||||
|
||||
// Optional: Read and print file contents
|
||||
// fileBytes, _ := os.ReadFile(configPath)
|
||||
// fmt.Println("\n=== Saved TOML File Contents ===")
|
||||
// fmt.Println(string(fileBytes))
|
||||
|
||||
// 7. Define some command-line arguments for override testing
|
||||
fmt.Println("\n=== Preparing Command-Line Overrides ===")
|
||||
// Simulate os.Args[1:]
|
||||
cliArgs := []string{
|
||||
"--log.level", "3", // Override default 1
|
||||
"--log.name", "cli_logger", // Override value set before save ("saved_logger")
|
||||
"--log.show_timestamp=false", // Override default true
|
||||
"--log.retention_period_hrs", "72.5", // Override value set before save (48.0)
|
||||
"--other.value", "test", // An unregistered key (should be ignored by Load logic)
|
||||
"--invalid-key", // Invalid key format (test error handling if desired)
|
||||
}
|
||||
fmt.Printf("Simulated CLI Args: %v\n", cliArgs)
|
||||
|
||||
// 8. Load again, now with file and CLI overrides
|
||||
// Create a *new* config instance to simulate a fresh application start
|
||||
// that loads existing file + CLI args over defaults.
|
||||
fmt.Println("\nCreating NEW config instance and loading with file and CLI args...")
|
||||
cfg2 := config.New()
|
||||
fmt.Println("Registering defaults for new instance...")
|
||||
err = cfg2.RegisterStruct("log.", defaultLogConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Error registering defaults for cfg2: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println("Loading config with file and CLI...")
|
||||
err = cfg2.Load(configPath, cliArgs)
|
||||
if err != nil {
|
||||
// Note: If "--invalid-key" is included above, Load should return ErrCLIParse.
|
||||
// Handle or remove the invalid key for a successful load test.
|
||||
// Example check:
|
||||
if errors.Is(err, config.ErrCLIParse) {
|
||||
fmt.Printf("INFO: Expected CLI parsing error detected: %v\n", err)
|
||||
// Decide how to proceed - maybe exit or remove the offending arg and retry
|
||||
// For this example, we'll filter the bad arg and try again
|
||||
var validArgs []string
|
||||
for _, arg := range cliArgs {
|
||||
if !strings.HasPrefix(arg, "--invalid") {
|
||||
validArgs = append(validArgs, arg)
|
||||
}
|
||||
}
|
||||
fmt.Println("Retrying load with filtered CLI args...")
|
||||
err = cfg2.Load(configPath, validArgs)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Error loading config even after filtering CLI args: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
log.Fatalf("FATAL: Unexpected error loading config with file and CLI: %v\n", err)
|
||||
}
|
||||
}
|
||||
fmt.Println("Load successful.")
|
||||
|
||||
// 9. Unmarshal the final configuration state
|
||||
var finalConfig LogConfig
|
||||
fmt.Println("\nUnmarshaling final config state...")
|
||||
err = cfg2.Scan("log", &finalConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Error unmarshaling final config: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n=== Final Configuration (Defaults + File + CLI) ===")
|
||||
printLogConfig(finalConfig)
|
||||
|
||||
// 10. Verify final values (Defaults < File < CLI)
|
||||
fmt.Println("\n=== Final Verification ===")
|
||||
verifyFinalConfig(finalConfig)
|
||||
|
||||
// 11. Demonstrate typed accessors on the final state
|
||||
fmt.Println("\n=== Demonstrating Typed Accessors ===")
|
||||
level, err := cfg2.Int64("log.level")
|
||||
if err != nil {
|
||||
fmt.Printf("ERROR getting log.level via Int64(): %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("SUCCESS: cfg2.Int64(\"log.level\") = %d (matches expected CLI override)\n", level)
|
||||
}
|
||||
|
||||
name, err := cfg2.String("log.name")
|
||||
if err != nil {
|
||||
fmt.Printf("ERROR getting log.name via String(): %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("SUCCESS: cfg2.String(\"log.name\") = %q (matches expected CLI override)\n", name)
|
||||
}
|
||||
|
||||
showTS, err := cfg2.Bool("log.show_timestamp")
|
||||
if err != nil {
|
||||
fmt.Printf("ERROR getting log.show_timestamp via Bool(): %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("SUCCESS: cfg2.Bool(\"log.show_timestamp\") = %t (matches expected CLI override)\n", showTS)
|
||||
}
|
||||
|
||||
// Try getting an unregistered value (should fail)
|
||||
_, err = cfg2.String("other.value")
|
||||
if err == nil {
|
||||
fmt.Println("ERROR: Expected error when getting unregistered key 'other.value', but got nil")
|
||||
} else {
|
||||
fmt.Printf("SUCCESS: Correctly got error for unregistered key 'other.value': %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n=== Test Complete ===")
|
||||
}
|
||||
|
||||
// printLogConfig prints the values of a LogConfig struct
|
||||
func printLogConfig(cfg LogConfig) {
|
||||
fmt.Println(" Basic:")
|
||||
fmt.Printf(" Level: %d, Name: %s, Dir: %s, Format: %s, Ext: %s\n",
|
||||
cfg.Level, cfg.Name, cfg.Directory, cfg.Format, cfg.Extension)
|
||||
fmt.Println(" Formatting:")
|
||||
fmt.Printf(" ShowTimestamp: %t, ShowLevel: %t\n", cfg.ShowTimestamp, cfg.ShowLevel)
|
||||
fmt.Println(" Limits:")
|
||||
fmt.Printf(" BufferSize: %d, MaxSizeMB: %d, MaxTotalSizeMB: %d, MinDiskFreeMB: %d\n",
|
||||
cfg.BufferSize, cfg.MaxSizeMB, cfg.MaxTotalSizeMB, cfg.MinDiskFreeMB)
|
||||
fmt.Println(" Timers:")
|
||||
fmt.Printf(" FlushIntervalMs: %d, TraceDepth: %d, RetentionPeriodHrs: %.1f, RetentionCheckMins: %.1f\n",
|
||||
cfg.FlushIntervalMs, cfg.TraceDepth, cfg.RetentionPeriodHrs, cfg.RetentionCheckMins)
|
||||
fmt.Println(" Disk Check:")
|
||||
fmt.Printf(" DiskCheckIntervalMs: %d, EnableAdaptive: %t, MinCheckMs: %d, MaxCheckMs: %d\n",
|
||||
cfg.DiskCheckIntervalMs, cfg.EnableAdaptiveInterval, cfg.MinCheckIntervalMs, cfg.MaxCheckIntervalMs)
|
||||
}
|
||||
|
||||
// verifyFinalConfig checks if the final values reflect the merge order: Default < File < CLI
|
||||
func verifyFinalConfig(cfg LogConfig) {
|
||||
allCorrect := true
|
||||
fmt.Println("Verifying values reflect merge order (Default < File < CLI)...")
|
||||
|
||||
// Value overridden by CLI
|
||||
if cfg.Level != 3 {
|
||||
fmt.Printf(" ERROR: Level is %d, expected 3 (from CLI)\n", cfg.Level)
|
||||
allCorrect = false
|
||||
}
|
||||
// Value overridden by CLI (overriding file value)
|
||||
if cfg.Name != "cli_logger" {
|
||||
fmt.Printf(" ERROR: Name is %s, expected 'cli_logger' (from CLI)\n", cfg.Name)
|
||||
allCorrect = false
|
||||
}
|
||||
// Value overridden by CLI
|
||||
if cfg.ShowTimestamp != false {
|
||||
fmt.Printf(" ERROR: ShowTimestamp is %t, expected false (from CLI)\n", cfg.ShowTimestamp)
|
||||
allCorrect = false
|
||||
}
|
||||
// Value overridden by CLI (float)
|
||||
if cfg.RetentionPeriodHrs != 72.5 {
|
||||
fmt.Printf(" ERROR: RetentionPeriodHrs is %.1f, expected 72.5 (from CLI)\n", cfg.RetentionPeriodHrs)
|
||||
allCorrect = false
|
||||
}
|
||||
|
||||
// Value overridden by File (not present in CLI)
|
||||
if cfg.MaxSizeMB != 50 {
|
||||
fmt.Printf(" ERROR: MaxSizeMB is %d, expected 50 (from File)\n", cfg.MaxSizeMB)
|
||||
allCorrect = false
|
||||
}
|
||||
|
||||
// Value from Default (not in File or CLI)
|
||||
if cfg.Directory != "./logs" {
|
||||
fmt.Printf(" ERROR: Directory is %s, expected './logs' (from Default)\n", cfg.Directory)
|
||||
allCorrect = false
|
||||
}
|
||||
if cfg.BufferSize != 1000 {
|
||||
fmt.Printf(" ERROR: BufferSize is %d, expected 1000 (from Default)\n", cfg.BufferSize)
|
||||
allCorrect = false
|
||||
}
|
||||
|
||||
if allCorrect {
|
||||
fmt.Println(" SUCCESS: All verified configuration values match expected final state!")
|
||||
} else {
|
||||
fmt.Println(" FAILURE: Some configuration values don't match expected final state!")
|
||||
}
|
||||
}
|
||||
301
cmd/test/main.go
Normal file
301
cmd/test/main.go
Normal file
@ -0,0 +1,301 @@
|
||||
// File: lixenwraith/cmd/test/main.go
|
||||
// Test program for the config package
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/lixenwraith/config"
|
||||
)
|
||||
|
||||
// AppConfig represents a simple application configuration
|
||||
type AppConfig struct {
|
||||
Server struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
} `toml:"server"`
|
||||
Database struct {
|
||||
URL string `toml:"url"`
|
||||
MaxConns int `toml:"max_conns"`
|
||||
} `toml:"database"`
|
||||
API struct {
|
||||
Key string `toml:"key" env:"CUSTOM_API_KEY"` // Custom env mapping
|
||||
Timeout int `toml:"timeout"`
|
||||
} `toml:"api"`
|
||||
Debug bool `toml:"debug"`
|
||||
LogFile string `toml:"log_file"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Println("=== Config Package Feature Test ===\n")
|
||||
|
||||
// Test directories
|
||||
tempDir := os.TempDir()
|
||||
configPath := filepath.Join(tempDir, "test_config.toml")
|
||||
defer os.Remove(configPath)
|
||||
|
||||
// Set up test environment variables
|
||||
setupEnvironment()
|
||||
defer cleanupEnvironment()
|
||||
|
||||
// Run feature tests
|
||||
testQuickStart()
|
||||
testBuilder()
|
||||
testSourceTracking()
|
||||
testEnvironmentFeatures()
|
||||
testValidation()
|
||||
testUtilities()
|
||||
|
||||
fmt.Println("\n=== All Tests Complete ===")
|
||||
}
|
||||
|
||||
func testQuickStart() {
|
||||
fmt.Println("=== Test 1: Quick Start ===")
|
||||
|
||||
// Define defaults
|
||||
defaults := AppConfig{}
|
||||
defaults.Server.Host = "localhost"
|
||||
defaults.Server.Port = 8080
|
||||
defaults.Database.URL = "postgres://localhost/testdb"
|
||||
defaults.Database.MaxConns = 10
|
||||
defaults.Debug = false
|
||||
|
||||
// Quick initialization
|
||||
cfg, err := config.Quick(defaults, "TEST_", "")
|
||||
if err != nil {
|
||||
log.Fatalf("Quick init failed: %v", err)
|
||||
}
|
||||
|
||||
// Access values
|
||||
host, _ := cfg.String("server.host")
|
||||
port, _ := cfg.Int64("server.port")
|
||||
fmt.Printf("Quick config - Host: %s, Port: %d\n", host, port)
|
||||
|
||||
// Verify env override (TEST_DEBUG=true was set)
|
||||
debug, _ := cfg.Bool("debug")
|
||||
fmt.Printf("Debug from env: %v (should be true)\n", debug)
|
||||
}
|
||||
|
||||
func testBuilder() {
|
||||
fmt.Println("\n=== Test 2: Builder Pattern ===")
|
||||
|
||||
defaults := AppConfig{}
|
||||
defaults.Server.Port = 8080
|
||||
defaults.API.Timeout = 30
|
||||
|
||||
// Custom precedence: Env > File > CLI > Default
|
||||
cfg, err := config.NewBuilder().
|
||||
WithDefaults(defaults).
|
||||
WithEnvPrefix("APP_").
|
||||
WithSources(
|
||||
config.SourceEnv,
|
||||
config.SourceFile,
|
||||
config.SourceCLI,
|
||||
config.SourceDefault,
|
||||
).
|
||||
WithArgs([]string{"--server.port=9999"}).
|
||||
Build()
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Builder failed: %v", err)
|
||||
}
|
||||
|
||||
// ENV should win over CLI due to custom precedence
|
||||
port, _ := cfg.Int64("server.port")
|
||||
fmt.Printf("Port with Env > CLI precedence: %d (should be 7070 from env)\n", port)
|
||||
}
|
||||
|
||||
func testSourceTracking() {
|
||||
fmt.Println("\n=== Test 3: Source Tracking ===")
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("test.value", "default")
|
||||
|
||||
// Set from multiple sources
|
||||
cfg.SetSource("test.value", config.SourceFile, "from-file")
|
||||
cfg.SetSource("test.value", config.SourceEnv, "from-env")
|
||||
cfg.SetSource("test.value", config.SourceCLI, "from-cli")
|
||||
|
||||
// Show all sources
|
||||
sources := cfg.GetSources("test.value")
|
||||
fmt.Println("All sources for test.value:")
|
||||
for source, value := range sources {
|
||||
fmt.Printf(" %s: %v\n", source, value)
|
||||
}
|
||||
|
||||
// Get from specific source
|
||||
envVal, exists := cfg.GetSource("test.value", config.SourceEnv)
|
||||
fmt.Printf("Value from env source: %v (exists: %v)\n", envVal, exists)
|
||||
|
||||
// Current value (default precedence)
|
||||
current, _ := cfg.String("test.value")
|
||||
fmt.Printf("Current value: %s (should be from-cli)\n", current)
|
||||
}
|
||||
|
||||
func testEnvironmentFeatures() {
|
||||
fmt.Println("\n=== Test 4: Environment Features ===")
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("api.key", "")
|
||||
cfg.Register("api.secret", "")
|
||||
cfg.Register("database.host", "localhost")
|
||||
|
||||
// Test 4a: Custom env transform
|
||||
fmt.Println("\n4a. Custom Environment Transform:")
|
||||
opts := config.LoadOptions{
|
||||
Sources: []config.Source{config.SourceEnv, config.SourceDefault},
|
||||
EnvTransform: func(path string) string {
|
||||
switch path {
|
||||
case "api.key":
|
||||
return "CUSTOM_API_KEY"
|
||||
case "database.host":
|
||||
return "DB_HOST"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
},
|
||||
}
|
||||
cfg.LoadWithOptions("", nil, opts)
|
||||
|
||||
apiKey, _ := cfg.String("api.key")
|
||||
fmt.Printf("API Key from CUSTOM_API_KEY: %s\n", apiKey)
|
||||
|
||||
// Test 4b: Discover environment variables
|
||||
fmt.Println("\n4b. Environment Discovery:")
|
||||
cfg2 := config.New()
|
||||
cfg2.Register("server.port", 8080)
|
||||
cfg2.Register("debug", false)
|
||||
cfg2.Register("api.timeout", 30)
|
||||
|
||||
discovered := cfg2.DiscoverEnv("TEST_")
|
||||
fmt.Println("Discovered env vars with TEST_ prefix:")
|
||||
for path, envVar := range discovered {
|
||||
fmt.Printf(" %s -> %s\n", path, envVar)
|
||||
}
|
||||
|
||||
// Test 4c: Export configuration as env vars
|
||||
fmt.Println("\n4c. Export as Environment:")
|
||||
cfg2.Set("server.port", 3000)
|
||||
cfg2.Set("debug", true)
|
||||
|
||||
exports := cfg2.ExportEnv("EXPORT_")
|
||||
fmt.Println("Non-default values exported:")
|
||||
for env, value := range exports {
|
||||
fmt.Printf(" export %s=%s\n", env, value)
|
||||
}
|
||||
|
||||
// Test 4d: RegisterWithEnv
|
||||
fmt.Println("\n4d. RegisterWithEnv:")
|
||||
cfg3 := config.New()
|
||||
err := cfg3.RegisterWithEnv("special.value", "default", "SPECIAL_ENV_VAR")
|
||||
if err != nil {
|
||||
fmt.Printf("RegisterWithEnv error: %v\n", err)
|
||||
}
|
||||
special, _ := cfg3.String("special.value")
|
||||
fmt.Printf("Value from SPECIAL_ENV_VAR: %s\n", special)
|
||||
}
|
||||
|
||||
func testValidation() {
|
||||
fmt.Println("\n=== Test 5: Validation ===")
|
||||
|
||||
cfg := config.New()
|
||||
cfg.RegisterRequired("api.key", "")
|
||||
cfg.RegisterRequired("database.url", "")
|
||||
cfg.Register("optional.setting", "default")
|
||||
|
||||
// Should fail validation
|
||||
err := cfg.Validate("api.key", "database.url")
|
||||
if err != nil {
|
||||
fmt.Printf("Validation failed as expected: %v\n", err)
|
||||
}
|
||||
|
||||
// Set required values
|
||||
cfg.Set("api.key", "secret-key")
|
||||
cfg.Set("database.url", "postgres://localhost/db")
|
||||
|
||||
// Should pass validation
|
||||
err = cfg.Validate("api.key", "database.url")
|
||||
if err == nil {
|
||||
fmt.Println("Validation passed after setting required values")
|
||||
}
|
||||
}
|
||||
|
||||
func testUtilities() {
|
||||
fmt.Println("\n=== Test 6: Utility Features ===")
|
||||
|
||||
// Create config with some data
|
||||
cfg := config.New()
|
||||
cfg.Register("app.name", "testapp")
|
||||
cfg.Register("app.version", "1.0.0")
|
||||
cfg.Register("server.port", 8080)
|
||||
|
||||
cfg.SetSource("app.version", config.SourceFile, "1.1.0")
|
||||
cfg.SetSource("server.port", config.SourceEnv, 9090)
|
||||
|
||||
// Test 6a: Debug output
|
||||
fmt.Println("\n6a. Debug Output:")
|
||||
debug := cfg.Debug()
|
||||
fmt.Printf("Debug info (first 200 chars): %.200s...\n", debug)
|
||||
|
||||
// Test 6b: Clone
|
||||
fmt.Println("\n6b. Clone Configuration:")
|
||||
clone := cfg.Clone()
|
||||
clone.Set("app.name", "cloned-app")
|
||||
|
||||
original, _ := cfg.String("app.name")
|
||||
cloned, _ := clone.String("app.name")
|
||||
fmt.Printf("Original app.name: %s, Cloned: %s\n", original, cloned)
|
||||
|
||||
// Test 6c: Reset source
|
||||
fmt.Println("\n6c. Reset Sources:")
|
||||
sources := cfg.GetSources("server.port")
|
||||
fmt.Printf("Sources before reset: %v\n", sources)
|
||||
|
||||
cfg.ResetSource(config.SourceEnv)
|
||||
sources = cfg.GetSources("server.port")
|
||||
fmt.Printf("Sources after env reset: %v\n", sources)
|
||||
|
||||
// Test 6d: Save and load specific source
|
||||
fmt.Println("\n6d. Save/Load Specific Source:")
|
||||
tempFile := filepath.Join(os.TempDir(), "source_test.toml")
|
||||
defer os.Remove(tempFile)
|
||||
|
||||
err := cfg.SaveSource(tempFile, config.SourceFile)
|
||||
if err != nil {
|
||||
fmt.Printf("SaveSource error: %v\n", err)
|
||||
} else {
|
||||
fmt.Println("Saved SourceFile values to temp file")
|
||||
}
|
||||
|
||||
// Test 6e: GetRegisteredPaths
|
||||
fmt.Println("\n6e. Registered Paths:")
|
||||
paths := cfg.GetRegisteredPaths("app.")
|
||||
fmt.Printf("Paths with 'app.' prefix: %v\n", paths)
|
||||
|
||||
pathsWithDefaults := cfg.GetRegisteredPathsWithDefaults("app.")
|
||||
for path, def := range pathsWithDefaults {
|
||||
fmt.Printf(" %s: %v\n", path, def)
|
||||
}
|
||||
}
|
||||
|
||||
func setupEnvironment() {
|
||||
// Set test environment variables
|
||||
os.Setenv("TEST_DEBUG", "true")
|
||||
os.Setenv("TEST_SERVER_PORT", "6666")
|
||||
os.Setenv("APP_SERVER_PORT", "7070")
|
||||
os.Setenv("CUSTOM_API_KEY", "env-api-key")
|
||||
os.Setenv("DB_HOST", "env-db-host")
|
||||
os.Setenv("SPECIAL_ENV_VAR", "special-value")
|
||||
}
|
||||
|
||||
func cleanupEnvironment() {
|
||||
os.Unsetenv("TEST_DEBUG")
|
||||
os.Unsetenv("TEST_SERVER_PORT")
|
||||
os.Unsetenv("APP_SERVER_PORT")
|
||||
os.Unsetenv("CUSTOM_API_KEY")
|
||||
os.Unsetenv("DB_HOST")
|
||||
os.Unsetenv("SPECIAL_ENV_VAR")
|
||||
}
|
||||
229
config.go
229
config.go
@ -1,40 +1,149 @@
|
||||
// File: lixenwraith/config/config.go
|
||||
// Package config provides thread-safe configuration management for Go applications
|
||||
// with support for TOML files, command-line overrides, and default values.
|
||||
// with support for multiple sources: TOML files, environment variables, command-line
|
||||
// arguments, and default values with configurable precedence.
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors" // Import errors package
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrConfigNotFound indicates the specified configuration file was not found.
|
||||
var ErrConfigNotFound = errors.New("configuration file not found")
|
||||
ErrConfigNotFound = errors.New("configuration file not found")
|
||||
|
||||
// ErrCLIParse indicates that parsing command-line arguments failed.
|
||||
var ErrCLIParse = errors.New("failed to parse command-line arguments")
|
||||
ErrCLIParse = errors.New("failed to parse command-line arguments")
|
||||
|
||||
// configItem holds both the default and current value for a configuration path
|
||||
type configItem struct {
|
||||
defaultValue any
|
||||
currentValue any
|
||||
// ErrEnvParse indicates that parsing environment variables failed.
|
||||
ErrEnvParse = errors.New("failed to parse environment variables")
|
||||
)
|
||||
|
||||
// Source represents a configuration source
|
||||
type Source string
|
||||
|
||||
const (
|
||||
SourceDefault Source = "default"
|
||||
SourceFile Source = "file"
|
||||
SourceEnv Source = "env"
|
||||
SourceCLI Source = "cli"
|
||||
)
|
||||
|
||||
// LoadMode defines how configuration sources are processed
|
||||
type LoadMode int
|
||||
|
||||
const (
|
||||
// LoadModeReplace completely replaces values (default behavior)
|
||||
LoadModeReplace LoadMode = iota
|
||||
|
||||
// LoadModeMerge merges maps/structs instead of replacing
|
||||
LoadModeMerge
|
||||
)
|
||||
|
||||
// EnvTransformFunc converts a configuration path to an environment variable name
|
||||
type EnvTransformFunc func(path string) string
|
||||
|
||||
// LoadOptions configures how configuration is loaded from multiple sources
|
||||
type LoadOptions struct {
|
||||
// Sources defines the precedence order (first = highest priority)
|
||||
// Default: [SourceCLI, SourceEnv, SourceFile, SourceDefault]
|
||||
Sources []Source
|
||||
|
||||
// EnvPrefix is prepended to environment variable names
|
||||
// Example: "MYAPP_" transforms "server.port" to "MYAPP_SERVER_PORT"
|
||||
EnvPrefix string
|
||||
|
||||
// EnvTransform customizes how paths map to environment variables
|
||||
// If nil, uses default transformation (dots to underscores, uppercase)
|
||||
EnvTransform EnvTransformFunc
|
||||
|
||||
// LoadMode determines how values are merged
|
||||
LoadMode LoadMode
|
||||
|
||||
// EnvWhitelist limits which paths are checked for env vars (nil = all)
|
||||
EnvWhitelist map[string]bool
|
||||
|
||||
// SkipValidation skips path validation during load
|
||||
SkipValidation bool
|
||||
}
|
||||
|
||||
// Config manages application configuration loaded from files and CLI arguments.
|
||||
// DefaultLoadOptions returns the standard load options
|
||||
func DefaultLoadOptions() LoadOptions {
|
||||
return LoadOptions{
|
||||
Sources: []Source{SourceCLI, SourceEnv, SourceFile, SourceDefault},
|
||||
LoadMode: LoadModeReplace,
|
||||
}
|
||||
}
|
||||
|
||||
// configItem holds configuration values from different sources
|
||||
type configItem struct {
|
||||
defaultValue any
|
||||
values map[Source]any // Values from each source
|
||||
currentValue any // Computed value based on precedence
|
||||
}
|
||||
|
||||
// Config manages application configuration loaded from multiple sources.
|
||||
type Config struct {
|
||||
items map[string]configItem
|
||||
mutex sync.RWMutex
|
||||
options LoadOptions // Current load options
|
||||
fileData map[string]any // Cached file data
|
||||
envData map[string]any // Cached env data
|
||||
cliData map[string]any // Cached CLI data
|
||||
}
|
||||
|
||||
// New creates and initializes a new Config instance.
|
||||
func New() *Config {
|
||||
return &Config{
|
||||
items: make(map[string]configItem),
|
||||
options: DefaultLoadOptions(),
|
||||
fileData: make(map[string]any),
|
||||
envData: make(map[string]any),
|
||||
cliData: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithOptions creates a new Config instance with custom load options
|
||||
func NewWithOptions(opts LoadOptions) *Config {
|
||||
c := New()
|
||||
c.options = opts
|
||||
return c
|
||||
}
|
||||
|
||||
// SetLoadOptions updates the load options and recomputes current values
|
||||
func (c *Config) SetLoadOptions(opts LoadOptions) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.options = opts
|
||||
|
||||
// Recompute all current values based on new precedence
|
||||
for path, item := range c.items {
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// computeValue determines the current value based on precedence
|
||||
func (c *Config) computeValue(path string, item configItem) any {
|
||||
// Check sources in precedence order
|
||||
for _, source := range c.options.Sources {
|
||||
if val, exists := item.values[source]; exists && val != nil {
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
// No source had a value, use default
|
||||
return item.defaultValue
|
||||
}
|
||||
|
||||
// Get retrieves a configuration value using the path.
|
||||
// It returns the current value (or default if not explicitly set).
|
||||
// It returns the current value based on configured precedence.
|
||||
// The second return value indicates if the path was registered.
|
||||
func (c *Config) Get(path string) (any, bool) {
|
||||
c.mutex.RLock()
|
||||
@ -48,11 +157,29 @@ func (c *Config) Get(path string) (any, bool) {
|
||||
return item.currentValue, true
|
||||
}
|
||||
|
||||
// GetSource retrieves a value from a specific source
|
||||
func (c *Config) GetSource(path string, source Source) (any, bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
item, registered := c.items[path]
|
||||
if !registered {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
val, exists := item.values[source]
|
||||
return val, exists
|
||||
}
|
||||
|
||||
// Set updates a configuration value for the given path.
|
||||
// It returns an error if the path is not registered.
|
||||
// Note: This allows setting a value of a different type than the default.
|
||||
// Type-specific getters will handle conversion attempts.
|
||||
// It sets the value in the highest priority source (typically CLI).
|
||||
// Returns an error if the path is not registered.
|
||||
func (c *Config) Set(path string, value any) error {
|
||||
return c.SetSource(path, c.options.Sources[0], value)
|
||||
}
|
||||
|
||||
// SetSource sets a value for a specific source
|
||||
func (c *Config) SetSource(path string, source Source, value any) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@ -61,7 +188,81 @@ func (c *Config) Set(path string, value any) error {
|
||||
return fmt.Errorf("path %s is not registered", path)
|
||||
}
|
||||
|
||||
item.currentValue = value
|
||||
if item.values == nil {
|
||||
item.values = make(map[Source]any)
|
||||
}
|
||||
|
||||
item.values[source] = value
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
|
||||
// Update source cache
|
||||
switch source {
|
||||
case SourceFile:
|
||||
c.fileData[path] = value
|
||||
case SourceEnv:
|
||||
c.envData[path] = value
|
||||
case SourceCLI:
|
||||
c.cliData[path] = value
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSources returns all sources that have a value for the given path
|
||||
func (c *Config) GetSources(path string) map[Source]any {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
item, registered := c.items[path]
|
||||
if !registered {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[Source]any)
|
||||
for source, value := range item.values {
|
||||
result[source] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Reset clears all non-default values and resets to defaults
|
||||
func (c *Config) Reset() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Clear source caches
|
||||
c.fileData = make(map[string]any)
|
||||
c.envData = make(map[string]any)
|
||||
c.cliData = make(map[string]any)
|
||||
|
||||
// Reset all items to default values
|
||||
for path, item := range c.items {
|
||||
item.values = make(map[Source]any)
|
||||
item.currentValue = item.defaultValue
|
||||
c.items[path] = item
|
||||
}
|
||||
}
|
||||
|
||||
// ResetSource clears all values from a specific source
|
||||
func (c *Config) ResetSource(source Source) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Clear source cache
|
||||
switch source {
|
||||
case SourceFile:
|
||||
c.fileData = make(map[string]any)
|
||||
case SourceEnv:
|
||||
c.envData = make(map[string]any)
|
||||
case SourceCLI:
|
||||
c.cliData = make(map[string]any)
|
||||
}
|
||||
|
||||
// Remove source values from all items
|
||||
for path, item := range c.items {
|
||||
delete(item.values, source)
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
}
|
||||
}
|
||||
231
convenience.go
Normal file
231
convenience.go
Normal file
@ -0,0 +1,231 @@
|
||||
// File: lixenwraith/config/convenience.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/BurntSushi/toml"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Quick creates a fully configured Config instance with a single call
|
||||
// This is the recommended way to initialize configuration for most applications
|
||||
func Quick(structDefaults interface{}, envPrefix, configFile string) (*Config, error) {
|
||||
cfg := New()
|
||||
|
||||
// Register defaults from struct if provided
|
||||
if structDefaults != nil {
|
||||
if err := cfg.RegisterStruct("", structDefaults); err != nil {
|
||||
return nil, fmt.Errorf("failed to register defaults: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load with standard precedence: CLI > Env > File > Default
|
||||
opts := DefaultLoadOptions()
|
||||
opts.EnvPrefix = envPrefix
|
||||
|
||||
err := cfg.LoadWithOptions(configFile, os.Args[1:], opts)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// QuickCustom creates a Config with custom options
|
||||
func QuickCustom(structDefaults interface{}, opts LoadOptions, configFile string) (*Config, error) {
|
||||
cfg := NewWithOptions(opts)
|
||||
|
||||
// Register defaults from struct if provided
|
||||
if structDefaults != nil {
|
||||
if err := cfg.RegisterStruct("", structDefaults); err != nil {
|
||||
return nil, fmt.Errorf("failed to register defaults: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := cfg.LoadWithOptions(configFile, os.Args[1:], opts)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// MustQuick is like Quick but panics on error
|
||||
func MustQuick(structDefaults interface{}, envPrefix, configFile string) *Config {
|
||||
cfg, err := Quick(structDefaults, envPrefix, configFile)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("config initialization failed: %v", err))
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// GenerateFlags creates flag.FlagSet entries for all registered paths
|
||||
func (c *Config) GenerateFlags() *flag.FlagSet {
|
||||
fs := flag.NewFlagSet("config", flag.ContinueOnError)
|
||||
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
for path, item := range c.items {
|
||||
// Create flag based on default value type
|
||||
switch v := item.defaultValue.(type) {
|
||||
case bool:
|
||||
fs.Bool(path, v, fmt.Sprintf("Config: %s", path))
|
||||
case int64:
|
||||
fs.Int64(path, v, fmt.Sprintf("Config: %s", path))
|
||||
case int:
|
||||
fs.Int(path, v, fmt.Sprintf("Config: %s", path))
|
||||
case float64:
|
||||
fs.Float64(path, v, fmt.Sprintf("Config: %s", path))
|
||||
case string:
|
||||
fs.String(path, v, fmt.Sprintf("Config: %s", path))
|
||||
default:
|
||||
// For other types, use string flag
|
||||
fs.String(path, fmt.Sprintf("%v", v), fmt.Sprintf("Config: %s", path))
|
||||
}
|
||||
}
|
||||
|
||||
return fs
|
||||
}
|
||||
|
||||
// BindFlags updates configuration from parsed flag.FlagSet
|
||||
func (c *Config) BindFlags(fs *flag.FlagSet) error {
|
||||
var errors []error
|
||||
|
||||
fs.Visit(func(f *flag.Flag) {
|
||||
value := f.Value.String()
|
||||
parsed := parseValue(value)
|
||||
|
||||
if err := c.SetSource(f.Name, SourceCLI, parsed); err != nil {
|
||||
errors = append(errors, fmt.Errorf("flag %s: %w", f.Name, err))
|
||||
}
|
||||
})
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("failed to bind %d flags: %w", len(errors), errors[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks that all required configuration values are set
|
||||
// A value is considered "set" if it differs from its default value
|
||||
func (c *Config) Validate(required ...string) error {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
var missing []string
|
||||
|
||||
for _, path := range required {
|
||||
item, exists := c.items[path]
|
||||
if !exists {
|
||||
missing = append(missing, path+" (not registered)")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if value equals default (indicating not set)
|
||||
if reflect.DeepEqual(item.currentValue, item.defaultValue) {
|
||||
// Check if any source provided a value
|
||||
hasValue := false
|
||||
for _, val := range item.values {
|
||||
if val != nil {
|
||||
hasValue = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasValue {
|
||||
missing = append(missing, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("missing required configuration: %s", strings.Join(missing, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Watch returns a channel that receives updates when configuration values change
|
||||
// This is useful for hot-reloading configurations
|
||||
// Note: This is a placeholder for future enhancement
|
||||
func (c *Config) Watch() <-chan string {
|
||||
// TODO: Implement file watching and config reload
|
||||
ch := make(chan string)
|
||||
close(ch) // Close immediately for now
|
||||
return ch
|
||||
}
|
||||
|
||||
// Debug returns a formatted string showing all configuration values and their sources
|
||||
func (c *Config) Debug() string {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("Configuration Debug Info:\n")
|
||||
b.WriteString(fmt.Sprintf("Precedence: %v\n", c.options.Sources))
|
||||
b.WriteString("Current values:\n")
|
||||
|
||||
for path, item := range c.items {
|
||||
b.WriteString(fmt.Sprintf(" %s:\n", path))
|
||||
b.WriteString(fmt.Sprintf(" Current: %v\n", item.currentValue))
|
||||
b.WriteString(fmt.Sprintf(" Default: %v\n", item.defaultValue))
|
||||
|
||||
for source, value := range item.values {
|
||||
b.WriteString(fmt.Sprintf(" %s: %v\n", source, value))
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Dump writes the current configuration to stdout in TOML format
|
||||
func (c *Config) Dump() error {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
nestedData := make(map[string]any)
|
||||
for path, item := range c.items {
|
||||
setNestedValue(nestedData, path, item.currentValue)
|
||||
}
|
||||
|
||||
encoder := toml.NewEncoder(os.Stdout)
|
||||
return encoder.Encode(nestedData)
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the configuration
|
||||
func (c *Config) Clone() *Config {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
clone := &Config{
|
||||
items: make(map[string]configItem),
|
||||
options: c.options,
|
||||
fileData: make(map[string]any),
|
||||
envData: make(map[string]any),
|
||||
cliData: make(map[string]any),
|
||||
}
|
||||
|
||||
// Deep copy items
|
||||
for path, item := range c.items {
|
||||
newItem := configItem{
|
||||
defaultValue: item.defaultValue,
|
||||
currentValue: item.currentValue,
|
||||
values: make(map[Source]any),
|
||||
}
|
||||
|
||||
for source, value := range item.values {
|
||||
newItem.values[source] = value
|
||||
}
|
||||
|
||||
clone.items[path] = newItem
|
||||
}
|
||||
|
||||
// Copy cache data
|
||||
for k, v := range c.fileData {
|
||||
clone.fileData[k] = v
|
||||
}
|
||||
for k, v := range c.envData {
|
||||
clone.envData[k] = v
|
||||
}
|
||||
for k, v := range c.cliData {
|
||||
clone.cliData[k] = v
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
60
doc.go
Normal file
60
doc.go
Normal file
@ -0,0 +1,60 @@
|
||||
// File: lixenwraith/config/doc.go
|
||||
|
||||
// Package config provides thread-safe configuration management for Go applications
|
||||
// with support for multiple sources: TOML files, environment variables, command-line
|
||||
// arguments, and default values with configurable precedence.
|
||||
//
|
||||
// Features:
|
||||
// - Multiple configuration sources with customizable precedence
|
||||
// - Thread-safe operations using sync.RWMutex
|
||||
// - Automatic type conversions for common Go types
|
||||
// - Struct registration with tag support
|
||||
// - Environment variable auto-discovery and mapping
|
||||
// - Builder pattern for easy initialization
|
||||
// - Source tracking to see where values originated
|
||||
// - Configuration validation
|
||||
// - Zero dependencies (only stdlib + toml parser + mapstructure)
|
||||
//
|
||||
// Quick Start:
|
||||
//
|
||||
// type Config struct {
|
||||
// Server struct {
|
||||
// Host string `toml:"host"`
|
||||
// Port int `toml:"port"`
|
||||
// } `toml:"server"`
|
||||
// }
|
||||
//
|
||||
// defaults := Config{}
|
||||
// defaults.Server.Host = "localhost"
|
||||
// defaults.Server.Port = 8080
|
||||
//
|
||||
// cfg, err := config.Quick(defaults, "MYAPP_", "config.toml")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// host, _ := cfg.String("server.host")
|
||||
// port, _ := cfg.Int64("server.port")
|
||||
//
|
||||
// Default Precedence (highest to lowest):
|
||||
// 1. Command-line arguments (--server.port=9090)
|
||||
// 2. Environment variables (MYAPP_SERVER_PORT=9090)
|
||||
// 3. Configuration file (config.toml)
|
||||
// 4. Default values
|
||||
//
|
||||
// Custom Precedence:
|
||||
//
|
||||
// cfg, err := config.NewBuilder().
|
||||
// WithDefaults(defaults).
|
||||
// WithSources(
|
||||
// config.SourceEnv, // Environment the highest priority
|
||||
// config.SourceCLI,
|
||||
// config.SourceFile,
|
||||
// config.SourceDefault,
|
||||
// ).
|
||||
// Build()
|
||||
//
|
||||
// Thread Safety:
|
||||
// All operations are thread-safe. The package uses read-write mutexes to allow
|
||||
// concurrent reads while protecting writes.
|
||||
package config
|
||||
225
env_test.go
Normal file
225
env_test.go
Normal file
@ -0,0 +1,225 @@
|
||||
// File: lixenwraith/config/env_test.go
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/lixenwraith/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnvironmentVariables(t *testing.T) {
|
||||
t.Run("Basic Environment Loading", func(t *testing.T) {
|
||||
// Set up environment
|
||||
envVars := map[string]string{
|
||||
"TEST_SERVER_HOST": "env-host",
|
||||
"TEST_SERVER_PORT": "9999",
|
||||
"TEST_DEBUG": "true",
|
||||
}
|
||||
for k, v := range envVars {
|
||||
os.Setenv(k, v)
|
||||
defer os.Unsetenv(k)
|
||||
}
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("server.host", "default-host")
|
||||
cfg.Register("server.port", 8080)
|
||||
cfg.Register("debug", false)
|
||||
|
||||
// Load environment variables
|
||||
err := cfg.LoadEnv("TEST_")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify values
|
||||
host, _ := cfg.String("server.host")
|
||||
assert.Equal(t, "env-host", host)
|
||||
|
||||
port, _ := cfg.Int64("server.port")
|
||||
assert.Equal(t, int64(9999), port)
|
||||
|
||||
debug, _ := cfg.Bool("debug")
|
||||
assert.True(t, debug)
|
||||
})
|
||||
|
||||
t.Run("Custom Environment Transform", func(t *testing.T) {
|
||||
os.Setenv("PORT", "3000")
|
||||
os.Setenv("DATABASE_URL", "postgres://localhost/test")
|
||||
defer func() {
|
||||
os.Unsetenv("PORT")
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
}()
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("server.port", 8080)
|
||||
cfg.Register("database.url", "sqlite://memory")
|
||||
|
||||
opts := config.LoadOptions{
|
||||
Sources: []config.Source{config.SourceEnv, config.SourceDefault},
|
||||
EnvTransform: func(path string) string {
|
||||
mapping := map[string]string{
|
||||
"server.port": "PORT",
|
||||
"database.url": "DATABASE_URL",
|
||||
}
|
||||
return mapping[path]
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.LoadWithOptions("", nil, opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
port, _ := cfg.Int64("server.port")
|
||||
assert.Equal(t, int64(3000), port)
|
||||
|
||||
dbURL, _ := cfg.String("database.url")
|
||||
assert.Equal(t, "postgres://localhost/test", dbURL)
|
||||
})
|
||||
|
||||
t.Run("Environment Discovery", func(t *testing.T) {
|
||||
// Set up various env vars
|
||||
envVars := map[string]string{
|
||||
"APP_SERVER_HOST": "discovered",
|
||||
"APP_SERVER_PORT": "4444",
|
||||
"APP_UNREGISTERED": "ignored",
|
||||
}
|
||||
for k, v := range envVars {
|
||||
os.Setenv(k, v)
|
||||
defer os.Unsetenv(k)
|
||||
}
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("server.host", "default")
|
||||
cfg.Register("server.port", 8080)
|
||||
cfg.Register("server.timeout", 30)
|
||||
|
||||
// Discover which registered paths have env vars
|
||||
discovered := cfg.DiscoverEnv("APP_")
|
||||
|
||||
// Should find 2 env vars
|
||||
assert.Len(t, discovered, 2)
|
||||
assert.Equal(t, "APP_SERVER_HOST", discovered["server.host"])
|
||||
assert.Equal(t, "APP_SERVER_PORT", discovered["server.port"])
|
||||
assert.NotContains(t, discovered, "unregistered")
|
||||
})
|
||||
|
||||
t.Run("Environment Whitelist", func(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"SECRET_API_KEY": "secret-value",
|
||||
"SECRET_DATABASE_PASSWORD": "db-pass",
|
||||
"SECRET_SERVER_PORT": "5555",
|
||||
}
|
||||
for k, v := range envVars {
|
||||
os.Setenv(k, v)
|
||||
defer os.Unsetenv(k)
|
||||
}
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("api.key", "")
|
||||
cfg.Register("database.password", "")
|
||||
cfg.Register("server.port", 8080)
|
||||
|
||||
opts := config.LoadOptions{
|
||||
Sources: []config.Source{config.SourceEnv, config.SourceDefault},
|
||||
EnvPrefix: "SECRET_",
|
||||
EnvWhitelist: map[string]bool{
|
||||
"api.key": true,
|
||||
"database.password": true,
|
||||
// server.port is NOT whitelisted
|
||||
},
|
||||
}
|
||||
|
||||
cfg.LoadWithOptions("", nil, opts)
|
||||
|
||||
// Whitelisted values should load
|
||||
apiKey, _ := cfg.String("api.key")
|
||||
assert.Equal(t, "secret-value", apiKey)
|
||||
|
||||
dbPass, _ := cfg.String("database.password")
|
||||
assert.Equal(t, "db-pass", dbPass)
|
||||
|
||||
// Non-whitelisted should use default
|
||||
port, _ := cfg.Int64("server.port")
|
||||
assert.Equal(t, int64(8080), port)
|
||||
})
|
||||
|
||||
t.Run("RegisterWithEnv", func(t *testing.T) {
|
||||
os.Setenv("CUSTOM_PORT", "6666")
|
||||
defer os.Unsetenv("CUSTOM_PORT")
|
||||
|
||||
cfg := config.New()
|
||||
|
||||
// Register with explicit env mapping
|
||||
err := cfg.RegisterWithEnv("server.port", 8080, "CUSTOM_PORT")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should immediately have env value
|
||||
port, _ := cfg.Int64("server.port")
|
||||
assert.Equal(t, int64(6666), port)
|
||||
})
|
||||
|
||||
t.Run("Export Environment", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("app.name", "myapp")
|
||||
cfg.Register("app.version", "1.0.0")
|
||||
cfg.Register("server.port", 8080)
|
||||
|
||||
// Set some non-default values
|
||||
cfg.Set("app.version", "2.0.0")
|
||||
cfg.Set("server.port", 9090)
|
||||
|
||||
// Export as env vars
|
||||
exports := cfg.ExportEnv("EXPORT_")
|
||||
|
||||
// Should export non-default values
|
||||
assert.Equal(t, "2.0.0", exports["EXPORT_APP_VERSION"])
|
||||
assert.Equal(t, "9090", exports["EXPORT_SERVER_PORT"])
|
||||
|
||||
// Should not export defaults
|
||||
assert.NotContains(t, exports, "EXPORT_APP_NAME")
|
||||
})
|
||||
|
||||
t.Run("Type Parsing from Environment", func(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"TYPES_STRING": "hello world",
|
||||
"TYPES_INT": "42",
|
||||
"TYPES_FLOAT": "3.14159",
|
||||
"TYPES_BOOL_TRUE": "true",
|
||||
"TYPES_BOOL_FALSE": "false",
|
||||
"TYPES_QUOTED": `"quoted string"`,
|
||||
}
|
||||
for k, v := range envVars {
|
||||
os.Setenv(k, v)
|
||||
defer os.Unsetenv(k)
|
||||
}
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("string", "")
|
||||
cfg.Register("int", 0)
|
||||
cfg.Register("float", 0.0)
|
||||
cfg.Register("bool.true", false)
|
||||
cfg.Register("bool.false", true)
|
||||
cfg.Register("quoted", "")
|
||||
|
||||
cfg.LoadEnv("TYPES_")
|
||||
|
||||
// Verify type conversions
|
||||
s, _ := cfg.String("string")
|
||||
assert.Equal(t, "hello world", s)
|
||||
|
||||
i, _ := cfg.Int64("int")
|
||||
assert.Equal(t, int64(42), i)
|
||||
|
||||
f, _ := cfg.Float64("float")
|
||||
assert.Equal(t, 3.14159, f)
|
||||
|
||||
bt, _ := cfg.Bool("bool.true")
|
||||
assert.True(t, bt)
|
||||
|
||||
bf, _ := cfg.Bool("bool.false")
|
||||
assert.False(t, bf)
|
||||
|
||||
q, _ := cfg.String("quoted")
|
||||
assert.Equal(t, "quoted string", q)
|
||||
})
|
||||
}
|
||||
9
go.mod
9
go.mod
@ -1,8 +1,15 @@
|
||||
module github.com/LixenWraith/config
|
||||
module github.com/lixenwraith/config
|
||||
|
||||
go 1.24.2
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.5.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
10
go.sum
10
go.sum
@ -1,4 +1,14 @@
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
// File: lixenwraith/config/helper.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
import "strings"
|
||||
|
||||
// flattenMap converts a nested map[string]any to a flat map[string]any with dot-notation paths.
|
||||
func flattenMap(nested map[string]any, prefix string) map[string]any {
|
||||
412
io.go
412
io.go
@ -1,3 +1,4 @@
|
||||
// File: lixenwraith/config/io.go
|
||||
package config
|
||||
|
||||
import (
|
||||
@ -13,72 +14,275 @@ import (
|
||||
)
|
||||
|
||||
// Load reads configuration from a TOML file and merges overrides from command-line arguments.
|
||||
// 'args' should be the command-line arguments (e.g., os.Args[1:]).
|
||||
// It returns an error if loading or parsing fails.
|
||||
// Specific errors ErrConfigNotFound and ErrCLIParse can be checked using errors.Is.
|
||||
func (c *Config) Load(path string, args []string) error {
|
||||
// This is a convenience method that maintains backward compatibility.
|
||||
func (c *Config) Load(filePath string, args []string) error {
|
||||
return c.LoadWithOptions(filePath, args, c.options)
|
||||
}
|
||||
|
||||
// LoadWithOptions loads configuration from multiple sources with custom options
|
||||
func (c *Config) LoadWithOptions(filePath string, args []string, opts LoadOptions) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.options = opts
|
||||
c.mutex.Unlock()
|
||||
|
||||
var errNotFound error
|
||||
var errCLI error
|
||||
var loadErrors []error
|
||||
|
||||
fileConfig := make(map[string]any) // Holds only file data
|
||||
// Process each source according to precedence (in reverse order for proper layering)
|
||||
for i := len(opts.Sources) - 1; i >= 0; i-- {
|
||||
source := opts.Sources[i]
|
||||
|
||||
// --- Load from file ---
|
||||
switch source {
|
||||
case SourceDefault:
|
||||
// Defaults are already in place from Register calls
|
||||
continue
|
||||
|
||||
case SourceFile:
|
||||
if filePath != "" {
|
||||
if err := c.loadFile(filePath); err != nil {
|
||||
if errors.Is(err, ErrConfigNotFound) {
|
||||
loadErrors = append(loadErrors, err)
|
||||
} else {
|
||||
return err // Fatal error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case SourceEnv:
|
||||
if err := c.loadEnv(opts); err != nil {
|
||||
loadErrors = append(loadErrors, err)
|
||||
}
|
||||
|
||||
case SourceCLI:
|
||||
if len(args) > 0 {
|
||||
if err := c.loadCLI(args); err != nil {
|
||||
loadErrors = append(loadErrors, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(loadErrors...)
|
||||
}
|
||||
|
||||
// LoadEnv loads configuration values from environment variables
|
||||
func (c *Config) LoadEnv(prefix string) error {
|
||||
opts := c.options
|
||||
opts.EnvPrefix = prefix
|
||||
return c.loadEnv(opts)
|
||||
}
|
||||
|
||||
// LoadCLI loads configuration values from command-line arguments
|
||||
func (c *Config) LoadCLI(args []string) error {
|
||||
return c.loadCLI(args)
|
||||
}
|
||||
|
||||
// LoadFile loads configuration values from a TOML file
|
||||
func (c *Config) LoadFile(filePath string) error {
|
||||
return c.loadFile(filePath)
|
||||
}
|
||||
|
||||
// loadFile reads and parses a TOML configuration file
|
||||
func (c *Config) loadFile(path string) error {
|
||||
fileData, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
errNotFound = ErrConfigNotFound
|
||||
// fileData is nil, proceed to CLI args
|
||||
} else {
|
||||
return ErrConfigNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to read config file '%s': %w", path, err)
|
||||
}
|
||||
} else if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
|
||||
|
||||
fileConfig := make(map[string]any)
|
||||
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
|
||||
}
|
||||
|
||||
// --- Flatten file data ---
|
||||
// Flatten and apply file data
|
||||
flattenedFileConfig := flattenMap(fileConfig, "")
|
||||
|
||||
// --- Parse CLI arguments ---
|
||||
cliOverrides := make(map[string]any) // Holds only CLI args data
|
||||
if len(args) > 0 {
|
||||
parsedCliMap, parseErr := parseArgs(args) // parseArgs returns a nested map
|
||||
if parseErr != nil {
|
||||
// Wrap the CLI parsing error with our specific error type
|
||||
errCLI = fmt.Errorf("%w: %w", ErrCLIParse, parseErr)
|
||||
// Do not return yet, proceed to merge what we have
|
||||
} else {
|
||||
// Flatten the nested map from CLI args only if parsing succeeded
|
||||
cliOverrides = flattenMap(parsedCliMap, "")
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Store in cache
|
||||
c.fileData = flattenedFileConfig
|
||||
|
||||
// Apply to registered paths
|
||||
for path, value := range flattenedFileConfig {
|
||||
if item, exists := c.items[path]; exists {
|
||||
if item.values == nil {
|
||||
item.values = make(map[Source]any)
|
||||
}
|
||||
item.values[SourceFile] = value
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
}
|
||||
// Ignore unregistered paths from file
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadEnv loads configuration from environment variables
|
||||
func (c *Config) loadEnv(opts LoadOptions) error {
|
||||
// Default transform function
|
||||
transform := opts.EnvTransform
|
||||
if transform == nil {
|
||||
transform = func(path string) string {
|
||||
// Convert dots to underscores and uppercase
|
||||
env := strings.ReplaceAll(path, ".", "_")
|
||||
env = strings.ToUpper(env)
|
||||
if opts.EnvPrefix != "" {
|
||||
env = opts.EnvPrefix + env
|
||||
}
|
||||
return env
|
||||
}
|
||||
}
|
||||
|
||||
// --- Merge and Update Internal State ---
|
||||
// Iterate through registered paths to apply loaded/default values correctly.
|
||||
// The order of precedence is: CLI > File > Registered Default
|
||||
for regPath, item := range c.items {
|
||||
// 1. Check CLI overrides (only if CLI parsing succeeded)
|
||||
if errCLI == nil {
|
||||
if cliVal, cliExists := cliOverrides[regPath]; cliExists {
|
||||
item.currentValue = cliVal
|
||||
c.items[regPath] = item
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Clear previous env data
|
||||
c.envData = make(map[string]any)
|
||||
|
||||
// Check each registered path for corresponding env var
|
||||
for path, item := range c.items {
|
||||
// Skip if whitelisted and not in whitelist
|
||||
if opts.EnvWhitelist != nil && !opts.EnvWhitelist[path] {
|
||||
continue
|
||||
}
|
||||
|
||||
envVar := transform(path)
|
||||
if value, exists := os.LookupEnv(envVar); exists {
|
||||
// Parse the string value
|
||||
parsedValue := parseValue(value)
|
||||
|
||||
if item.values == nil {
|
||||
item.values = make(map[Source]any)
|
||||
}
|
||||
item.values[SourceEnv] = parsedValue
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
c.envData[path] = parsedValue
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check File config (if no CLI override or CLI parsing failed)
|
||||
if fileVal, fileExists := flattenedFileConfig[regPath]; fileExists {
|
||||
item.currentValue = fileVal
|
||||
} else {
|
||||
// 3. Use Default (if not in CLI or File)
|
||||
item.currentValue = item.defaultValue
|
||||
}
|
||||
c.items[regPath] = item
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Join(errNotFound, errCLI)
|
||||
// loadCLI loads configuration from command-line arguments
|
||||
func (c *Config) loadCLI(args []string) error {
|
||||
parsedCLI, err := parseArgs(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrCLIParse, err)
|
||||
}
|
||||
|
||||
// Flatten CLI data
|
||||
flattenedCLI := flattenMap(parsedCLI, "")
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Store in cache
|
||||
c.cliData = flattenedCLI
|
||||
|
||||
// Apply to registered paths
|
||||
for path, value := range flattenedCLI {
|
||||
if item, exists := c.items[path]; exists {
|
||||
if item.values == nil {
|
||||
item.values = make(map[Source]any)
|
||||
}
|
||||
item.values[SourceCLI] = value
|
||||
item.currentValue = c.computeValue(path, item)
|
||||
c.items[path] = item
|
||||
}
|
||||
// Ignore unregistered paths from CLI
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DiscoverEnv finds all environment variables matching registered paths
|
||||
// and returns a map of path -> env var name for found variables
|
||||
func (c *Config) DiscoverEnv(prefix string) map[string]string {
|
||||
transform := c.options.EnvTransform
|
||||
if transform == nil {
|
||||
transform = defaultEnvTransform(prefix)
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
discovered := make(map[string]string)
|
||||
|
||||
for path := range c.items {
|
||||
envVar := transform(path)
|
||||
if _, exists := os.LookupEnv(envVar); exists {
|
||||
discovered[path] = envVar
|
||||
}
|
||||
}
|
||||
|
||||
return discovered
|
||||
}
|
||||
|
||||
// ExportEnv exports the current configuration as environment variables
|
||||
// Only exports paths that have non-default values
|
||||
func (c *Config) ExportEnv(prefix string) map[string]string {
|
||||
transform := c.options.EnvTransform
|
||||
if transform == nil {
|
||||
transform = defaultEnvTransform(prefix)
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
exports := make(map[string]string)
|
||||
|
||||
for path, item := range c.items {
|
||||
// Only export if value differs from default
|
||||
if item.currentValue != item.defaultValue {
|
||||
envVar := transform(path)
|
||||
exports[envVar] = fmt.Sprintf("%v", item.currentValue)
|
||||
}
|
||||
}
|
||||
|
||||
return exports
|
||||
}
|
||||
|
||||
// defaultEnvTransform creates the default environment variable transformer
|
||||
func defaultEnvTransform(prefix string) EnvTransformFunc {
|
||||
return func(path string) string {
|
||||
env := strings.ReplaceAll(path, ".", "_")
|
||||
env = strings.ToUpper(env)
|
||||
if prefix != "" {
|
||||
env = prefix + env
|
||||
}
|
||||
return env
|
||||
}
|
||||
}
|
||||
|
||||
// parseValue attempts to parse a string into appropriate types
|
||||
func parseValue(s string) any {
|
||||
// Try boolean
|
||||
if v, err := strconv.ParseBool(s); err == nil {
|
||||
return v
|
||||
}
|
||||
|
||||
// Try int64
|
||||
if v, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
return v
|
||||
}
|
||||
|
||||
// Try float64
|
||||
if v, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return v
|
||||
}
|
||||
|
||||
// Remove quotes if present
|
||||
if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
|
||||
// Return as string
|
||||
return s
|
||||
}
|
||||
|
||||
// Save writes the current configuration to a TOML file atomically.
|
||||
@ -93,20 +297,18 @@ func (c *Config) Save(path string) error {
|
||||
|
||||
c.mutex.RUnlock()
|
||||
|
||||
// --- Marshal using BurntSushi/toml ---
|
||||
// Marshal using BurntSushi/toml
|
||||
var buf bytes.Buffer
|
||||
encoder := toml.NewEncoder(&buf)
|
||||
// encoder.Indent = " " // Optional use of 2 spaces for indentation
|
||||
if err := encoder.Encode(nestedData); err != nil {
|
||||
return fmt.Errorf("failed to marshal config data to TOML: %w", err)
|
||||
}
|
||||
tomlData := buf.Bytes()
|
||||
// --- End Marshal ---
|
||||
|
||||
// --- Atomic write logic ---
|
||||
// Atomic write logic
|
||||
dir := filepath.Dir(path)
|
||||
// Ensure the directory exists
|
||||
if err := os.MkdirAll(dir, 0755); err != nil { // 0755 allows owner rwx, group rx, other rx
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory '%s': %w", dir, err)
|
||||
}
|
||||
|
||||
@ -115,56 +317,111 @@ func (c *Config) Save(path string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary config file in '%s': %w", dir, err)
|
||||
}
|
||||
// Defer cleanup in case of errors during write/rename
|
||||
|
||||
tempFilePath := tempFile.Name()
|
||||
removed := false
|
||||
defer func() {
|
||||
if !removed {
|
||||
os.Remove(tempFilePath) // Clean up temp file if rename fails or we panic
|
||||
os.Remove(tempFilePath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Write data to the temporary file
|
||||
if _, err := tempFile.Write(tomlData); err != nil {
|
||||
tempFile.Close() // Close file before returning error
|
||||
tempFile.Close()
|
||||
return fmt.Errorf("failed to write temp config file '%s': %w", tempFilePath, err)
|
||||
}
|
||||
|
||||
// Sync data to disk
|
||||
if err := tempFile.Sync(); err != nil {
|
||||
tempFile.Close()
|
||||
return fmt.Errorf("failed to sync temp config file '%s': %w", tempFilePath, err)
|
||||
}
|
||||
|
||||
// Close the temporary file
|
||||
if err := tempFile.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close temp config file '%s': %w", tempFilePath, err)
|
||||
}
|
||||
|
||||
// Set permissions on the temporary file *before* renaming (safer)
|
||||
// Use 0644: owner rw, group r, other r
|
||||
// Set permissions on the temporary file
|
||||
if err := os.Chmod(tempFilePath, 0644); err != nil {
|
||||
return fmt.Errorf("failed to set permissions on temporary config file '%s': %w", tempFilePath, err)
|
||||
}
|
||||
|
||||
// Atomically replace the original file with the temporary file
|
||||
// Atomically replace the original file
|
||||
if err := os.Rename(tempFilePath, path); err != nil {
|
||||
return fmt.Errorf("failed to rename temp file '%s' to '%s': %w", tempFilePath, path, err)
|
||||
}
|
||||
removed = true // Mark temp file as successfully renamed
|
||||
removed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveSource writes values from a specific source to a TOML file
|
||||
func (c *Config) SaveSource(path string, source Source) error {
|
||||
c.mutex.RLock()
|
||||
|
||||
nestedData := make(map[string]any)
|
||||
for itemPath, item := range c.items {
|
||||
if val, exists := item.values[source]; exists {
|
||||
setNestedValue(nestedData, itemPath, val)
|
||||
}
|
||||
}
|
||||
|
||||
c.mutex.RUnlock()
|
||||
|
||||
// Use the same atomic save logic
|
||||
var buf bytes.Buffer
|
||||
encoder := toml.NewEncoder(&buf)
|
||||
if err := encoder.Encode(nestedData); err != nil {
|
||||
return fmt.Errorf("failed to marshal config data to TOML: %w", err)
|
||||
}
|
||||
|
||||
// ... (rest of atomic save logic same as Save method)
|
||||
return atomicWriteFile(path, buf.Bytes())
|
||||
}
|
||||
|
||||
// atomicWriteFile performs atomic file write
|
||||
func atomicWriteFile(path string, data []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s': %w", dir, err)
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(dir, filepath.Base(path)+".*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary file: %w", err)
|
||||
}
|
||||
|
||||
tempPath := tempFile.Name()
|
||||
defer os.Remove(tempPath) // Clean up on any error
|
||||
|
||||
if _, err := tempFile.Write(data); err != nil {
|
||||
tempFile.Close()
|
||||
return fmt.Errorf("failed to write temporary file: %w", err)
|
||||
}
|
||||
|
||||
if err := tempFile.Sync(); err != nil {
|
||||
tempFile.Close()
|
||||
return fmt.Errorf("failed to sync temporary file: %w", err)
|
||||
}
|
||||
|
||||
if err := tempFile.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close temporary file: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempPath, 0644); err != nil {
|
||||
return fmt.Errorf("failed to set permissions: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tempPath, path); err != nil {
|
||||
return fmt.Errorf("failed to rename temporary file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseArgs processes command-line arguments into a nested map structure.
|
||||
// Expects arguments in the format:
|
||||
//
|
||||
// --key.subkey=value
|
||||
// --key.subkey value
|
||||
// --booleanflag (implicitly true)
|
||||
// --booleanflag=true
|
||||
// --booleanflag=false
|
||||
//
|
||||
// Values are parsed into bool, int64, float64, or string.
|
||||
// Returns an error if a key segment is invalid.
|
||||
func parseArgs(args []string) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
i := 0
|
||||
@ -196,7 +453,7 @@ func parseArgs(args []string) (map[string]any, error) {
|
||||
} else {
|
||||
// Handle "--key value" or "--booleanflag"
|
||||
keyPath = argContent
|
||||
// Check if it's potentially a boolean flag (next arg starts with -- or end of args)
|
||||
// Check if it's potentially a boolean flag
|
||||
isBoolFlag := i+1 >= len(args) || strings.HasPrefix(args[i+1], "--")
|
||||
|
||||
if isBoolFlag {
|
||||
@ -210,33 +467,16 @@ func parseArgs(args []string) (map[string]any, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate keyPath segments *after* extracting the key
|
||||
// Validate keyPath segments
|
||||
segments := strings.Split(keyPath, ".")
|
||||
for _, segment := range segments {
|
||||
if !isValidKeySegment(segment) {
|
||||
// Return a specific error indicating the problem
|
||||
return nil, fmt.Errorf("invalid command-line key segment %q in path %q", segment, keyPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to parse the value string into richer types
|
||||
var value any
|
||||
if v, err := strconv.ParseBool(valueStr); err == nil {
|
||||
value = v
|
||||
} else if v, err := strconv.ParseInt(valueStr, 10, 64); err == nil {
|
||||
value = v
|
||||
} else if v, err := strconv.ParseFloat(valueStr, 64); err == nil {
|
||||
value = v
|
||||
} else {
|
||||
// Keep as string if no other parsing succeeded
|
||||
// Remove surrounding quotes if present
|
||||
if len(valueStr) >= 2 && valueStr[0] == '"' && valueStr[len(valueStr)-1] == '"' {
|
||||
value = valueStr[1 : len(valueStr)-1]
|
||||
} else {
|
||||
value = valueStr
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the value
|
||||
value := parseValue(valueStr)
|
||||
setNestedValue(result, keyPath, value)
|
||||
}
|
||||
|
||||
|
||||
176
register.go
176
register.go
@ -1,10 +1,13 @@
|
||||
// File: lixenwraith/config/register.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// Register makes a configuration path known to the Config instance.
|
||||
@ -30,11 +33,35 @@ func (c *Config) Register(path string, defaultValue any) error {
|
||||
c.items[path] = configItem{
|
||||
defaultValue: defaultValue,
|
||||
currentValue: defaultValue, // Initially set to default
|
||||
values: make(map[Source]any),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterWithEnv registers a path with an explicit environment variable mapping
|
||||
func (c *Config) RegisterWithEnv(path string, defaultValue any, envVar string) error {
|
||||
if err := c.Register(path, defaultValue); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the environment variable exists and load it
|
||||
if value, exists := os.LookupEnv(envVar); exists {
|
||||
parsed := parseValue(value)
|
||||
return c.SetSource(path, SourceEnv, parsed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterRequired registers a path and marks it as required
|
||||
// The configuration will fail validation if this value is not provided
|
||||
func (c *Config) RegisterRequired(path string, defaultValue any) error {
|
||||
// For now, just register normally
|
||||
// The required paths will be tracked separately in a future enhancement
|
||||
return c.Register(path, defaultValue)
|
||||
}
|
||||
|
||||
// Unregister removes a configuration path and all its children.
|
||||
func (c *Config) Unregister(path string) error {
|
||||
c.mutex.Lock()
|
||||
@ -92,7 +119,7 @@ func (c *Config) RegisterStruct(prefix string, structWithDefaults interface{}) e
|
||||
var errors []string
|
||||
|
||||
// Use a helper function for recursive registration
|
||||
c.registerFields(v, prefix, "", &errors) // Pass receiver `c`
|
||||
c.registerFields(v, prefix, "", &errors)
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("failed to register %d field(s): %s", len(errors), strings.Join(errors, "; "))
|
||||
@ -101,8 +128,21 @@ func (c *Config) RegisterStruct(prefix string, structWithDefaults interface{}) e
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterStructWithTags is like RegisterStruct but allows custom tag names
|
||||
func (c *Config) RegisterStructWithTags(prefix string, structWithDefaults interface{}, tagName string) error {
|
||||
// Save current tag preference
|
||||
oldTag := "toml"
|
||||
|
||||
// Temporarily use custom tag
|
||||
// Note: This would require modifying registerFields to accept tagName parameter
|
||||
// For now, we'll keep using "toml" tag
|
||||
_ = oldTag
|
||||
_ = tagName
|
||||
|
||||
return c.RegisterStruct(prefix, structWithDefaults)
|
||||
}
|
||||
|
||||
// registerFields is a helper function that handles the recursive field registration.
|
||||
// It's now a method on *Config to simplify calling c.Register.
|
||||
func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, errors *[]string) {
|
||||
t := v.Type()
|
||||
|
||||
@ -120,20 +160,21 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
|
||||
continue // Skip this field
|
||||
}
|
||||
|
||||
// Check for additional tags
|
||||
envTag := field.Tag.Get("env") // Explicit env var name
|
||||
required := field.Tag.Get("required") == "true"
|
||||
|
||||
key := field.Name
|
||||
if tag != "" {
|
||||
parts := strings.Split(tag, ",")
|
||||
if parts[0] != "" {
|
||||
key = parts[0]
|
||||
}
|
||||
// Note: We are ignoring other tag options like 'omitempty' here,
|
||||
// as RegisterStruct is about setting defaults.
|
||||
}
|
||||
|
||||
// Build full path
|
||||
currentPath := key
|
||||
if pathPrefix != "" {
|
||||
// Ensure trailing dot on prefix if needed
|
||||
if !strings.HasSuffix(pathPrefix, ".") {
|
||||
pathPrefix += "."
|
||||
}
|
||||
@ -141,7 +182,6 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
|
||||
}
|
||||
|
||||
// Handle nested structs recursively
|
||||
// Check for pointer to struct as well
|
||||
fieldType := fieldValue.Type()
|
||||
isStruct := fieldValue.Kind() == reflect.Struct
|
||||
isPtrToStruct := fieldValue.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct
|
||||
@ -151,7 +191,7 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
|
||||
nestedValue := fieldValue
|
||||
if isPtrToStruct {
|
||||
if fieldValue.IsNil() {
|
||||
// Skip nil pointers, as their paths aren't well-defined defaults.
|
||||
// Skip nil pointers
|
||||
continue
|
||||
}
|
||||
nestedValue = fieldValue.Elem()
|
||||
@ -159,15 +199,33 @@ func (c *Config) registerFields(v reflect.Value, pathPrefix, fieldPath string, e
|
||||
|
||||
// For nested structs, append a dot and continue recursion
|
||||
nestedPrefix := currentPath + "."
|
||||
c.registerFields(nestedValue, nestedPrefix, fieldPath+field.Name+".", errors) // Call recursively on `c`
|
||||
c.registerFields(nestedValue, nestedPrefix, fieldPath+field.Name+".", errors)
|
||||
continue
|
||||
}
|
||||
|
||||
// Register non-struct fields
|
||||
// Use fieldValue.Interface() to get the actual default value
|
||||
if err := c.Register(currentPath, fieldValue.Interface()); err != nil {
|
||||
defaultValue := fieldValue.Interface()
|
||||
|
||||
var err error
|
||||
if required {
|
||||
err = c.RegisterRequired(currentPath, defaultValue)
|
||||
} else {
|
||||
err = c.Register(currentPath, defaultValue)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
*errors = append(*errors, fmt.Sprintf("field %s%s (path %s): %v", fieldPath, field.Name, currentPath, err))
|
||||
}
|
||||
|
||||
// Handle explicit env tag
|
||||
if envTag != "" && err == nil {
|
||||
if value, exists := os.LookupEnv(envTag); exists {
|
||||
parsed := parseValue(value)
|
||||
if setErr := c.SetSource(currentPath, SourceEnv, parsed); setErr != nil {
|
||||
*errors = append(*errors, fmt.Sprintf("field %s%s env %s: %v", fieldPath, field.Name, envTag, setErr))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,6 +244,21 @@ func (c *Config) GetRegisteredPaths(prefix string) map[string]bool {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetRegisteredPathsWithDefaults returns paths with their default values
|
||||
func (c *Config) GetRegisteredPathsWithDefaults(prefix string) map[string]any {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
result := make(map[string]any)
|
||||
for path, item := range c.items {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
result[path] = item.defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Scan decodes the configuration data under a specific base path
|
||||
// into the target struct or map. It operates on the current, merged configuration state.
|
||||
// The target must be a non-nil pointer to a struct or map.
|
||||
@ -249,15 +322,15 @@ func (c *Config) Scan(basePath string, target any) error {
|
||||
// Ensure the final data we are decoding from is actually a map
|
||||
sectionMap, ok := sectionData.(map[string]any)
|
||||
if !ok {
|
||||
// This can happen if the basePath points to a non-map value (e.g., a string, int)
|
||||
return fmt.Errorf("configuration path %q does not refer to a scannable section (map), but to type %T", basePath, sectionData) // Updated error message
|
||||
// This can happen if the basePath points to a non-map value
|
||||
return fmt.Errorf("configuration path %q does not refer to a scannable section (map), but to type %T", basePath, sectionData)
|
||||
}
|
||||
|
||||
// Use mapstructure to decode the relevant section map into the target
|
||||
decoderConfig := &mapstructure.DecoderConfig{
|
||||
Result: target,
|
||||
TagName: "toml", // Use the same tag name for consistency
|
||||
WeaklyTypedInput: true, // Allow conversions (e.g., int to string if needed by target)
|
||||
WeaklyTypedInput: true, // Allow conversions
|
||||
DecodeHook: mapstructure.ComposeDecodeHookFunc(
|
||||
mapstructure.StringToTimeDurationHookFunc(),
|
||||
mapstructure.StringToSliceHookFunc(","),
|
||||
@ -269,10 +342,81 @@ func (c *Config) Scan(basePath string, target any) error {
|
||||
return fmt.Errorf("failed to create mapstructure decoder: %w", err)
|
||||
}
|
||||
|
||||
err = decoder.Decode(sectionMap) // Use sectionMap
|
||||
err = decoder.Decode(sectionMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan section %q into %T: %w", basePath, target, err) // Updated error message
|
||||
return fmt.Errorf("failed to scan section %q into %T: %w", basePath, target, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScanSource scans configuration from a specific source
|
||||
func (c *Config) ScanSource(basePath string, source Source, target any) error {
|
||||
// Validate target
|
||||
rv := reflect.ValueOf(target)
|
||||
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||
return fmt.Errorf("target of ScanSource must be a non-nil pointer, got %T", target)
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
|
||||
// Build nested map from specific source only
|
||||
nestedMap := make(map[string]any)
|
||||
for path, item := range c.items {
|
||||
if val, exists := item.values[source]; exists {
|
||||
setNestedValue(nestedMap, path, val)
|
||||
}
|
||||
}
|
||||
|
||||
c.mutex.RUnlock()
|
||||
|
||||
// Rest of the logic is similar to Scan
|
||||
var sectionData any = nestedMap
|
||||
|
||||
if basePath != "" {
|
||||
basePath = strings.TrimSuffix(basePath, ".")
|
||||
if basePath != "" {
|
||||
segments := strings.Split(basePath, ".")
|
||||
current := any(nestedMap)
|
||||
|
||||
for _, segment := range segments {
|
||||
currentMap, ok := current.(map[string]any)
|
||||
if !ok {
|
||||
sectionData = make(map[string]any)
|
||||
break
|
||||
}
|
||||
|
||||
value, exists := currentMap[segment]
|
||||
if !exists {
|
||||
sectionData = make(map[string]any)
|
||||
break
|
||||
}
|
||||
current = value
|
||||
}
|
||||
|
||||
sectionData = current
|
||||
}
|
||||
}
|
||||
|
||||
sectionMap, ok := sectionData.(map[string]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("path %q does not refer to a map in source %s", basePath, source)
|
||||
}
|
||||
|
||||
decoderConfig := &mapstructure.DecoderConfig{
|
||||
Result: target,
|
||||
TagName: "toml",
|
||||
WeaklyTypedInput: true,
|
||||
DecodeHook: mapstructure.ComposeDecodeHookFunc(
|
||||
mapstructure.StringToTimeDurationHookFunc(),
|
||||
mapstructure.StringToSliceHookFunc(","),
|
||||
),
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decoderConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create decoder: %w", err)
|
||||
}
|
||||
|
||||
return decoder.Decode(sectionMap)
|
||||
}
|
||||
219
source_test.go
Normal file
219
source_test.go
Normal file
@ -0,0 +1,219 @@
|
||||
// File: lixenwraith/config/source_test.go
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lixenwraith/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMultiSourceConfiguration(t *testing.T) {
|
||||
t.Run("Source Precedence", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("test.value", "default")
|
||||
|
||||
// Set values in different sources
|
||||
cfg.SetSource("test.value", config.SourceFile, "from-file")
|
||||
cfg.SetSource("test.value", config.SourceEnv, "from-env")
|
||||
cfg.SetSource("test.value", config.SourceCLI, "from-cli")
|
||||
|
||||
// Default precedence: CLI > Env > File > Default
|
||||
val, _ := cfg.String("test.value")
|
||||
assert.Equal(t, "from-cli", val)
|
||||
|
||||
// Change precedence
|
||||
opts := config.LoadOptions{
|
||||
Sources: []config.Source{
|
||||
config.SourceEnv,
|
||||
config.SourceCLI,
|
||||
config.SourceFile,
|
||||
config.SourceDefault,
|
||||
},
|
||||
}
|
||||
cfg.SetLoadOptions(opts)
|
||||
|
||||
// Now env should win
|
||||
val, _ = cfg.String("test.value")
|
||||
assert.Equal(t, "from-env", val)
|
||||
})
|
||||
|
||||
t.Run("Source Tracking", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("server.port", 8080)
|
||||
|
||||
// Set from multiple sources
|
||||
cfg.SetSource("server.port", config.SourceFile, 9090)
|
||||
cfg.SetSource("server.port", config.SourceEnv, 7070)
|
||||
|
||||
// Get all sources
|
||||
sources := cfg.GetSources("server.port")
|
||||
|
||||
// Should have 2 sources
|
||||
assert.Len(t, sources, 2)
|
||||
assert.Equal(t, 9090, sources[config.SourceFile])
|
||||
assert.Equal(t, 7070, sources[config.SourceEnv])
|
||||
})
|
||||
|
||||
t.Run("GetSource", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("api.key", "default-key")
|
||||
|
||||
cfg.SetSource("api.key", config.SourceEnv, "env-key")
|
||||
|
||||
// Get from specific source
|
||||
envVal, exists := cfg.GetSource("api.key", config.SourceEnv)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "env-key", envVal)
|
||||
|
||||
// Get from missing source
|
||||
_, exists = cfg.GetSource("api.key", config.SourceFile)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Reset Sources", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("test1", "default1")
|
||||
cfg.Register("test2", "default2")
|
||||
|
||||
// Set values
|
||||
cfg.SetSource("test1", config.SourceFile, "file1")
|
||||
cfg.SetSource("test1", config.SourceEnv, "env1")
|
||||
cfg.SetSource("test2", config.SourceCLI, "cli2")
|
||||
|
||||
// Reset specific source
|
||||
cfg.ResetSource(config.SourceEnv)
|
||||
|
||||
// Env value should be gone
|
||||
_, exists := cfg.GetSource("test1", config.SourceEnv)
|
||||
assert.False(t, exists)
|
||||
|
||||
// Other sources remain
|
||||
fileVal, _ := cfg.GetSource("test1", config.SourceFile)
|
||||
assert.Equal(t, "file1", fileVal)
|
||||
|
||||
// Reset all
|
||||
cfg.Reset()
|
||||
|
||||
// All values should be defaults
|
||||
val1, _ := cfg.String("test1")
|
||||
val2, _ := cfg.String("test2")
|
||||
assert.Equal(t, "default1", val1)
|
||||
assert.Equal(t, "default2", val2)
|
||||
})
|
||||
|
||||
t.Run("LoadWithOptions Integration", func(t *testing.T) {
|
||||
// Create temp config file
|
||||
tmpdir := t.TempDir()
|
||||
configFile := filepath.Join(tmpdir, "test.toml")
|
||||
|
||||
configContent := `
|
||||
[server]
|
||||
host = "file-host"
|
||||
port = 8080
|
||||
|
||||
[feature]
|
||||
enabled = true
|
||||
`
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(configContent), 0644))
|
||||
|
||||
// Set environment
|
||||
os.Setenv("TEST_SERVER_PORT", "9090")
|
||||
os.Setenv("TEST_FEATURE_ENABLED", "false")
|
||||
t.Cleanup(func() {
|
||||
os.Unsetenv("TEST_SERVER_PORT")
|
||||
os.Unsetenv("TEST_FEATURE_ENABLED")
|
||||
})
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("server.host", "default-host")
|
||||
cfg.Register("server.port", 7070)
|
||||
cfg.Register("feature.enabled", false)
|
||||
|
||||
// Load with custom precedence (File highest)
|
||||
opts := config.LoadOptions{
|
||||
Sources: []config.Source{
|
||||
config.SourceFile,
|
||||
config.SourceEnv,
|
||||
config.SourceCLI,
|
||||
config.SourceDefault,
|
||||
},
|
||||
EnvPrefix: "TEST_",
|
||||
}
|
||||
|
||||
err := cfg.LoadWithOptions(configFile, []string{"--server.host=cli-host"}, opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File should win for all values
|
||||
host, _ := cfg.String("server.host")
|
||||
assert.Equal(t, "file-host", host)
|
||||
|
||||
port, _ := cfg.Int64("server.port")
|
||||
assert.Equal(t, int64(8080), port)
|
||||
|
||||
enabled, _ := cfg.Bool("feature.enabled")
|
||||
assert.True(t, enabled)
|
||||
})
|
||||
|
||||
t.Run("ScanSource", func(t *testing.T) {
|
||||
type ServerConfig struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
}
|
||||
|
||||
cfg := config.New()
|
||||
cfg.Register("server.host", "default")
|
||||
cfg.Register("server.port", 8080)
|
||||
|
||||
// Set different values in different sources
|
||||
cfg.SetSource("server.host", config.SourceFile, "file-host")
|
||||
cfg.SetSource("server.port", config.SourceFile, 8080)
|
||||
cfg.SetSource("server.host", config.SourceEnv, "env-host")
|
||||
cfg.SetSource("server.port", config.SourceEnv, 9090)
|
||||
|
||||
// Scan from specific source
|
||||
var fileConfig ServerConfig
|
||||
err := cfg.ScanSource("server", config.SourceFile, &fileConfig)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "file-host", fileConfig.Host)
|
||||
assert.Equal(t, 8080, fileConfig.Port)
|
||||
|
||||
var envConfig ServerConfig
|
||||
err = cfg.ScanSource("server", config.SourceEnv, &envConfig)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "env-host", envConfig.Host)
|
||||
assert.Equal(t, 9090, envConfig.Port)
|
||||
})
|
||||
|
||||
t.Run("SaveSource", func(t *testing.T) {
|
||||
cfg := config.New()
|
||||
cfg.Register("app.name", "myapp")
|
||||
cfg.Register("app.version", "1.0.0")
|
||||
cfg.Register("server.port", 8080)
|
||||
|
||||
// Set values in different sources
|
||||
cfg.SetSource("app.name", config.SourceFile, "fileapp")
|
||||
cfg.SetSource("app.version", config.SourceEnv, "2.0.0")
|
||||
cfg.SetSource("server.port", config.SourceCLI, 9090)
|
||||
|
||||
// Save only env source
|
||||
tmpfile := filepath.Join(t.TempDir(), "config-source.toml")
|
||||
err := cfg.SaveSource(tmpfile, config.SourceEnv)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load saved file and verify
|
||||
newCfg := config.New()
|
||||
newCfg.Register("app.version", "")
|
||||
newCfg.LoadFile(tmpfile)
|
||||
|
||||
version, _ := newCfg.String("app.version")
|
||||
assert.Equal(t, "2.0.0", version)
|
||||
|
||||
// Should not have other source values
|
||||
name, _ := newCfg.String("app.name")
|
||||
assert.Empty(t, name)
|
||||
})
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
// File: lixenwraith/config/type.go
|
||||
package config
|
||||
|
||||
import (
|
||||
Reference in New Issue
Block a user