e6.0.0 Added file format change and security option support.

This commit is contained in:
2025-08-26 15:07:10 -04:00
parent 3aa2ab30d6
commit 112426b43f
9 changed files with 802 additions and 67 deletions

View File

@ -15,6 +15,8 @@ type Builder struct {
opts LoadOptions
defaults any
tagName string
fileFormat string
securityOpts *SecurityOptions
prefix string
file string
args []string
@ -50,6 +52,14 @@ func (b *Builder) Build() (*Config, error) {
tagName = "toml"
}
// Set format and security settings
if b.fileFormat != "" {
b.cfg.fileFormat = b.fileFormat
}
if b.securityOpts != nil {
b.cfg.securityOpts = b.securityOpts
}
// 1. Register defaults
// If WithDefaults() was called, it takes precedence.
// If not, but WithTarget() was called, use the target struct for defaults.
@ -148,6 +158,23 @@ func (b *Builder) WithTagName(tagName string) *Builder {
return b
}
// WithFileFormat sets the expected file format
func (b *Builder) WithFileFormat(format string) *Builder {
switch format {
case "toml", "json", "yaml", "auto":
b.fileFormat = format
default:
b.err = fmt.Errorf("unsupported file format %q", format)
}
return b
}
// WithSecurityOptions sets security options for file loading
func (b *Builder) WithSecurityOptions(opts SecurityOptions) *Builder {
b.securityOpts = &opts
return b
}
// WithPrefix sets the prefix for struct registration
func (b *Builder) WithPrefix(prefix string) *Builder {
b.prefix = prefix

View File

@ -203,7 +203,8 @@ func TestBuilder(t *testing.T) {
func TestFileDiscovery(t *testing.T) {
t.Run("DiscoveryWithCLIFlag", func(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "custom.conf")
// Use .toml extension for TOML content
configFile := filepath.Join(tmpDir, "custom.toml")
os.WriteFile(configFile, []byte(`test = "value"`), 0644)
opts := DefaultDiscoveryOptions("myapp")
@ -223,6 +224,7 @@ func TestFileDiscovery(t *testing.T) {
assert.Equal(t, "value", val)
})
// Rest of test cases remain the same...
t.Run("DiscoveryWithEnvVar", func(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "env.toml")

View File

@ -47,19 +47,28 @@ type structCache struct {
mu sync.RWMutex
}
// SecurityOptions for enhanced file loading security
type SecurityOptions struct {
PreventPathTraversal bool // Prevent ../ in paths
EnforceFileOwnership bool // Unix only: ensure file owned by current user
MaxFileSize int64 // Maximum config file size (0 = no limit)
}
// Config manages application configuration. It can be used in two primary ways:
// 1. As a dynamic key-value store, accessed via methods like Get(), String(), and Int64()
// 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct()
type Config struct {
items map[string]configItem
tagName string
mutex sync.RWMutex
options LoadOptions // Current load options
fileData map[string]any // Cached file data
envData map[string]any // Cached env data
cliData map[string]any // Cached CLI data
version atomic.Int64
structCache *structCache
items map[string]configItem
tagName string
fileFormat string // Separate from tagName: "toml", "json", "yaml", or "auto"
securityOpts *SecurityOptions
mutex sync.RWMutex
options LoadOptions // Current load options
fileData map[string]any // Cached file data
envData map[string]any // Cached env data
cliData map[string]any // Cached CLI data
version atomic.Int64
structCache *structCache
// File watching support
watcher *watcher
@ -69,8 +78,14 @@ type Config struct {
// New creates and initializes a new Config instance.
func New() *Config {
return &Config{
items: make(map[string]configItem),
tagName: "toml",
items: make(map[string]configItem),
tagName: "toml",
fileFormat: "auto",
// securityOpts: &SecurityOptions{
// PreventPathTraversal: false,
// EnforceFileOwnership: false,
// MaxFileSize: 0,
// },
options: DefaultLoadOptions(),
fileData: make(map[string]any),
envData: make(map[string]any),
@ -114,6 +129,30 @@ func (c *Config) computeValue(item configItem) any {
return item.defaultValue
}
// SetFileFormat sets the expected format for configuration files.
// Use "auto" to detect based on file extension.
func (c *Config) SetFileFormat(format string) error {
switch format {
case "toml", "json", "yaml", "auto":
// Valid formats
default:
return fmt.Errorf("unsupported file format %q, must be one of: toml, json, yaml, auto", format)
}
c.mutex.Lock()
defer c.mutex.Unlock()
c.fileFormat = format
return nil
}
// SetSecurityOptions configures security checks for file loading
func (c *Config) SetSecurityOptions(opts SecurityOptions) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.securityOpts = &opts
}
// Get retrieves a configuration value using the path and indicator if the path was registered
func (c *Config) Get(path string) (any, bool) {
c.mutex.RLock()

View File

@ -2,6 +2,7 @@
package config
import (
"encoding/json"
"fmt"
"net"
"net/url"
@ -119,6 +120,9 @@ func normalizeMap(data any) (map[string]any, error) {
// getDecodeHook returns the composite decode hook for all type conversions
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
return mapstructure.ComposeDecodeHookFunc(
// JSON Number handling
jsonNumberHookFunc(),
// Network types
stringToNetIPHookFunc(),
stringToNetIPNetHookFunc(),
@ -134,6 +138,41 @@ func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
)
}
// jsonNumberHookFunc handles json.Number conversion to appropriate numeric types
func jsonNumberHookFunc() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
// Check if source is json.Number
if f != reflect.TypeOf(json.Number("")) {
return data, nil
}
num := data.(json.Number)
// Convert based on target type
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return num.Int64()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as int64 first, then convert
i, err := num.Int64()
if err != nil {
return nil, err
}
if i < 0 {
return nil, fmt.Errorf("cannot convert negative number to unsigned type")
}
return uint64(i), nil
case reflect.Float32, reflect.Float64:
return num.Float64()
case reflect.String:
return num.String(), nil
default:
// Return as-is for other types
return data, nil
}
}
}
// stringToNetIPHookFunc handles net.IP conversion
func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data any) (any, error) {

418
dynamic_test.go Normal file
View File

@ -0,0 +1,418 @@
// FILE: lixenwraith/config/dynamic_test.go
package config
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMultiFormatLoading tests loading different config formats
func TestMultiFormatLoading(t *testing.T) {
tmpDir := t.TempDir()
// Create test config in different formats
tomlConfig := `
[server]
host = "toml-host"
port = 8080
[database]
url = "postgres://localhost/toml"
`
jsonConfig := `{
"server": {
"host": "json-host",
"port": 9090
},
"database": {
"url": "postgres://localhost/json"
}
}`
yamlConfig := `
server:
host: yaml-host
port: 7070
database:
url: postgres://localhost/yaml
`
// Write config files
tomlPath := filepath.Join(tmpDir, "config.toml")
jsonPath := filepath.Join(tmpDir, "config.json")
yamlPath := filepath.Join(tmpDir, "config.yaml")
require.NoError(t, os.WriteFile(tomlPath, []byte(tomlConfig), 0644))
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlConfig), 0644))
t.Run("AutoDetectFormats", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
cfg.Register("server.port", 0)
cfg.Register("database.url", "")
// Test TOML
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(tomlPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "toml-host", host)
// Test JSON
require.NoError(t, cfg.LoadFile(jsonPath))
host, _ = cfg.Get("server.host")
assert.Equal(t, "json-host", host)
port, _ := cfg.Get("server.port")
// JSON number should be preserved as json.Number but convertible
switch v := port.(type) {
case json.Number:
// Expected for raw value
assert.Equal(t, json.Number("9090"), v)
case int64:
// Expected after decode hook conversion
assert.Equal(t, int64(9090), v)
case float64:
// Alternative conversion
assert.Equal(t, float64(9090), v)
default:
t.Errorf("Unexpected type for port: %T", port)
}
// Test YAML
require.NoError(t, cfg.LoadFile(yamlPath))
host, _ = cfg.Get("server.host")
assert.Equal(t, "yaml-host", host)
})
t.Run("ExplicitFormat", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
// Force JSON parsing on .conf file
confPath := filepath.Join(tmpDir, "config.conf")
require.NoError(t, os.WriteFile(confPath, []byte(jsonConfig), 0644))
cfg.SetFileFormat("json")
require.NoError(t, cfg.LoadFile(confPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "json-host", host)
})
t.Run("ContentDetection", func(t *testing.T) {
cfg := New()
cfg.Register("server.host", "")
// Ambiguous extension
ambigPath := filepath.Join(tmpDir, "config.conf")
require.NoError(t, os.WriteFile(ambigPath, []byte(yamlConfig), 0644))
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(ambigPath))
host, _ := cfg.Get("server.host")
assert.Equal(t, "yaml-host", host)
})
}
// TestDynamicFormatSwitching tests runtime format changes
func TestDynamicFormatSwitching(t *testing.T) {
tmpDir := t.TempDir()
// Create configs in different formats with same structure
configs := map[string]string{
"toml": `value = "from-toml"`,
"json": `{"value": "from-json"}`,
"yaml": `value: from-yaml`,
}
cfg := New()
cfg.Register("value", "default")
for format, content := range configs {
t.Run(format, func(t *testing.T) {
filePath := filepath.Join(tmpDir, "config."+format)
require.NoError(t, os.WriteFile(filePath, []byte(content), 0644))
// Set format and load
require.NoError(t, cfg.SetFileFormat(format))
require.NoError(t, cfg.LoadFile(filePath))
val, _ := cfg.Get("value")
assert.Equal(t, "from-"+format, val)
})
}
}
// TestWatchFileFormatSwitch tests watching different file formats
func TestWatchFileFormatSwitch(t *testing.T) {
tmpDir := t.TempDir()
tomlPath := filepath.Join(tmpDir, "config.toml")
jsonPath := filepath.Join(tmpDir, "config.json")
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-1"`), 0644))
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-1"}`), 0644))
cfg := New()
cfg.Register("value", "default")
// Configure fast polling for test
opts := WatchOptions{
PollInterval: testPollInterval, // Fast polling for tests
Debounce: testDebounce, // Short debounce
MaxWatchers: 10,
}
// Start watching TOML
cfg.SetFileFormat("auto")
require.NoError(t, cfg.LoadFile(tomlPath))
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
// Wait for watcher to start
require.Eventually(t, func() bool {
return cfg.IsWatching()
}, 4*testDebounce, 2*SpinWaitInterval)
val, _ := cfg.Get("value")
assert.Equal(t, "toml-1", val)
// Switch to JSON with format hint
require.NoError(t, cfg.WatchFile(jsonPath, "json"))
// Wait for new watcher to start
require.Eventually(t, func() bool {
return cfg.IsWatching()
}, 4*testDebounce, 2*SpinWaitInterval)
// Get watch channel AFTER switching files
changes := cfg.Watch()
val, _ = cfg.Get("value")
assert.Equal(t, "json-1", val)
// Update JSON file
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-2"}`), 0644))
// Wait for change notification
select {
case path := <-changes:
assert.Equal(t, "value", path)
// Wait a bit for value to be updated
require.Eventually(t, func() bool {
val, _ := cfg.Get("value")
return val == "json-2"
}, testEventuallyTimeout, 2*SpinWaitInterval)
case <-time.After(testWatchTimeout):
t.Error("Timeout waiting for JSON file change")
}
// Update old TOML file - should NOT trigger notification
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-2"`), 0644))
// Should not receive notification from old file
select {
case <-changes:
t.Error("Should not receive changes from old TOML file")
case <-time.After(testPollWindow):
// Expected - no change notification
}
}
// TestSecurityOptions tests security features
func TestSecurityOptions(t *testing.T) {
tmpDir := t.TempDir()
t.Run("PathTraversal", func(t *testing.T) {
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
PreventPathTraversal: true,
})
// Test various malicious paths
maliciousPaths := []string{
"../../../etc/passwd",
"./../etc/passwd",
"config/../../../etc/passwd",
filepath.Join("..", "..", "etc", "passwd"),
}
for _, malPath := range maliciousPaths {
err := cfg.LoadFile(malPath)
assert.Error(t, err, "Should reject path: %s", malPath)
assert.Contains(t, err.Error(), "path traversal")
}
// Valid paths should work
validPath := filepath.Join(tmpDir, "config.toml")
os.WriteFile(validPath, []byte(`test = "value"`), 0644)
cfg.Register("test", "")
err := cfg.LoadFile(validPath)
assert.NoError(t, err, "Should accept valid absolute path")
})
t.Run("FileSizeLimit", func(t *testing.T) {
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
MaxFileSize: 100, // 100 bytes limit
})
// Create large file
largePath := filepath.Join(tmpDir, "large.toml")
largeContent := make([]byte, 1024)
for i := range largeContent {
largeContent[i] = 'a'
}
require.NoError(t, os.WriteFile(largePath, largeContent, 0644))
err := cfg.LoadFile(largePath)
assert.Error(t, err)
assert.Contains(t, err.Error(), "exceeds maximum size")
})
t.Run("FileOwnership", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping ownership test on Windows")
}
cfg := New()
cfg.SetSecurityOptions(SecurityOptions{
EnforceFileOwnership: true,
})
// Create file owned by current user (should succeed)
ownedPath := filepath.Join(tmpDir, "owned.toml")
require.NoError(t, os.WriteFile(ownedPath, []byte(`test = "value"`), 0644))
cfg.Register("test", "")
err := cfg.LoadFile(ownedPath)
assert.NoError(t, err)
})
}
// waitForWatchingState waits for watcher state, preventing race conditions of goroutine start and test check
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
t.Helper()
require.Eventually(t, func() bool {
return cfg.IsWatching() == expected
}, testEventuallyTimeout, 2*SpinWaitInterval, msgAndArgs...)
}
// TestBuilderWithFormat tests Builder integration
func TestBuilderWithFormat(t *testing.T) {
tmpDir := t.TempDir()
jsonPath := filepath.Join(tmpDir, "config.json")
jsonConfig := `{
"server": {
"host": "builder-host",
"port": 8080
}
}`
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
type Config struct {
Server struct {
Host string `json:"host" toml:"host"`
Port int `json:"port" toml:"port"`
} `json:"server" toml:"server"`
}
defaults := &Config{}
defaults.Server.Host = "default-host"
defaults.Server.Port = 3000
cfg, err := NewBuilder().
WithDefaults(defaults).
WithFile(jsonPath).
WithFileFormat("json").
WithTagName("toml"). // Use toml tags for registration
WithSecurityOptions(SecurityOptions{
PreventPathTraversal: true,
MaxFileSize: 1024 * 1024, // 1MB
}).
Build()
require.NoError(t, err)
// Check the value was loaded
host, exists := cfg.Get("server.host")
assert.True(t, exists, "server.host should exist")
assert.Equal(t, "builder-host", host)
port, exists := cfg.Get("server.port")
assert.True(t, exists, "server.port should exist")
// Handle json.Number or converted int
switch v := port.(type) {
case json.Number:
p, _ := v.Int64()
assert.Equal(t, int64(8080), p)
case int64:
assert.Equal(t, int64(8080), v)
case float64:
assert.Equal(t, float64(8080), v)
default:
t.Errorf("Unexpected type for port: %T", port)
}
}
// BenchmarkFormatParsing benchmarks different format parsing speeds
func BenchmarkFormatParsing(b *testing.B) {
tmpDir := b.TempDir()
// Create test data
configs := map[string]string{
"toml": `
[server]
host = "localhost"
port = 8080
[database]
url = "postgres://localhost/db"
[cache]
ttl = 300
`,
"json": `{
"server": {"host": "localhost", "port": 8080},
"database": {"url": "postgres://localhost/db"},
"cache": {"ttl": 300}
}`,
"yaml": `
server:
host: localhost
port: 8080
database:
url: postgres://localhost/db
cache:
ttl: 300
`,
}
for format, content := range configs {
b.Run(format, func(b *testing.B) {
path := filepath.Join(tmpDir, "bench."+format)
os.WriteFile(path, []byte(content), 0644)
cfg := New()
cfg.Register("server.host", "")
cfg.Register("server.port", 0)
cfg.Register("database.url", "")
cfg.Register("cache.ttl", 0)
cfg.SetFileFormat(format)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cfg.LoadFile(path)
}
})
}
}

