e6.0.0 Added file format change and security option support.

This commit is contained in:
2025-08-26 15:07:10 -04:00
parent 3aa2ab30d6
commit 112426b43f
9 changed files with 802 additions and 67 deletions

View File

@ -15,6 +15,8 @@ type Builder struct {
opts LoadOptions opts LoadOptions
defaults any defaults any
tagName string tagName string
fileFormat string
securityOpts *SecurityOptions
prefix string prefix string
file string file string
args []string args []string
@ -50,6 +52,14 @@ func (b *Builder) Build() (*Config, error) {
tagName = "toml" tagName = "toml"
} }
// Set format and security settings
if b.fileFormat != "" {
b.cfg.fileFormat = b.fileFormat
}
if b.securityOpts != nil {
b.cfg.securityOpts = b.securityOpts
}
// 1. Register defaults // 1. Register defaults
// If WithDefaults() was called, it takes precedence. // If WithDefaults() was called, it takes precedence.
// If not, but WithTarget() was called, use the target struct for defaults. // If not, but WithTarget() was called, use the target struct for defaults.
@ -148,6 +158,23 @@ func (b *Builder) WithTagName(tagName string) *Builder {
return b return b
} }
// WithFileFormat sets the expected file format
func (b *Builder) WithFileFormat(format string) *Builder {
switch format {
case "toml", "json", "yaml", "auto":
b.fileFormat = format
default:
b.err = fmt.Errorf("unsupported file format %q", format)
}
return b
}
// WithSecurityOptions sets security options for file loading
func (b *Builder) WithSecurityOptions(opts SecurityOptions) *Builder {
b.securityOpts = &opts
return b
}
// WithPrefix sets the prefix for struct registration // WithPrefix sets the prefix for struct registration
func (b *Builder) WithPrefix(prefix string) *Builder { func (b *Builder) WithPrefix(prefix string) *Builder {
b.prefix = prefix b.prefix = prefix

View File

@ -203,7 +203,8 @@ func TestBuilder(t *testing.T) {
func TestFileDiscovery(t *testing.T) { func TestFileDiscovery(t *testing.T) {
t.Run("DiscoveryWithCLIFlag", func(t *testing.T) { t.Run("DiscoveryWithCLIFlag", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "custom.conf") // Use .toml extension for TOML content
configFile := filepath.Join(tmpDir, "custom.toml")
os.WriteFile(configFile, []byte(`test = "value"`), 0644) os.WriteFile(configFile, []byte(`test = "value"`), 0644)
opts := DefaultDiscoveryOptions("myapp") opts := DefaultDiscoveryOptions("myapp")
@ -223,6 +224,7 @@ func TestFileDiscovery(t *testing.T) {
assert.Equal(t, "value", val) assert.Equal(t, "value", val)
}) })
// Rest of test cases remain the same...
t.Run("DiscoveryWithEnvVar", func(t *testing.T) { t.Run("DiscoveryWithEnvVar", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "env.toml") configFile := filepath.Join(tmpDir, "env.toml")

View File

@ -47,19 +47,28 @@ type structCache struct {
mu sync.RWMutex mu sync.RWMutex
} }
// SecurityOptions for enhanced file loading security
type SecurityOptions struct {
PreventPathTraversal bool // Prevent ../ in paths
EnforceFileOwnership bool // Unix only: ensure file owned by current user
MaxFileSize int64 // Maximum config file size (0 = no limit)
}
// Config manages application configuration. It can be used in two primary ways: // Config manages application configuration. It can be used in two primary ways:
// 1. As a dynamic key-value store, accessed via methods like Get(), String(), and Int64() // 1. As a dynamic key-value store, accessed via methods like Get(), String(), and Int64()
// 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct() // 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct()
type Config struct { type Config struct {
items map[string]configItem items map[string]configItem
tagName string tagName string
mutex sync.RWMutex fileFormat string // Separate from tagName: "toml", "json", "yaml", or "auto"
options LoadOptions // Current load options securityOpts *SecurityOptions
fileData map[string]any // Cached file data mutex sync.RWMutex
envData map[string]any // Cached env data options LoadOptions // Current load options
cliData map[string]any // Cached CLI data fileData map[string]any // Cached file data
version atomic.Int64 envData map[string]any // Cached env data
structCache *structCache cliData map[string]any // Cached CLI data
version atomic.Int64
structCache *structCache
// File watching support // File watching support
watcher *watcher watcher *watcher
@ -69,8 +78,14 @@ type Config struct {
// New creates and initializes a new Config instance. // New creates and initializes a new Config instance.
func New() *Config { func New() *Config {
return &Config{ return &Config{
items: make(map[string]configItem), items: make(map[string]configItem),
tagName: "toml", tagName: "toml",
fileFormat: "auto",
// securityOpts: &SecurityOptions{
// PreventPathTraversal: false,
// EnforceFileOwnership: false,
// MaxFileSize: 0,
// },
options: DefaultLoadOptions(), options: DefaultLoadOptions(),
fileData: make(map[string]any), fileData: make(map[string]any),
envData: make(map[string]any), envData: make(map[string]any),
@ -114,6 +129,30 @@ func (c *Config) computeValue(item configItem) any {
return item.defaultValue return item.defaultValue
} }
// SetFileFormat sets the expected format for configuration files.
// Use "auto" to detect based on file extension.
func (c *Config) SetFileFormat(format string) error {
switch format {
case "toml", "json", "yaml", "auto":
// Valid formats
default:
return fmt.Errorf("unsupported file format %q, must be one of: toml, json, yaml, auto", format)
}
c.mutex.Lock()
defer c.mutex.Unlock()
c.fileFormat = format
return nil
}
// SetSecurityOptions configures security checks for file loading
func (c *Config) SetSecurityOptions(opts SecurityOptions) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.securityOpts = &opts
}
// Get retrieves a configuration value using the path and indicator if the path was registered // Get retrieves a configuration value using the path and indicator if the path was registered
func (c *Config) Get(path string) (any, bool) { func (c *Config) Get(path string) (any, bool) {
c.mutex.RLock() c.mutex.RLock()

View File

@ -2,6 +2,7 @@
package config package config
import ( import (
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@ -119,6 +120,9 @@ func normalizeMap(data any) (map[string]any, error) {
// getDecodeHook returns the composite decode hook for all type conversions // getDecodeHook returns the composite decode hook for all type conversions
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc { func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
return mapstructure.ComposeDecodeHookFunc( return mapstructure.ComposeDecodeHookFunc(
// JSON Number handling
jsonNumberHookFunc(),
// Network types // Network types
stringToNetIPHookFunc(), stringToNetIPHookFunc(),
stringToNetIPNetHookFunc(), stringToNetIPNetHookFunc(),
@ -134,6 +138,41 @@ func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
) )
} }
// jsonNumberHookFunc handles json.Number conversion to appropriate numeric types
func jsonNumberHookFunc() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
// Check if source is json.Number
if f != reflect.TypeOf(json.Number("")) {
return data, nil
}
num := data.(json.Number)
// Convert based on target type
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return num.Int64()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as int64 first, then convert
i, err := num.Int64()
if err != nil {
return nil, err
}
if i < 0 {
return nil, fmt.Errorf("cannot convert negative number to unsigned type")
}
return uint64(i), nil
case reflect.Float32, reflect.Float64:
return num.Float64()
case reflect.String:
return num.String(), nil
default:
// Return as-is for other types
return data, nil
}
}
}
// stringToNetIPHookFunc handles net.IP conversion // stringToNetIPHookFunc handles net.IP conversion
func stringToNetIPHookFunc() mapstructure.DecodeHookFunc { func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) { return func(f reflect.Type, t reflect.Type, data any) (any, error) {

418
dynamic_test.go Normal file
View File

@ -0,0 +1,418 @@
// FILE: lixenwraith/config/dynamic_test.go
package config
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMultiFormatLoading tests loading different config formats
func TestMultiFormatLoading(t *testing.T) {
tmpDir := t.TempDir()
// Create test config in different formats
tomlConfig := `
[server]
host = "toml-host"
port = 8080
[database]
url = "postgres://localhost/toml"
`
jsonConfig := `{
"server": {
"host": "json-host",
"port": 9090
},
"database": {
"url": "postgres://localhost/json"
}
}`
yamlConfig := `
server:
host: yaml-host
port: 7070
database:
url: postgres://localhost/yaml
`
// Write config files
tomlPath := filepath.Join(tmpDir, "config.toml")
jsonPath := filepath.Join(tmpDir, "config.json")
yamlPath := filepath.Join(tmpDir, "config.yaml")
require.NoError(t, os.WriteFile(tomlPath, []byte(tomlConfig), 0644))
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlConfig), 0644))
t.Run("AutoDetectFormats", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
cfg.Register("server.port", 0)
cfg.Register("database.url", "")
// Test TOML
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(tomlPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "toml-host", host)
// Test JSON
require.NoError(t, cfg.LoadFile(jsonPath))
host, _ = cfg.Get("server.host")
assert.Equal(t, "json-host", host)
port, _ := cfg.Get("server.port")
// JSON number should be preserved as json.Number but convertible
switch v := port.(type) {
case json.Number:
// Expected for raw value
assert.Equal(t, json.Number("9090"), v)
case int64:
// Expected after decode hook conversion
assert.Equal(t, int64(9090), v)
case float64:
// Alternative conversion
assert.Equal(t, float64(9090), v)
default:
t.Errorf("Unexpected type for port: %T", port)
}
// Test YAML
require.NoError(t, cfg.LoadFile(yamlPath))
host, _ = cfg.Get("server.host")
assert.Equal(t, "yaml-host", host)
})
t.Run("ExplicitFormat", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
// Force JSON parsing on .conf file
confPath := filepath.Join(tmpDir, "config.conf")
require.NoError(t, os.WriteFile(confPath, []byte(jsonConfig), 0644))
cfg.SetFileFormat("json")
require.NoError(t, cfg.LoadFile(confPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "json-host", host)
})
t.Run("ContentDetection", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
// Ambiguous extension
ambigPath := filepath.Join(tmpDir, "config.conf")
require.NoError(t, os.WriteFile(ambigPath, []byte(yamlConfig), 0644))
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(ambigPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "yaml-host", host)
})
}
// TestDynamicFormatSwitching tests runtime format changes
func TestDynamicFormatSwitching(t *testing.T) {
tmpDir := t.TempDir()
// Create configs in different formats with same structure
configs := map[string]string{
"toml": `value = "from-toml"`,
"json": `{"value": "from-json"}`,
"yaml": `value: from-yaml`,
}
cfg := New()
cfg.Register("value", "default")
for format, content := range configs {
t.Run(format, func(t *testing.T) {
filePath := filepath.Join(tmpDir, "config."+format)
require.NoError(t, os.WriteFile(filePath, []byte(content), 0644))
// Set format and load
require.NoError(t, cfg.SetFileFormat(format))
require.NoError(t, cfg.LoadFile(filePath))
val, _ := cfg.Get("value")
assert.Equal(t, "from-"+format, val)
})
}
}
// TestWatchFileFormatSwitch tests watching different file formats
func TestWatchFileFormatSwitch(t *testing.T) {
tmpDir := t.TempDir()
tomlPath := filepath.Join(tmpDir, "config.toml")
jsonPath := filepath.Join(tmpDir, "config.json")
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-1"`), 0644))
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-1"}`), 0644))
cfg := New()
cfg.Register("value", "default")
// Configure fast polling for test
opts := WatchOptions{
PollInterval: testPollInterval, // Fast polling for tests
Debounce: testDebounce, // Short debounce
MaxWatchers: 10,
}
// Start watching TOML
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(tomlPath))
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
// Wait for watcher to start
require.Eventually(t, func() bool {
return cfg.IsWatching()
}, 4*testDebounce, 2*SpinWaitInterval)
val, _ := cfg.Get("value")
assert.Equal(t, "toml-1", val)
// Switch to JSON with format hint
require.NoError(t, cfg.WatchFile(jsonPath, "json"))
// Wait for new watcher to start
require.Eventually(t, func() bool {
return cfg.IsWatching()
}, 4*testDebounce, 2*SpinWaitInterval)
// Get watch channel AFTER switching files
changes := cfg.Watch()
val, _ = cfg.Get("value")
assert.Equal(t, "json-1", val)
// Update JSON file
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-2"}`), 0644))
// Wait for change notification
select {
case path := <-changes:
assert.Equal(t, "value", path)
// Wait a bit for value to be updated
require.Eventually(t, func() bool {
val, _ := cfg.Get("value")
return val == "json-2"
}, testEventuallyTimeout, 2*SpinWaitInterval)
case <-time.After(testWatchTimeout):
t.Error("Timeout waiting for JSON file change")
}
// Update old TOML file - should NOT trigger notification
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-2"`), 0644))
// Should not receive notification from old file
select {
case <-changes:
t.Error("Should not receive changes from old TOML file")
case <-time.After(testPollWindow):
// Expected - no change notification
}
}
// TestSecurityOptions tests security features
func TestSecurityOptions(t *testing.T) {
tmpDir := t.TempDir()
t.Run("PathTraversal", func(t *testing.T) {
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
PreventPathTraversal: true,
})
// Test various malicious paths
maliciousPaths := []string{
"../../../etc/passwd",
"./../etc/passwd",
"config/../../../etc/passwd",
filepath.Join("..", "..", "etc", "passwd"),
}
for _, malPath := range maliciousPaths {
err := cfg.LoadFile(malPath)
assert.Error(t, err, "Should reject path: %s", malPath)
assert.Contains(t, err.Error(), "path traversal")
}
// Valid paths should work
validPath := filepath.Join(tmpDir, "config.toml")
os.WriteFile(validPath, []byte(`test = "value"`), 0644)
cfg.Register("test", "")
err := cfg.LoadFile(validPath)
assert.NoError(t, err, "Should accept valid absolute path")
})
t.Run("FileSizeLimit", func(t *testing.T) {
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
MaxFileSize: 100, // 100 bytes limit
})
// Create large file
largePath := filepath.Join(tmpDir, "large.toml")
largeContent := make([]byte, 1024)
for i := range largeContent {
largeContent[i] = 'a'
}
require.NoError(t, os.WriteFile(largePath, largeContent, 0644))
err := cfg.LoadFile(largePath)
assert.Error(t, err)
assert.Contains(t, err.Error(), "exceeds maximum size")
})
t.Run("FileOwnership", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping ownership test on Windows")
}
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
EnforceFileOwnership: true,
})
// Create file owned by current user (should succeed)
ownedPath := filepath.Join(tmpDir, "owned.toml")
require.NoError(t, os.WriteFile(ownedPath, []byte(`test = "value"`), 0644))
cfg.Register("test", "")
err := cfg.LoadFile(ownedPath)
assert.NoError(t, err)
})
}
// waitForWatchingState waits for watcher state, preventing race conditions of goroutine start and test check
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
t.Helper()
require.Eventually(t, func() bool {
return cfg.IsWatching() == expected
}, testEventuallyTimeout, 2*SpinWaitInterval, msgAndArgs...)
}
// TestBuilderWithFormat tests Builder integration
func TestBuilderWithFormat(t *testing.T) {
tmpDir := t.TempDir()
jsonPath := filepath.Join(tmpDir, "config.json")
jsonConfig := `{
"server": {
"host": "builder-host",
"port": 8080
}
}`
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
type Config struct {
Server struct {
Host string `json:"host" toml:"host"`
Port int `json:"port" toml:"port"`
} `json:"server" toml:"server"`
}
defaults := &Config{}
defaults.Server.Host = "default-host"
defaults.Server.Port = 3000
cfg, err := NewBuilder().
WithDefaults(defaults).
WithFile(jsonPath).
WithFileFormat("json").
WithTagName("toml"). // Use toml tags for registration
WithSecurityOptions(SecurityOptions{
PreventPathTraversal: true,
MaxFileSize: 1024 * 1024, // 1MB
}).
Build()
require.NoError(t, err)
// Check the value was loaded
host, exists := cfg.Get("server.host")
assert.True(t, exists, "server.host should exist")
assert.Equal(t, "builder-host", host)
port, exists := cfg.Get("server.port")
assert.True(t, exists, "server.port should exist")
// Handle json.Number or converted int
switch v := port.(type) {
case json.Number:
p, _ := v.Int64()
assert.Equal(t, int64(8080), p)
case int64:
assert.Equal(t, int64(8080), v)
case float64:
assert.Equal(t, float64(8080), v)
default:
t.Errorf("Unexpected type for port: %T", port)
}
}
// BenchmarkFormatParsing benchmarks different format parsing speeds
func BenchmarkFormatParsing(b *testing.B) {
tmpDir := b.TempDir()
// Create test data
configs := map[string]string{
"toml": `
[server]
host = "localhost"
port = 8080
[database]
url = "postgres://localhost/db"
[cache]
ttl = 300
`,
"json": `{
"server": {"host": "localhost", "port": 8080},
"database": {"url": "postgres://localhost/db"},
"cache": {"ttl": 300}
}`,
"yaml": `
server:
host: localhost
port: 8080
database:
url: postgres://localhost/db
cache:
ttl: 300
`,
}
for format, content := range configs {
b.Run(format, func(b *testing.B) {
path := filepath.Join(tmpDir, "bench."+format)
os.WriteFile(path, []byte(content), 0644)
cfg := New()
cfg.Register("server.host", "")
cfg.Register("server.port", 0)
cfg.Register("database.url", "")
cfg.Register("cache.ttl", 0)
cfg.SetFileFormat(format)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cfg.LoadFile(path)
}
})
}
}