141
loader.go
View File

@ -3,13 +3,18 @@ package config
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"syscall"
"github.com/BurntSushi/toml"
"gopkg.in/yaml.v3"
)
// Source represents a configuration source, used to define load precedence
@ -143,18 +148,103 @@ func (c *Config) LoadFile(filePath string) error {
// loadFile reads and parses a TOML configuration file
func (c *Config) loadFile(path string) error {
// 1. Read and Parse (No Lock)
fileData, err := os.ReadFile(path)
// Security: Path traversal check
if c.securityOpts != nil && c.securityOpts.PreventPathTraversal {
// Clean the path and check for traversal attempts
cleanPath := filepath.Clean(path)
// Check if cleaned path tries to go outside current directory
if strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) || cleanPath == ".." {
return fmt.Errorf("potential path traversal detected in config path: %s", path)
}
// Also check for absolute paths that might escape jail
if filepath.IsAbs(cleanPath) && filepath.IsAbs(path) {
// Absolute paths are OK if that's what was provided
} else if filepath.IsAbs(cleanPath) && !filepath.IsAbs(path) {
// Relative path became absolute after cleaning - suspicious
return fmt.Errorf("potential path traversal detected in config path: %s", path)
}
}
// Read file with size limit
fileInfo, err := os.Stat(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return ErrConfigNotFound
}
return fmt.Errorf("failed to stat config file '%s': %w", path, err)
}
// Security: File size check
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
if fileInfo.Size() > c.securityOpts.MaxFileSize {
return fmt.Errorf("config file '%s' exceeds maximum size %d bytes", path, c.securityOpts.MaxFileSize)
}
}
// Security: File ownership check (Unix only)
if c.securityOpts != nil && c.securityOpts.EnforceFileOwnership && runtime.GOOS != "windows" {
if stat, ok := fileInfo.Sys().(*syscall.Stat_t); ok {
if stat.Uid != uint32(os.Geteuid()) {
return fmt.Errorf("config file '%s' is not owned by current user (file UID: %d, process UID: %d)",
path, stat.Uid, os.Geteuid())
}
}
}
// 1. Read and parse file data
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("failed to open config file '%s': %w", path, err)
}
defer file.Close()
// Use LimitedReader for additional safety
var reader io.Reader = file
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
reader = io.LimitReader(file, c.securityOpts.MaxFileSize)
}
fileData, err := io.ReadAll(reader)
if err != nil {
return fmt.Errorf("failed to read config file '%s': %w", path, err)
}
// Determine format
format := c.fileFormat
if format == "" || format == "auto" {
// Try extension first
format = detectFileFormat(path)
if format == "" {
// Fall back to content detection
format = detectFormatFromContent(fileData)
if format == "" {
// Last resort: use tagName as hint
format = c.tagName
}
}
}
// Parse based on detected/specified format
fileConfig := make(map[string]any)
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
switch format {
case "toml":
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
}
case "json":
decoder := json.NewDecoder(bytes.NewReader(fileData))
decoder.UseNumber() // Preserve number precision
if err := decoder.Decode(&fileConfig); err != nil {
return fmt.Errorf("failed to parse JSON config file '%s': %w", path, err)
}
case "yaml":
if err := yaml.Unmarshal(fileData, &fileConfig); err != nil {
return fmt.Errorf("failed to parse YAML config file '%s': %w", path, err)
}
default:
return fmt.Errorf("unable to determine config format for file '%s'", path)
}
// 2. Prepare New State (Read-Lock Only)
@ -185,7 +275,7 @@ func (c *Config) loadFile(path string) error {
}
apply("", fileConfig)
// -- 3. Atomically Update Config (Write-Lock)
// 3. Atomically Update Config (Write-Lock)
c.mutex.Lock()
defer c.mutex.Unlock()
@ -578,4 +668,45 @@ func parseArgs(args []string) (map[string]any, error) {
}
return result, nil
}
// detectFileFormat determines format from file extension
func detectFileFormat(path string) string {
ext := strings.ToLower(filepath.Ext(path))
switch ext {
case ".toml", ".tml":
return "toml"
case ".json":
return "json"
case ".yaml", ".yml":
return "yaml"
case ".conf", ".config":
// Try to detect from content
return ""
default:
return ""
}
}
// detectFormatFromContent attempts to detect format by parsing
func detectFormatFromContent(data []byte) string {
// Try JSON first (strict format)
var jsonTest any
if err := json.Unmarshal(data, &jsonTest); err == nil {
return "json"
}
// Try YAML (superset of JSON, so check after JSON)
var yamlTest any
if err := yaml.Unmarshal(data, &yamlTest); err == nil {
return "yaml"
}
// Try TOML last
var tomlTest any
if err := toml.Unmarshal(data, &tomlTest); err == nil {
return "toml"
}
return ""
}

26
timing.go Normal file
View File

@ -0,0 +1,26 @@
// FILE: lixenwraith/config/timing.go
package config
import "time"
// Core timing constants for production use.
// These define the fundamental timing behavior of the config package.
const (
// File watching intervals (ordered by frequency)
SpinWaitInterval = 5 * time.Millisecond // CPU-friendly busy-wait quantum
MinPollInterval = 100 * time.Millisecond // Hard floor for file stat polling
ShutdownTimeout = 100 * time.Millisecond // Graceful watcher termination window
DefaultDebounce = 500 * time.Millisecond // File change coalescence period
DefaultPollInterval = time.Second // Standard file monitoring frequency
DefaultReloadTimeout = 5 * time.Second // Maximum duration for reload operations
)
// Derived timing relationships for internal use.
// These maintain consistent ratios between related timers.
const (
// shutdownPollCycles defines how many spin-wait cycles comprise a shutdown timeout
shutdownPollCycles = ShutdownTimeout / SpinWaitInterval // = 20 cycles
// debounceSettleMultiplier ensures sufficient time for debounce to complete
debounceSettleMultiplier = 3 // Wait 3x debounce period for value stabilization
)

View File

@ -11,6 +11,8 @@ import (
"time"
)
const DefaultMaxWatchers = 100 // Prevent resource exhaustion
// WatchOptions configures file watching behavior
type WatchOptions struct {
// PollInterval for file stat checks (minimum 100ms)
@ -32,10 +34,10 @@ type WatchOptions struct {
// DefaultWatchOptions returns sensible defaults for file watching
func DefaultWatchOptions() WatchOptions {
return WatchOptions{
PollInterval: time.Second, // Check every second
Debounce: 500 * time.Millisecond,
MaxWatchers: 100, // Prevent resource exhaustion
ReloadTimeout: 5 * time.Second,
PollInterval: DefaultPollInterval,
Debounce: DefaultDebounce,
MaxWatchers: DefaultMaxWatchers,
ReloadTimeout: DefaultReloadTimeout,
VerifyPermissions: true,
}
}
@ -71,26 +73,32 @@ func (c *Config) AutoUpdate() {
// AutoUpdateWithOptions enables automatic configuration reloading with custom options
func (c *Config) AutoUpdateWithOptions(opts WatchOptions) {
// Validate options
if opts.PollInterval < 100*time.Millisecond {
opts.PollInterval = 100 * time.Millisecond // Minimum poll interval
if opts.PollInterval < MinPollInterval {
opts.PollInterval = MinPollInterval
}
if opts.MaxWatchers <= 0 {
opts.MaxWatchers = 100
}
if opts.ReloadTimeout <= 0 {
opts.ReloadTimeout = 5 * time.Second
opts.ReloadTimeout = DefaultReloadTimeout
}
c.mutex.Lock()
defer c.mutex.Unlock()
// Check if we have a file to watch
// Get path of current file to watch
filePath := c.getConfigFilePath()
if filePath == "" {
// No file configured, nothing to watch
return
}
// Stop existing watcher if path changed
if c.watcher != nil && c.watcher.filePath != filePath {
c.watcher.stop()
c.watcher = nil
}
// Initialize watcher if needed
if c.watcher == nil {
ctx, cancel := context.WithCancel(context.Background())
@ -131,17 +139,24 @@ func (c *Config) Watch() <-chan string {
}
// WatchFile stops any existing file watcher, loads a new configuration file,
// and starts a new watcher on that file path.
func (c *Config) WatchFile(filePath string) error {
// Stop any currently running watcher to prevent orphaned goroutines.
// and starts a new watcher on that file path. Optionally accepts format hint.
func (c *Config) WatchFile(filePath string, formatHint ...string) error {
// Stop any currently running watcher
c.StopAutoUpdate()
// Load the new file and set `configFilePath` to the new path
// Set format hint if provided
if len(formatHint) > 0 {
if err := c.SetFileFormat(formatHint[0]); err != nil {
return fmt.Errorf("invalid format hint: %w", err)
}
}
// Load the new file
if err := c.LoadFile(filePath); err != nil {
return fmt.Errorf("failed to load new file for watching: %w", err)
}
// Start a new watcher on the new file
// Get previous watcher options if available
c.mutex.RLock()
opts := DefaultWatchOptions()
if c.watcher != nil {
@ -149,18 +164,36 @@ func (c *Config) WatchFile(filePath string) error {
}
c.mutex.RUnlock()
// Start new watcher (AutoUpdateWithOptions will create a new watcher with the new file path)
c.AutoUpdateWithOptions(opts)
return nil
}
// WatchWithOptions returns a channel with custom watch options
// should not restart the watcher if it's already running with the same file
func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string {
c.mutex.RLock()
watcher := c.watcher
filePath := c.configFilePath
c.mutex.RUnlock()
// If no file configured, return closed channel
if filePath == "" {
ch := make(chan string)
close(ch)
return ch
}
// If watcher exists and is watching the current file, just subscribe
if watcher != nil && watcher.filePath == filePath && watcher.watching.Load() {
return watcher.subscribe()
}
// First ensure auto-update is running
c.AutoUpdateWithOptions(opts)
c.mutex.RLock()
watcher := c.watcher
watcher = c.watcher
c.mutex.RUnlock()
if watcher == nil {
@ -363,18 +396,22 @@ func (w *watcher) notifyWatchers(path string) {
// stop terminates the watcher
func (w *watcher) stop() {
w.cancel()
if w.cancel != nil {
w.cancel()
}
// Stop debounce timer
w.mu.Lock()
if w.debounceTimer != nil {
w.debounceTimer.Stop()
w.debounceTimer = nil
}
w.mu.Unlock()
// Wait for watch loop to exit
for w.watching.Load() {
time.Sleep(10 * time.Millisecond)
// Wait for watch loop to exit with timeout
deadline := time.Now().Add(ShutdownTimeout)
for w.watching.Load() && time.Now().Before(deadline) {
time.Sleep(SpinWaitInterval)
}
}

View File

@ -14,6 +14,29 @@ import (
"github.com/stretchr/testify/require"
)
// Test-specific timing constants derived from production values.
// These accelerate test execution while maintaining timing relationships.
const (
// testAcceleration reduces all intervals by this factor for faster tests
testAcceleration = 10
// Accelerated test timings
testPollInterval = DefaultPollInterval / testAcceleration // 100ms (from 1s)
testDebounce = DefaultDebounce / testAcceleration // 50ms (from 500ms)
testReloadTimeout = DefaultReloadTimeout / testAcceleration // 500ms (from 5s)
testShutdownTimeout = ShutdownTimeout // Keep original for safety
testSpinWaitInterval = SpinWaitInterval // Keep original for CPU efficiency
// Test assertion timeouts
testEventuallyTimeout = testReloadTimeout // Aligns with reload timing
testWatchTimeout = 2 * DefaultPollInterval // 2s for change propagation
// Derived test multipliers with clear purpose
testDebounceSettle = debounceSettleMultiplier * testDebounce // 150ms for debounce verification
testPollWindow = 3 * testPollInterval // 300ms change detection window
testStateStabilize = 4 * testDebounce // 200ms for state convergence
)
// TestAutoUpdate tests automatic configuration reloading
func TestAutoUpdate(t *testing.T) {
// Setup
@ -59,8 +82,8 @@ enabled = true
// Enable auto-update with fast polling
opts := WatchOptions{
PollInterval: 100 * time.Millisecond,
Debounce: 50 * time.Millisecond,
PollInterval: testPollInterval,
Debounce: testDebounce,
MaxWatchers: 10,
}
cfg.AutoUpdateWithOptions(opts)
@ -93,7 +116,7 @@ enabled = false
require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644))
// Wait for changes to be detected
time.Sleep(300 * time.Millisecond)
time.Sleep(testPollWindow)
// Verify new values
port, _ = cfg.Get("server.port")
@ -130,8 +153,8 @@ func TestWatchFileDeleted(t *testing.T) {
// Enable watching
opts := WatchOptions{
PollInterval: 100 * time.Millisecond,
Debounce: 50 * time.Millisecond,
PollInterval: testPollInterval,
Debounce: testDebounce,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
@ -145,7 +168,7 @@ func TestWatchFileDeleted(t *testing.T) {
select {
case path := <-changes:
assert.Equal(t, "file_deleted", path)
case <-time.After(500 * time.Millisecond):
case <-time.After(testEventuallyTimeout):
t.Error("Timeout waiting for deletion notification")
}
}
@ -169,8 +192,8 @@ func TestWatchPermissionChange(t *testing.T) {
// Enable watching with permission verification
opts := WatchOptions{
PollInterval: 100 * time.Millisecond,
Debounce: 50 * time.Millisecond,
PollInterval: testPollInterval,
Debounce: testDebounce,
VerifyPermissions: true,
}
cfg.AutoUpdateWithOptions(opts)
@ -185,7 +208,7 @@ func TestWatchPermissionChange(t *testing.T) {
select {
case path := <-changes:
assert.Equal(t, "permissions_changed", path)
case <-time.After(500 * time.Millisecond):
case <-time.After(testEventuallyTimeout):
t.Error("Timeout waiting for permission change notification")
}
}
@ -203,7 +226,7 @@ func TestMaxWatchers(t *testing.T) {
// Enable watching with low max watchers
opts := WatchOptions{
PollInterval: 100 * time.Millisecond,
PollInterval: testPollInterval,
MaxWatchers: 3,
}
cfg.AutoUpdateWithOptions(opts)
@ -229,7 +252,7 @@ func TestMaxWatchers(t *testing.T) {
select {
case _, ok := <-ch:
assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)")
case <-time.After(10 * time.Millisecond):
case <-time.After(testEventuallyTimeout):
t.Error("Channel 3 should be closed immediately")
}
}
@ -239,8 +262,8 @@ func TestMaxWatchers(t *testing.T) {
assert.Equal(t, 3, cfg.WatcherCount())
}
// TestDebounce tests that rapid changes are debounced
func TestDebounce(t *testing.T) {
// TestRapidDebounce tests that rapid changes are debounced
func TestRapidDebounce(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test.toml")
@ -253,8 +276,8 @@ func TestDebounce(t *testing.T) {
// Enable watching with longer debounce
opts := WatchOptions{
PollInterval: 50 * time.Millisecond,
Debounce: 200 * time.Millisecond,
PollInterval: testDebounce,
Debounce: testStateStabilize,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
@ -282,11 +305,11 @@ func TestDebounce(t *testing.T) {
for i := 2; i <= 5; i++ {
content := fmt.Sprintf(`value = %d`, i)
require.NoError(t, os.WriteFile(configPath, []byte(content), 0644))
time.Sleep(50 * time.Millisecond) // Less than debounce period
time.Sleep(testDebounce) // Less than debounce period
}
// Wait for debounce to complete
time.Sleep(300 * time.Millisecond)
time.Sleep(2 * testStateStabilize)
done <- true
// Should only see one change due to debounce
@ -328,7 +351,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
require.NoError(t, cfg.LoadFile(configPath))
opts := WatchOptions{
PollInterval: 50 * time.Millisecond,
PollInterval: testDebounce,
MaxWatchers: 50,
}
cfg.AutoUpdateWithOptions(opts)
@ -353,7 +376,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
select {
case <-ch:
// OK, got a change
case <-time.After(10 * time.Millisecond):
case <-time.After(2 * SpinWaitInterval):
// OK, no changes yet
}
}(i)
@ -384,7 +407,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
isWatching = true
break
}
time.Sleep(10 * time.Millisecond)
time.Sleep(2 * SpinWaitInterval)
}
if !isWatching {
errors <- fmt.Errorf("checker %d: IsWatching returned false", id)
@ -418,8 +441,8 @@ func TestReloadTimeout(t *testing.T) {
// Very short timeout
opts := WatchOptions{
PollInterval: 100 * time.Millisecond,
ReloadTimeout: 1 * time.Nanosecond, // Extremely short
PollInterval: testPollInterval,
ReloadTimeout: 1 * time.Nanosecond,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
@ -454,7 +477,7 @@ func TestStopAutoUpdate(t *testing.T) {
select {
case _, ok := <-ch:
assert.False(t, ok, "Channel should be closed after stop")
case <-time.After(100 * time.Millisecond):
case <-time.After(ShutdownTimeout):
// OK, channel might not close immediately
}
@ -484,7 +507,7 @@ func BenchmarkWatchOverhead(b *testing.B) {
// Enable watching
opts := WatchOptions{
PollInterval: 100 * time.Millisecond,
PollInterval: testPollInterval,
}
cfg.AutoUpdateWithOptions(opts)
defer cfg.StopAutoUpdate()
@ -494,11 +517,4 @@ func BenchmarkWatchOverhead(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = cfg.Get(fmt.Sprintf("value%d", i%100))
}
}
// helper function to wait for watcher state, preventing race conditions of goroutine start and test check
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
require.Eventually(t, func() bool {
return cfg.IsWatching() == expected
}, 200*time.Millisecond, 10*time.Millisecond, msgAndArgs...)
}