141
loader.go
View File

@ -3,13 +3,18 @@ package config
import ( import (
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"syscall"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"gopkg.in/yaml.v3"
) )
// Source represents a configuration source, used to define load precedence // Source represents a configuration source, used to define load precedence
@ -143,18 +148,103 @@ func (c *Config) LoadFile(filePath string) error {
// loadFile reads and parses a TOML configuration file // loadFile reads and parses a TOML configuration file
func (c *Config) loadFile(path string) error { func (c *Config) loadFile(path string) error {
// 1. Read and Parse (No Lock) // Security: Path traversal check
fileData, err := os.ReadFile(path) if c.securityOpts != nil && c.securityOpts.PreventPathTraversal {
// Clean the path and check for traversal attempts
cleanPath := filepath.Clean(path)
// Check if cleaned path tries to go outside current directory
if strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) || cleanPath == ".." {
return fmt.Errorf("potential path traversal detected in config path: %s", path)
}
// Also check for absolute paths that might escape jail
if filepath.IsAbs(cleanPath) && filepath.IsAbs(path) {
// Absolute paths are OK if that's what was provided
} else if filepath.IsAbs(cleanPath) && !filepath.IsAbs(path) {
// Relative path became absolute after cleaning - suspicious
return fmt.Errorf("potential path traversal detected in config path: %s", path)
}
}
// Read file with size limit
fileInfo, err := os.Stat(path)
if err != nil { if err != nil {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
return ErrConfigNotFound return ErrConfigNotFound
} }
return fmt.Errorf("failed to stat config file '%s': %w", path, err)
}
// Security: File size check
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
if fileInfo.Size() > c.securityOpts.MaxFileSize {
return fmt.Errorf("config file '%s' exceeds maximum size %d bytes", path, c.securityOpts.MaxFileSize)
}
}
// Security: File ownership check (Unix only)
if c.securityOpts != nil && c.securityOpts.EnforceFileOwnership && runtime.GOOS != "windows" {
if stat, ok := fileInfo.Sys().(*syscall.Stat_t); ok {
if stat.Uid != uint32(os.Geteuid()) {
return fmt.Errorf("config file '%s' is not owned by current user (file UID: %d, process UID: %d)",
path, stat.Uid, os.Geteuid())
}
}
}
// 1. Read and parse file data
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("failed to open config file '%s': %w", path, err)
}
defer file.Close()
// Use LimitedReader for additional safety
var reader io.Reader = file
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
reader = io.LimitReader(file, c.securityOpts.MaxFileSize)
}
fileData, err := io.ReadAll(reader)
if err != nil {
return fmt.Errorf("failed to read config file '%s': %w", path, err) return fmt.Errorf("failed to read config file '%s': %w", path, err)
} }
// Determine format
format := c.fileFormat
if format == "" || format == "auto" {
// Try extension first
format = detectFileFormat(path)
if format == "" {
// Fall back to content detection
format = detectFormatFromContent(fileData)
if format == "" {
// Last resort: use tagName as hint
format = c.tagName
}
}
}
// Parse based on detected/specified format
fileConfig := make(map[string]any) fileConfig := make(map[string]any)
if err := toml.Unmarshal(fileData, &fileConfig); err != nil { switch format {
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err) case "toml":
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
}
case "json":
decoder := json.NewDecoder(bytes.NewReader(fileData))
decoder.UseNumber() // Preserve number precision
if err := decoder.Decode(&fileConfig); err != nil {
return fmt.Errorf("failed to parse JSON config file '%s': %w", path, err)
}
case "yaml":
if err := yaml.Unmarshal(fileData, &fileConfig); err != nil {
return fmt.Errorf("failed to parse YAML config file '%s': %w", path, err)
}
default:
return fmt.Errorf("unable to determine config format for file '%s'", path)
} }
// 2. Prepare New State (Read-Lock Only) // 2. Prepare New State (Read-Lock Only)
@ -185,7 +275,7 @@ func (c *Config) loadFile(path string) error {
} }
apply("", fileConfig) apply("", fileConfig)
// -- 3. Atomically Update Config (Write-Lock) // 3. Atomically Update Config (Write-Lock)
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
@ -578,4 +668,45 @@ func parseArgs(args []string) (map[string]any, error) {
} }
return result, nil return result, nil
}
// detectFileFormat determines format from file extension
func detectFileFormat(path string) string {
ext := strings.ToLower(filepath.Ext(path))
switch ext {
case ".toml", ".tml":
return "toml"
case ".json":
return "json"
case ".yaml", ".yml":
return "yaml"
case ".conf", ".config":
// Try to detect from content
return ""
default:
return ""
}
}
// detectFormatFromContent attempts to detect format by parsing
func detectFormatFromContent(data []byte) string {
// Try JSON first (strict format)
var jsonTest any
if err := json.Unmarshal(data, &jsonTest); err == nil {
return "json"
}
// Try YAML (superset of JSON, so check after JSON)
var yamlTest any
if err := yaml.Unmarshal(data, &yamlTest); err == nil {
return "yaml"
}
// Try TOML last
var tomlTest any
if err := toml.Unmarshal(data, &tomlTest); err == nil {
return "toml"
}
return ""
} }

26
timing.go Normal file
View File

@ -0,0 +1,26 @@
// FILE: lixenwraith/config/timing.go
package config
import "time"
// Core timing constants for production use.
// These define the fundamental timing behavior of the config package.
const (
// File watching intervals (ordered by frequency)
SpinWaitInterval = 5 * time.Millisecond // CPU-friendly busy-wait quantum
MinPollInterval = 100 * time.Millisecond // Hard floor for file stat polling
ShutdownTimeout = 100 * time.Millisecond // Graceful watcher termination window
DefaultDebounce = 500 * time.Millisecond // File change coalescence period
DefaultPollInterval = time.Second // Standard file monitoring frequency
DefaultReloadTimeout = 5 * time.Second // Maximum duration for reload operations
)
// Derived timing relationships for internal use.
// These maintain consistent ratios between related timers.
const (
// shutdownPollCycles defines how many spin-wait cycles comprise a shutdown timeout
shutdownPollCycles = ShutdownTimeout / SpinWaitInterval // = 20 cycles
// debounceSettleMultiplier ensures sufficient time for debounce to complete
debounceSettleMultiplier = 3 // Wait 3x debounce period for value stabilization
)

View File

@ -11,6 +11,8 @@ import (
"time" "time"
) )
const DefaultMaxWatchers = 100 // Prevent resource exhaustion
// WatchOptions configures file watching behavior // WatchOptions configures file watching behavior
type WatchOptions struct { type WatchOptions struct {
// PollInterval for file stat checks (minimum 100ms) // PollInterval for file stat checks (minimum 100ms)
@ -32,10 +34,10 @@ type WatchOptions struct {
// DefaultWatchOptions returns sensible defaults for file watching // DefaultWatchOptions returns sensible defaults for file watching
func DefaultWatchOptions() WatchOptions { func DefaultWatchOptions() WatchOptions {
return WatchOptions{ return WatchOptions{
PollInterval: time.Second, // Check every second PollInterval: DefaultPollInterval,
Debounce: 500 * time.Millisecond, Debounce: DefaultDebounce,
MaxWatchers: 100, // Prevent resource exhaustion MaxWatchers: DefaultMaxWatchers,
ReloadTimeout: 5 * time.Second, ReloadTimeout: DefaultReloadTimeout,
VerifyPermissions: true, VerifyPermissions: true,
} }
} }
@ -71,26 +73,32 @@ func (c *Config) AutoUpdate() {
// AutoUpdateWithOptions enables automatic configuration reloading with custom options // AutoUpdateWithOptions enables automatic configuration reloading with custom options
func (c *Config) AutoUpdateWithOptions(opts WatchOptions) { func (c *Config) AutoUpdateWithOptions(opts WatchOptions) {
// Validate options // Validate options
if opts.PollInterval < 100*time.Millisecond { if opts.PollInterval < MinPollInterval {
opts.PollInterval = 100 * time.Millisecond // Minimum poll interval opts.PollInterval = MinPollInterval
} }
if opts.MaxWatchers <= 0 { if opts.MaxWatchers <= 0 {
opts.MaxWatchers = 100 opts.MaxWatchers = 100
} }
if opts.ReloadTimeout <= 0 { if opts.ReloadTimeout <= 0 {
opts.ReloadTimeout = 5 * time.Second opts.ReloadTimeout = DefaultReloadTimeout
} }
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
// Check if we have a file to watch // Get path of current file to watch
filePath := c.getConfigFilePath() filePath := c.getConfigFilePath()
if filePath == "" { if filePath == "" {
// No file configured, nothing to watch // No file configured, nothing to watch
return return
} }
// Stop existing watcher if path changed
if c.watcher != nil && c.watcher.filePath != filePath {
c.watcher.stop()
c.watcher = nil
}
// Initialize watcher if needed // Initialize watcher if needed
if c.watcher == nil { if c.watcher == nil {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -131,17 +139,24 @@ func (c *Config) Watch() <-chan string {
} }
// WatchFile stops any existing file watcher, loads a new configuration file, // WatchFile stops any existing file watcher, loads a new configuration file,
// and starts a new watcher on that file path. // and starts a new watcher on that file path. Optionally accepts format hint.
func (c *Config) WatchFile(filePath string) error { func (c *Config) WatchFile(filePath string, formatHint ...string) error {
// Stop any currently running watcher to prevent orphaned goroutines. // Stop any currently running watcher
c.StopAutoUpdate() c.StopAutoUpdate()
// Load the new file and set `configFilePath` to the new path // Set format hint if provided
if len(formatHint) > 0 {
if err := c.SetFileFormat(formatHint[0]); err != nil {
return fmt.Errorf("invalid format hint: %w", err)
}
}
// Load the new file
if err := c.LoadFile(filePath); err != nil { if err := c.LoadFile(filePath); err != nil {
return fmt.Errorf("failed to load new file for watching: %w", err) return fmt.Errorf("failed to load new file for watching: %w", err)
} }
// Start a new watcher on the new file // Get previous watcher options if available
c.mutex.RLock() c.mutex.RLock()
opts := DefaultWatchOptions() opts := DefaultWatchOptions()
if c.watcher != nil { if c.watcher != nil {
@ -149,18 +164,36 @@ func (c *Config) WatchFile(filePath string) error {
} }
c.mutex.RUnlock() c.mutex.RUnlock()
// Start new watcher (AutoUpdateWithOptions will create a new watcher with the new file path)
c.AutoUpdateWithOptions(opts) c.AutoUpdateWithOptions(opts)
return nil return nil
} }
// WatchWithOptions returns a channel with custom watch options // WatchWithOptions returns a channel with custom watch options
// should not restart the watcher if it's already running with the same file
func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string { func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string {
c.mutex.RLock()
watcher := c.watcher
filePath := c.configFilePath
c.mutex.RUnlock()
// If no file configured, return closed channel
if filePath == "" {
ch := make(chan string)
close(ch)
return ch
}
// If watcher exists and is watching the current file, just subscribe
if watcher != nil && watcher.filePath == filePath && watcher.watching.Load() {
return watcher.subscribe()
}
// First ensure auto-update is running // First ensure auto-update is running
c.AutoUpdateWithOptions(opts) c.AutoUpdateWithOptions(opts)
c.mutex.RLock() c.mutex.RLock()
watcher := c.watcher watcher = c.watcher
c.mutex.RUnlock() c.mutex.RUnlock()
if watcher == nil { if watcher == nil {
@ -363,18 +396,22 @@ func (w *watcher) notifyWatchers(path string) {
// stop terminates the watcher // stop terminates the watcher
func (w *watcher) stop() { func (w *watcher) stop() {
w.cancel() if w.cancel != nil {
w.cancel()
}
// Stop debounce timer // Stop debounce timer
w.mu.Lock() w.mu.Lock()
if w.debounceTimer != nil { if w.debounceTimer != nil {
w.debounceTimer.Stop() w.debounceTimer.Stop()
w.debounceTimer = nil
} }
w.mu.Unlock() w.mu.Unlock()
// Wait for watch loop to exit // Wait for watch loop to exit with timeout
for w.watching.Load() { deadline := time.Now().Add(ShutdownTimeout)
time.Sleep(10 * time.Millisecond) for w.watching.Load() && time.Now().Before(deadline) {
time.Sleep(SpinWaitInterval)
} }
} }

View File

@ -14,6 +14,29 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// Test-specific timing constants derived from production values.
// These accelerate test execution while maintaining timing relationships.
const (
// testAcceleration reduces all intervals by this factor for faster tests
testAcceleration = 10
// Accelerated test timings
testPollInterval = DefaultPollInterval / testAcceleration // 100ms (from 1s)
testDebounce = DefaultDebounce / testAcceleration // 50ms (from 500ms)
testReloadTimeout = DefaultReloadTimeout / testAcceleration // 500ms (from 5s)
testShutdownTimeout = ShutdownTimeout // Keep original for safety
testSpinWaitInterval = SpinWaitInterval // Keep original for CPU efficiency
// Test assertion timeouts
testEventuallyTimeout = testReloadTimeout // Aligns with reload timing
testWatchTimeout = 2 * DefaultPollInterval // 2s for change propagation
// Derived test multipliers with clear purpose
testDebounceSettle = debounceSettleMultiplier * testDebounce // 150ms for debounce verification
testPollWindow = 3 * testPollInterval // 300ms change detection window
testStateStabilize = 4 * testDebounce // 200ms for state convergence
)
// TestAutoUpdate tests automatic configuration reloading // TestAutoUpdate tests automatic configuration reloading
func TestAutoUpdate(t *testing.T) { func TestAutoUpdate(t *testing.T) {
// Setup // Setup
@ -59,8 +82,8 @@ enabled = true
// Enable auto-update with fast polling // Enable auto-update with fast polling
opts := WatchOptions{ opts := WatchOptions{
PollInterval: 100 * time.Millisecond, PollInterval: testPollInterval,
Debounce: 50 * time.Millisecond, Debounce: testDebounce,
MaxWatchers: 10, MaxWatchers: 10,
} }
cfg.AutoUpdateWithOptions(opts) cfg.AutoUpdateWithOptions(opts)
@ -93,7 +116,7 @@ enabled = false
require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644)) require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644))
// Wait for changes to be detected // Wait for changes to be detected
time.Sleep(300 * time.Millisecond) time.Sleep(testPollWindow)
// Verify new values // Verify new values
port, _ = cfg.Get("server.port") port, _ = cfg.Get("server.port")
@ -130,8 +153,8 @@ func TestWatchFileDeleted(t *testing.T) {
// Enable watching // Enable watching
opts := WatchOptions{ opts := WatchOptions{
PollInterval: 100 * time.Millisecond, PollInterval: testPollInterval,
Debounce: 50 * time.Millisecond, Debounce: testDebounce,
} }
cfg.AutoUpdateWithOptions(opts) cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate() defer cfg.StopAutoUpdate()
@ -145,7 +168,7 @@ func TestWatchFileDeleted(t *testing.T) {
select { select {
case path := <-changes: case path := <-changes:
assert.Equal(t, "file_deleted", path) assert.Equal(t, "file_deleted", path)
case <-time.After(500 * time.Millisecond): case <-time.After(testEventuallyTimeout):
t.Error("Timeout waiting for deletion notification") t.Error("Timeout waiting for deletion notification")
} }
} }
@ -169,8 +192,8 @@ func TestWatchPermissionChange(t *testing.T) {
// Enable watching with permission verification // Enable watching with permission verification
opts := WatchOptions{ opts := WatchOptions{
PollInterval: 100 * time.Millisecond, PollInterval: testPollInterval,
Debounce: 50 * time.Millisecond, Debounce: testDebounce,
VerifyPermissions: true, VerifyPermissions: true,
} }
cfg.AutoUpdateWithOptions(opts) cfg.AutoUpdateWithOptions(opts)
@ -185,7 +208,7 @@ func TestWatchPermissionChange(t *testing.T) {
select { select {
case path := <-changes: case path := <-changes:
assert.Equal(t, "permissions_changed", path) assert.Equal(t, "permissions_changed", path)
case <-time.After(500 * time.Millisecond): case <-time.After(testEventuallyTimeout):
t.Error("Timeout waiting for permission change notification") t.Error("Timeout waiting for permission change notification")
} }
} }
@ -203,7 +226,7 @@ func TestMaxWatchers(t *testing.T) {
// Enable watching with low max watchers // Enable watching with low max watchers
opts := WatchOptions{ opts := WatchOptions{
PollInterval: 100 * time.Millisecond, PollInterval: testPollInterval,
MaxWatchers: 3, MaxWatchers: 3,
} }
cfg.AutoUpdateWithOptions(opts) cfg.AutoUpdateWithOptions(opts)
@ -229,7 +252,7 @@ func TestMaxWatchers(t *testing.T) {
select { select {
case _, ok := <-ch: case _, ok := <-ch:
assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)") assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)")
case <-time.After(10 * time.Millisecond): case <-time.After(testEventuallyTimeout):
t.Error("Channel 3 should be closed immediately") t.Error("Channel 3 should be closed immediately")
} }
} }
@ -239,8 +262,8 @@ func TestMaxWatchers(t *testing.T) {
assert.Equal(t, 3, cfg.WatcherCount()) assert.Equal(t, 3, cfg.WatcherCount())
} }
// TestDebounce tests that rapid changes are debounced // TestRapidDebounce tests that rapid changes are debounced
func TestDebounce(t *testing.T) { func TestRapidDebounce(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml") configPath := filepath.Join(tmpDir, "test.toml")
@ -253,8 +276,8 @@ func TestDebounce(t *testing.T) {
// Enable watching with longer debounce // Enable watching with longer debounce
opts := WatchOptions{ opts := WatchOptions{
PollInterval: 50 * time.Millisecond, PollInterval: testDebounce,
Debounce: 200 * time.Millisecond, Debounce: testStateStabilize,
} }
cfg.AutoUpdateWithOptions(opts) cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate() defer cfg.StopAutoUpdate()
@ -282,11 +305,11 @@ func TestDebounce(t *testing.T) {
for i := 2; i <= 5; i++ { for i := 2; i <= 5; i++ {
content := fmt.Sprintf(`value = %d`, i) content := fmt.Sprintf(`value = %d`, i)
require.NoError(t, os.WriteFile(configPath, []byte(content), 0644)) require.NoError(t, os.WriteFile(configPath, []byte(content), 0644))
time.Sleep(50 * time.Millisecond) // Less than debounce period time.Sleep(testDebounce) // Less than debounce period
} }
// Wait for debounce to complete // Wait for debounce to complete
time.Sleep(300 * time.Millisecond) time.Sleep(2 * testStateStabilize)
done <- true done <- true
// Should only see one change due to debounce // Should only see one change due to debounce
@ -328,7 +351,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
require.NoError(t, cfg.LoadFile(configPath)) require.NoError(t, cfg.LoadFile(configPath))
opts := WatchOptions{ opts := WatchOptions{
PollInterval: 50 * time.Millisecond, PollInterval: testDebounce,
MaxWatchers: 50, MaxWatchers: 50,
} }
cfg.AutoUpdateWithOptions(opts) cfg.AutoUpdateWithOptions(opts)
@ -353,7 +376,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
select { select {
case <-ch: case <-ch:
// OK, got a change // OK, got a change
case <-time.After(10 * time.Millisecond): case <-time.After(2 * SpinWaitInterval):
// OK, no changes yet // OK, no changes yet
} }
}(i) }(i)
@ -384,7 +407,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
isWatching = true isWatching = true
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(2 * SpinWaitInterval)
} }
if !isWatching { if !isWatching {
errors <- fmt.Errorf("checker %d: IsWatching returned false", id) errors <- fmt.Errorf("checker %d: IsWatching returned false", id)
@ -418,8 +441,8 @@ func TestReloadTimeout(t *testing.T) {
// Very short timeout // Very short timeout
opts := WatchOptions{ opts := WatchOptions{
PollInterval: 100 * time.Millisecond, PollInterval: testPollInterval,
ReloadTimeout: 1 * time.Nanosecond, // Extremely short ReloadTimeout: 1 * time.Nanosecond,
} }
cfg.AutoUpdateWithOptions(opts) cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate() defer cfg.StopAutoUpdate()
@ -454,7 +477,7 @@ func TestStopAutoUpdate(t *testing.T) {
select { select {
case _, ok := <-ch: case _, ok := <-ch:
assert.False(t, ok, "Channel should be closed after stop") assert.False(t, ok, "Channel should be closed after stop")
case <-time.After(100 * time.Millisecond): case <-time.After(ShutdownTimeout):
// OK, channel might not close immediately // OK, channel might not close immediately
} }
@ -484,7 +507,7 @@ func BenchmarkWatchOverhead(b *testing.B) {
// Enable watching // Enable watching
opts := WatchOptions{ opts := WatchOptions{
PollInterval: 100 * time.Millisecond, PollInterval: testPollInterval,
} }
cfg.AutoUpdateWithOptions(opts) cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate() defer cfg.StopAutoUpdate()
@ -494,11 +517,4 @@ func BenchmarkWatchOverhead(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, _ = cfg.Get(fmt.Sprintf("value%d", i%100)) _, _ = cfg.Get(fmt.Sprintf("value%d", i%100))
} }
}
// helper function to wait for watcher state, preventing race conditions of goroutine start and test check
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
require.Eventually(t, func() bool {
return cfg.IsWatching() == expected
}, 200*time.Millisecond, 10*time.Millisecond, msgAndArgs...)
} }