e6.0.0 Added file format change and security option support.
This commit is contained in:
27
builder.go
27
builder.go
@ -15,6 +15,8 @@ type Builder struct {
|
|||||||
opts LoadOptions
|
opts LoadOptions
|
||||||
defaults any
|
defaults any
|
||||||
tagName string
|
tagName string
|
||||||
|
fileFormat string
|
||||||
|
securityOpts *SecurityOptions
|
||||||
prefix string
|
prefix string
|
||||||
file string
|
file string
|
||||||
args []string
|
args []string
|
||||||
@ -50,6 +52,14 @@ func (b *Builder) Build() (*Config, error) {
|
|||||||
tagName = "toml"
|
tagName = "toml"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set format and security settings
|
||||||
|
if b.fileFormat != "" {
|
||||||
|
b.cfg.fileFormat = b.fileFormat
|
||||||
|
}
|
||||||
|
if b.securityOpts != nil {
|
||||||
|
b.cfg.securityOpts = b.securityOpts
|
||||||
|
}
|
||||||
|
|
||||||
// 1. Register defaults
|
// 1. Register defaults
|
||||||
// If WithDefaults() was called, it takes precedence.
|
// If WithDefaults() was called, it takes precedence.
|
||||||
// If not, but WithTarget() was called, use the target struct for defaults.
|
// If not, but WithTarget() was called, use the target struct for defaults.
|
||||||
@ -148,6 +158,23 @@ func (b *Builder) WithTagName(tagName string) *Builder {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithFileFormat sets the expected file format
|
||||||
|
func (b *Builder) WithFileFormat(format string) *Builder {
|
||||||
|
switch format {
|
||||||
|
case "toml", "json", "yaml", "auto":
|
||||||
|
b.fileFormat = format
|
||||||
|
default:
|
||||||
|
b.err = fmt.Errorf("unsupported file format %q", format)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSecurityOptions sets security options for file loading
|
||||||
|
func (b *Builder) WithSecurityOptions(opts SecurityOptions) *Builder {
|
||||||
|
b.securityOpts = &opts
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// WithPrefix sets the prefix for struct registration
|
// WithPrefix sets the prefix for struct registration
|
||||||
func (b *Builder) WithPrefix(prefix string) *Builder {
|
func (b *Builder) WithPrefix(prefix string) *Builder {
|
||||||
b.prefix = prefix
|
b.prefix = prefix
|
||||||
|
|||||||
@ -203,7 +203,8 @@ func TestBuilder(t *testing.T) {
|
|||||||
func TestFileDiscovery(t *testing.T) {
|
func TestFileDiscovery(t *testing.T) {
|
||||||
t.Run("DiscoveryWithCLIFlag", func(t *testing.T) {
|
t.Run("DiscoveryWithCLIFlag", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
configFile := filepath.Join(tmpDir, "custom.conf")
|
// Use .toml extension for TOML content
|
||||||
|
configFile := filepath.Join(tmpDir, "custom.toml")
|
||||||
os.WriteFile(configFile, []byte(`test = "value"`), 0644)
|
os.WriteFile(configFile, []byte(`test = "value"`), 0644)
|
||||||
|
|
||||||
opts := DefaultDiscoveryOptions("myapp")
|
opts := DefaultDiscoveryOptions("myapp")
|
||||||
@ -223,6 +224,7 @@ func TestFileDiscovery(t *testing.T) {
|
|||||||
assert.Equal(t, "value", val)
|
assert.Equal(t, "value", val)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Rest of test cases remain the same...
|
||||||
t.Run("DiscoveryWithEnvVar", func(t *testing.T) {
|
t.Run("DiscoveryWithEnvVar", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
configFile := filepath.Join(tmpDir, "env.toml")
|
configFile := filepath.Join(tmpDir, "env.toml")
|
||||||
|
|||||||
61
config.go
61
config.go
@ -47,19 +47,28 @@ type structCache struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SecurityOptions for enhanced file loading security
|
||||||
|
type SecurityOptions struct {
|
||||||
|
PreventPathTraversal bool // Prevent ../ in paths
|
||||||
|
EnforceFileOwnership bool // Unix only: ensure file owned by current user
|
||||||
|
MaxFileSize int64 // Maximum config file size (0 = no limit)
|
||||||
|
}
|
||||||
|
|
||||||
// Config manages application configuration. It can be used in two primary ways:
|
// Config manages application configuration. It can be used in two primary ways:
|
||||||
// 1. As a dynamic key-value store, accessed via methods like Get(), String(), and Int64()
|
// 1. As a dynamic key-value store, accessed via methods like Get(), String(), and Int64()
|
||||||
// 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct()
|
// 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct()
|
||||||
type Config struct {
|
type Config struct {
|
||||||
items map[string]configItem
|
items map[string]configItem
|
||||||
tagName string
|
tagName string
|
||||||
mutex sync.RWMutex
|
fileFormat string // Separate from tagName: "toml", "json", "yaml", or "auto"
|
||||||
options LoadOptions // Current load options
|
securityOpts *SecurityOptions
|
||||||
fileData map[string]any // Cached file data
|
mutex sync.RWMutex
|
||||||
envData map[string]any // Cached env data
|
options LoadOptions // Current load options
|
||||||
cliData map[string]any // Cached CLI data
|
fileData map[string]any // Cached file data
|
||||||
version atomic.Int64
|
envData map[string]any // Cached env data
|
||||||
structCache *structCache
|
cliData map[string]any // Cached CLI data
|
||||||
|
version atomic.Int64
|
||||||
|
structCache *structCache
|
||||||
|
|
||||||
// File watching support
|
// File watching support
|
||||||
watcher *watcher
|
watcher *watcher
|
||||||
@ -69,8 +78,14 @@ type Config struct {
|
|||||||
// New creates and initializes a new Config instance.
|
// New creates and initializes a new Config instance.
|
||||||
func New() *Config {
|
func New() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
items: make(map[string]configItem),
|
items: make(map[string]configItem),
|
||||||
tagName: "toml",
|
tagName: "toml",
|
||||||
|
fileFormat: "auto",
|
||||||
|
// securityOpts: &SecurityOptions{
|
||||||
|
// PreventPathTraversal: false,
|
||||||
|
// EnforceFileOwnership: false,
|
||||||
|
// MaxFileSize: 0,
|
||||||
|
// },
|
||||||
options: DefaultLoadOptions(),
|
options: DefaultLoadOptions(),
|
||||||
fileData: make(map[string]any),
|
fileData: make(map[string]any),
|
||||||
envData: make(map[string]any),
|
envData: make(map[string]any),
|
||||||
@ -114,6 +129,30 @@ func (c *Config) computeValue(item configItem) any {
|
|||||||
return item.defaultValue
|
return item.defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFileFormat sets the expected format for configuration files.
|
||||||
|
// Use "auto" to detect based on file extension.
|
||||||
|
func (c *Config) SetFileFormat(format string) error {
|
||||||
|
switch format {
|
||||||
|
case "toml", "json", "yaml", "auto":
|
||||||
|
// Valid formats
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported file format %q, must be one of: toml, json, yaml, auto", format)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
c.fileFormat = format
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSecurityOptions configures security checks for file loading
|
||||||
|
func (c *Config) SetSecurityOptions(opts SecurityOptions) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
c.securityOpts = &opts
|
||||||
|
}
|
||||||
|
|
||||||
// Get retrieves a configuration value using the path and indicator if the path was registered
|
// Get retrieves a configuration value using the path and indicator if the path was registered
|
||||||
func (c *Config) Get(path string) (any, bool) {
|
func (c *Config) Get(path string) (any, bool) {
|
||||||
c.mutex.RLock()
|
c.mutex.RLock()
|
||||||
|
|||||||
39
decode.go
39
decode.go
@ -2,6 +2,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -119,6 +120,9 @@ func normalizeMap(data any) (map[string]any, error) {
|
|||||||
// getDecodeHook returns the composite decode hook for all type conversions
|
// getDecodeHook returns the composite decode hook for all type conversions
|
||||||
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
||||||
return mapstructure.ComposeDecodeHookFunc(
|
return mapstructure.ComposeDecodeHookFunc(
|
||||||
|
// JSON Number handling
|
||||||
|
jsonNumberHookFunc(),
|
||||||
|
|
||||||
// Network types
|
// Network types
|
||||||
stringToNetIPHookFunc(),
|
stringToNetIPHookFunc(),
|
||||||
stringToNetIPNetHookFunc(),
|
stringToNetIPNetHookFunc(),
|
||||||
@ -134,6 +138,41 @@ func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// jsonNumberHookFunc handles json.Number conversion to appropriate numeric types
|
||||||
|
func jsonNumberHookFunc() mapstructure.DecodeHookFunc {
|
||||||
|
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||||
|
// Check if source is json.Number
|
||||||
|
if f != reflect.TypeOf(json.Number("")) {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
num := data.(json.Number)
|
||||||
|
|
||||||
|
// Convert based on target type
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return num.Int64()
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
// Parse as int64 first, then convert
|
||||||
|
i, err := num.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if i < 0 {
|
||||||
|
return nil, fmt.Errorf("cannot convert negative number to unsigned type")
|
||||||
|
}
|
||||||
|
return uint64(i), nil
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return num.Float64()
|
||||||
|
case reflect.String:
|
||||||
|
return num.String(), nil
|
||||||
|
default:
|
||||||
|
// Return as-is for other types
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// stringToNetIPHookFunc handles net.IP conversion
|
// stringToNetIPHookFunc handles net.IP conversion
|
||||||
func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
|
func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
|
||||||
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||||
|
|||||||
418
dynamic_test.go
Normal file
418
dynamic_test.go
Normal file
@ -0,0 +1,418 @@
|
|||||||
|
// FILE: lixenwraith/config/dynamic_test.go
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestMultiFormatLoading tests loading different config formats
|
||||||
|
func TestMultiFormatLoading(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create test config in different formats
|
||||||
|
tomlConfig := `
|
||||||
|
[server]
|
||||||
|
host = "toml-host"
|
||||||
|
port = 8080
|
||||||
|
|
||||||
|
[database]
|
||||||
|
url = "postgres://localhost/toml"
|
||||||
|
`
|
||||||
|
|
||||||
|
jsonConfig := `{
|
||||||
|
"server": {
|
||||||
|
"host": "json-host",
|
||||||
|
"port": 9090
|
||||||
|
},
|
||||||
|
"database": {
|
||||||
|
"url": "postgres://localhost/json"
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
yamlConfig := `
|
||||||
|
server:
|
||||||
|
host: yaml-host
|
||||||
|
port: 7070
|
||||||
|
database:
|
||||||
|
url: postgres://localhost/yaml
|
||||||
|
`
|
||||||
|
|
||||||
|
// Write config files
|
||||||
|
tomlPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
jsonPath := filepath.Join(tmpDir, "config.json")
|
||||||
|
yamlPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(tomlPath, []byte(tomlConfig), 0644))
|
||||||
|
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
|
||||||
|
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlConfig), 0644))
|
||||||
|
|
||||||
|
t.Run("AutoDetectFormats", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "")
|
||||||
|
cfg.Register("server.port", 0)
|
||||||
|
cfg.Register("database.url", "")
|
||||||
|
|
||||||
|
// Test TOML
|
||||||
|
cfg.SetFileFormat("auto")
|
||||||
|
require.NoError(t, cfg.LoadFile(tomlPath))
|
||||||
|
host, _ := cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "toml-host", host)
|
||||||
|
|
||||||
|
// Test JSON
|
||||||
|
require.NoError(t, cfg.LoadFile(jsonPath))
|
||||||
|
host, _ = cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "json-host", host)
|
||||||
|
port, _ := cfg.Get("server.port")
|
||||||
|
// JSON number should be preserved as json.Number but convertible
|
||||||
|
switch v := port.(type) {
|
||||||
|
case json.Number:
|
||||||
|
// Expected for raw value
|
||||||
|
assert.Equal(t, json.Number("9090"), v)
|
||||||
|
case int64:
|
||||||
|
// Expected after decode hook conversion
|
||||||
|
assert.Equal(t, int64(9090), v)
|
||||||
|
case float64:
|
||||||
|
// Alternative conversion
|
||||||
|
assert.Equal(t, float64(9090), v)
|
||||||
|
default:
|
||||||
|
t.Errorf("Unexpected type for port: %T", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test YAML
|
||||||
|
require.NoError(t, cfg.LoadFile(yamlPath))
|
||||||
|
host, _ = cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "yaml-host", host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExplicitFormat", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "")
|
||||||
|
|
||||||
|
// Force JSON parsing on .conf file
|
||||||
|
confPath := filepath.Join(tmpDir, "config.conf")
|
||||||
|
require.NoError(t, os.WriteFile(confPath, []byte(jsonConfig), 0644))
|
||||||
|
|
||||||
|
cfg.SetFileFormat("json")
|
||||||
|
require.NoError(t, cfg.LoadFile(confPath))
|
||||||
|
|
||||||
|
host, _ := cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "json-host", host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ContentDetection", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "")
|
||||||
|
|
||||||
|
// Ambiguous extension
|
||||||
|
ambigPath := filepath.Join(tmpDir, "config.conf")
|
||||||
|
require.NoError(t, os.WriteFile(ambigPath, []byte(yamlConfig), 0644))
|
||||||
|
|
||||||
|
cfg.SetFileFormat("auto")
|
||||||
|
require.NoError(t, cfg.LoadFile(ambigPath))
|
||||||
|
|
||||||
|
host, _ := cfg.Get("server.host")
|
||||||
|
assert.Equal(t, "yaml-host", host)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDynamicFormatSwitching tests runtime format changes
|
||||||
|
func TestDynamicFormatSwitching(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create configs in different formats with same structure
|
||||||
|
configs := map[string]string{
|
||||||
|
"toml": `value = "from-toml"`,
|
||||||
|
"json": `{"value": "from-json"}`,
|
||||||
|
"yaml": `value: from-yaml`,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("value", "default")
|
||||||
|
|
||||||
|
for format, content := range configs {
|
||||||
|
t.Run(format, func(t *testing.T) {
|
||||||
|
filePath := filepath.Join(tmpDir, "config."+format)
|
||||||
|
require.NoError(t, os.WriteFile(filePath, []byte(content), 0644))
|
||||||
|
|
||||||
|
// Set format and load
|
||||||
|
require.NoError(t, cfg.SetFileFormat(format))
|
||||||
|
require.NoError(t, cfg.LoadFile(filePath))
|
||||||
|
|
||||||
|
val, _ := cfg.Get("value")
|
||||||
|
assert.Equal(t, "from-"+format, val)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWatchFileFormatSwitch tests watching different file formats
|
||||||
|
func TestWatchFileFormatSwitch(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
tomlPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
jsonPath := filepath.Join(tmpDir, "config.json")
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-1"`), 0644))
|
||||||
|
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-1"}`), 0644))
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("value", "default")
|
||||||
|
|
||||||
|
// Configure fast polling for test
|
||||||
|
opts := WatchOptions{
|
||||||
|
PollInterval: testPollInterval, // Fast polling for tests
|
||||||
|
Debounce: testDebounce, // Short debounce
|
||||||
|
MaxWatchers: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start watching TOML
|
||||||
|
cfg.SetFileFormat("auto")
|
||||||
|
require.NoError(t, cfg.LoadFile(tomlPath))
|
||||||
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
|
defer cfg.StopAutoUpdate()
|
||||||
|
|
||||||
|
// Wait for watcher to start
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return cfg.IsWatching()
|
||||||
|
}, 4*testDebounce, 2*SpinWaitInterval)
|
||||||
|
|
||||||
|
val, _ := cfg.Get("value")
|
||||||
|
assert.Equal(t, "toml-1", val)
|
||||||
|
|
||||||
|
// Switch to JSON with format hint
|
||||||
|
require.NoError(t, cfg.WatchFile(jsonPath, "json"))
|
||||||
|
|
||||||
|
// Wait for new watcher to start
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return cfg.IsWatching()
|
||||||
|
}, 4*testDebounce, 2*SpinWaitInterval)
|
||||||
|
|
||||||
|
// Get watch channel AFTER switching files
|
||||||
|
changes := cfg.Watch()
|
||||||
|
|
||||||
|
val, _ = cfg.Get("value")
|
||||||
|
assert.Equal(t, "json-1", val)
|
||||||
|
|
||||||
|
// Update JSON file
|
||||||
|
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-2"}`), 0644))
|
||||||
|
|
||||||
|
// Wait for change notification
|
||||||
|
select {
|
||||||
|
case path := <-changes:
|
||||||
|
assert.Equal(t, "value", path)
|
||||||
|
// Wait a bit for value to be updated
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
val, _ := cfg.Get("value")
|
||||||
|
return val == "json-2"
|
||||||
|
}, testEventuallyTimeout, 2*SpinWaitInterval)
|
||||||
|
case <-time.After(testWatchTimeout):
|
||||||
|
t.Error("Timeout waiting for JSON file change")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update old TOML file - should NOT trigger notification
|
||||||
|
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-2"`), 0644))
|
||||||
|
|
||||||
|
// Should not receive notification from old file
|
||||||
|
select {
|
||||||
|
case <-changes:
|
||||||
|
t.Error("Should not receive changes from old TOML file")
|
||||||
|
case <-time.After(testPollWindow):
|
||||||
|
// Expected - no change notification
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSecurityOptions tests security features
|
||||||
|
func TestSecurityOptions(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
t.Run("PathTraversal", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.SetSecurityOptions(SecurityOptions{
|
||||||
|
PreventPathTraversal: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test various malicious paths
|
||||||
|
maliciousPaths := []string{
|
||||||
|
"../../../etc/passwd",
|
||||||
|
"./../etc/passwd",
|
||||||
|
"config/../../../etc/passwd",
|
||||||
|
filepath.Join("..", "..", "etc", "passwd"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, malPath := range maliciousPaths {
|
||||||
|
err := cfg.LoadFile(malPath)
|
||||||
|
assert.Error(t, err, "Should reject path: %s", malPath)
|
||||||
|
assert.Contains(t, err.Error(), "path traversal")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid paths should work
|
||||||
|
validPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
os.WriteFile(validPath, []byte(`test = "value"`), 0644)
|
||||||
|
cfg.Register("test", "")
|
||||||
|
|
||||||
|
err := cfg.LoadFile(validPath)
|
||||||
|
assert.NoError(t, err, "Should accept valid absolute path")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FileSizeLimit", func(t *testing.T) {
|
||||||
|
cfg := New()
|
||||||
|
cfg.SetSecurityOptions(SecurityOptions{
|
||||||
|
MaxFileSize: 100, // 100 bytes limit
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create large file
|
||||||
|
largePath := filepath.Join(tmpDir, "large.toml")
|
||||||
|
largeContent := make([]byte, 1024)
|
||||||
|
for i := range largeContent {
|
||||||
|
largeContent[i] = 'a'
|
||||||
|
}
|
||||||
|
require.NoError(t, os.WriteFile(largePath, largeContent, 0644))
|
||||||
|
|
||||||
|
err := cfg.LoadFile(largePath)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "exceeds maximum size")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FileOwnership", func(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("Skipping ownership test on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.SetSecurityOptions(SecurityOptions{
|
||||||
|
EnforceFileOwnership: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create file owned by current user (should succeed)
|
||||||
|
ownedPath := filepath.Join(tmpDir, "owned.toml")
|
||||||
|
require.NoError(t, os.WriteFile(ownedPath, []byte(`test = "value"`), 0644))
|
||||||
|
|
||||||
|
cfg.Register("test", "")
|
||||||
|
err := cfg.LoadFile(ownedPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForWatchingState waits for watcher state, preventing race conditions of goroutine start and test check
|
||||||
|
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
|
||||||
|
t.Helper()
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return cfg.IsWatching() == expected
|
||||||
|
}, testEventuallyTimeout, 2*SpinWaitInterval, msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuilderWithFormat tests Builder integration
|
||||||
|
func TestBuilderWithFormat(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
jsonPath := filepath.Join(tmpDir, "config.json")
|
||||||
|
|
||||||
|
jsonConfig := `{
|
||||||
|
"server": {
|
||||||
|
"host": "builder-host",
|
||||||
|
"port": 8080
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Server struct {
|
||||||
|
Host string `json:"host" toml:"host"`
|
||||||
|
Port int `json:"port" toml:"port"`
|
||||||
|
} `json:"server" toml:"server"`
|
||||||
|
}
|
||||||
|
|
||||||
|
defaults := &Config{}
|
||||||
|
defaults.Server.Host = "default-host"
|
||||||
|
defaults.Server.Port = 3000
|
||||||
|
|
||||||
|
cfg, err := NewBuilder().
|
||||||
|
WithDefaults(defaults).
|
||||||
|
WithFile(jsonPath).
|
||||||
|
WithFileFormat("json").
|
||||||
|
WithTagName("toml"). // Use toml tags for registration
|
||||||
|
WithSecurityOptions(SecurityOptions{
|
||||||
|
PreventPathTraversal: true,
|
||||||
|
MaxFileSize: 1024 * 1024, // 1MB
|
||||||
|
}).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check the value was loaded
|
||||||
|
host, exists := cfg.Get("server.host")
|
||||||
|
assert.True(t, exists, "server.host should exist")
|
||||||
|
assert.Equal(t, "builder-host", host)
|
||||||
|
|
||||||
|
port, exists := cfg.Get("server.port")
|
||||||
|
assert.True(t, exists, "server.port should exist")
|
||||||
|
// Handle json.Number or converted int
|
||||||
|
switch v := port.(type) {
|
||||||
|
case json.Number:
|
||||||
|
p, _ := v.Int64()
|
||||||
|
assert.Equal(t, int64(8080), p)
|
||||||
|
case int64:
|
||||||
|
assert.Equal(t, int64(8080), v)
|
||||||
|
case float64:
|
||||||
|
assert.Equal(t, float64(8080), v)
|
||||||
|
default:
|
||||||
|
t.Errorf("Unexpected type for port: %T", port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkFormatParsing benchmarks different format parsing speeds
|
||||||
|
func BenchmarkFormatParsing(b *testing.B) {
|
||||||
|
tmpDir := b.TempDir()
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
configs := map[string]string{
|
||||||
|
"toml": `
|
||||||
|
[server]
|
||||||
|
host = "localhost"
|
||||||
|
port = 8080
|
||||||
|
[database]
|
||||||
|
url = "postgres://localhost/db"
|
||||||
|
[cache]
|
||||||
|
ttl = 300
|
||||||
|
`,
|
||||||
|
"json": `{
|
||||||
|
"server": {"host": "localhost", "port": 8080},
|
||||||
|
"database": {"url": "postgres://localhost/db"},
|
||||||
|
"cache": {"ttl": 300}
|
||||||
|
}`,
|
||||||
|
"yaml": `
|
||||||
|
server:
|
||||||
|
host: localhost
|
||||||
|
port: 8080
|
||||||
|
database:
|
||||||
|
url: postgres://localhost/db
|
||||||
|
cache:
|
||||||
|
ttl: 300
|
||||||
|
`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for format, content := range configs {
|
||||||
|
b.Run(format, func(b *testing.B) {
|
||||||
|
path := filepath.Join(tmpDir, "bench."+format)
|
||||||
|
os.WriteFile(path, []byte(content), 0644)
|
||||||
|
|
||||||
|
cfg := New()
|
||||||
|
cfg.Register("server.host", "")
|
||||||
|
cfg.Register("server.port", 0)
|
||||||
|
cfg.Register("database.url", "")
|
||||||
|
cfg.Register("cache.ttl", 0)
|
||||||
|
cfg.SetFileFormat(format)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
cfg.LoadFile(path)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
141
loader.go
141
loader.go
@ -3,13 +3,18 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Source represents a configuration source, used to define load precedence
|
// Source represents a configuration source, used to define load precedence
|
||||||
@ -143,18 +148,103 @@ func (c *Config) LoadFile(filePath string) error {
|
|||||||
|
|
||||||
// loadFile reads and parses a TOML configuration file
|
// loadFile reads and parses a TOML configuration file
|
||||||
func (c *Config) loadFile(path string) error {
|
func (c *Config) loadFile(path string) error {
|
||||||
// 1. Read and Parse (No Lock)
|
// Security: Path traversal check
|
||||||
fileData, err := os.ReadFile(path)
|
if c.securityOpts != nil && c.securityOpts.PreventPathTraversal {
|
||||||
|
// Clean the path and check for traversal attempts
|
||||||
|
cleanPath := filepath.Clean(path)
|
||||||
|
|
||||||
|
// Check if cleaned path tries to go outside current directory
|
||||||
|
if strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) || cleanPath == ".." {
|
||||||
|
return fmt.Errorf("potential path traversal detected in config path: %s", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also check for absolute paths that might escape jail
|
||||||
|
if filepath.IsAbs(cleanPath) && filepath.IsAbs(path) {
|
||||||
|
// Absolute paths are OK if that's what was provided
|
||||||
|
} else if filepath.IsAbs(cleanPath) && !filepath.IsAbs(path) {
|
||||||
|
// Relative path became absolute after cleaning - suspicious
|
||||||
|
return fmt.Errorf("potential path traversal detected in config path: %s", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read file with size limit
|
||||||
|
fileInfo, err := os.Stat(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return ErrConfigNotFound
|
return ErrConfigNotFound
|
||||||
}
|
}
|
||||||
|
return fmt.Errorf("failed to stat config file '%s': %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security: File size check
|
||||||
|
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
|
||||||
|
if fileInfo.Size() > c.securityOpts.MaxFileSize {
|
||||||
|
return fmt.Errorf("config file '%s' exceeds maximum size %d bytes", path, c.securityOpts.MaxFileSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security: File ownership check (Unix only)
|
||||||
|
if c.securityOpts != nil && c.securityOpts.EnforceFileOwnership && runtime.GOOS != "windows" {
|
||||||
|
if stat, ok := fileInfo.Sys().(*syscall.Stat_t); ok {
|
||||||
|
if stat.Uid != uint32(os.Geteuid()) {
|
||||||
|
return fmt.Errorf("config file '%s' is not owned by current user (file UID: %d, process UID: %d)",
|
||||||
|
path, stat.Uid, os.Geteuid())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Read and parse file data
|
||||||
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open config file '%s': %w", path, err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Use LimitedReader for additional safety
|
||||||
|
var reader io.Reader = file
|
||||||
|
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
|
||||||
|
reader = io.LimitReader(file, c.securityOpts.MaxFileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileData, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read config file '%s': %w", path, err)
|
return fmt.Errorf("failed to read config file '%s': %w", path, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine format
|
||||||
|
format := c.fileFormat
|
||||||
|
if format == "" || format == "auto" {
|
||||||
|
// Try extension first
|
||||||
|
format = detectFileFormat(path)
|
||||||
|
if format == "" {
|
||||||
|
// Fall back to content detection
|
||||||
|
format = detectFormatFromContent(fileData)
|
||||||
|
if format == "" {
|
||||||
|
// Last resort: use tagName as hint
|
||||||
|
format = c.tagName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse based on detected/specified format
|
||||||
fileConfig := make(map[string]any)
|
fileConfig := make(map[string]any)
|
||||||
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
|
switch format {
|
||||||
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
|
case "toml":
|
||||||
|
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
|
||||||
|
}
|
||||||
|
case "json":
|
||||||
|
decoder := json.NewDecoder(bytes.NewReader(fileData))
|
||||||
|
decoder.UseNumber() // Preserve number precision
|
||||||
|
if err := decoder.Decode(&fileConfig); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse JSON config file '%s': %w", path, err)
|
||||||
|
}
|
||||||
|
case "yaml":
|
||||||
|
if err := yaml.Unmarshal(fileData, &fileConfig); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse YAML config file '%s': %w", path, err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unable to determine config format for file '%s'", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Prepare New State (Read-Lock Only)
|
// 2. Prepare New State (Read-Lock Only)
|
||||||
@ -185,7 +275,7 @@ func (c *Config) loadFile(path string) error {
|
|||||||
}
|
}
|
||||||
apply("", fileConfig)
|
apply("", fileConfig)
|
||||||
|
|
||||||
// -- 3. Atomically Update Config (Write-Lock)
|
// 3. Atomically Update Config (Write-Lock)
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
@ -578,4 +668,45 @@ func parseArgs(args []string) (map[string]any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectFileFormat determines format from file extension
|
||||||
|
func detectFileFormat(path string) string {
|
||||||
|
ext := strings.ToLower(filepath.Ext(path))
|
||||||
|
switch ext {
|
||||||
|
case ".toml", ".tml":
|
||||||
|
return "toml"
|
||||||
|
case ".json":
|
||||||
|
return "json"
|
||||||
|
case ".yaml", ".yml":
|
||||||
|
return "yaml"
|
||||||
|
case ".conf", ".config":
|
||||||
|
// Try to detect from content
|
||||||
|
return ""
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectFormatFromContent attempts to detect format by parsing
|
||||||
|
func detectFormatFromContent(data []byte) string {
|
||||||
|
// Try JSON first (strict format)
|
||||||
|
var jsonTest any
|
||||||
|
if err := json.Unmarshal(data, &jsonTest); err == nil {
|
||||||
|
return "json"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try YAML (superset of JSON, so check after JSON)
|
||||||
|
var yamlTest any
|
||||||
|
if err := yaml.Unmarshal(data, &yamlTest); err == nil {
|
||||||
|
return "yaml"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try TOML last
|
||||||
|
var tomlTest any
|
||||||
|
if err := toml.Unmarshal(data, &tomlTest); err == nil {
|
||||||
|
return "toml"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
26
timing.go
Normal file
26
timing.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
// FILE: lixenwraith/config/timing.go
|
||||||
|
package config
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// Core timing constants for production use.
|
||||||
|
// These define the fundamental timing behavior of the config package.
|
||||||
|
const (
|
||||||
|
// File watching intervals (ordered by frequency)
|
||||||
|
SpinWaitInterval = 5 * time.Millisecond // CPU-friendly busy-wait quantum
|
||||||
|
MinPollInterval = 100 * time.Millisecond // Hard floor for file stat polling
|
||||||
|
ShutdownTimeout = 100 * time.Millisecond // Graceful watcher termination window
|
||||||
|
DefaultDebounce = 500 * time.Millisecond // File change coalescence period
|
||||||
|
DefaultPollInterval = time.Second // Standard file monitoring frequency
|
||||||
|
DefaultReloadTimeout = 5 * time.Second // Maximum duration for reload operations
|
||||||
|
)
|
||||||
|
|
||||||
|
// Derived timing relationships for internal use.
|
||||||
|
// These maintain consistent ratios between related timers.
|
||||||
|
const (
|
||||||
|
// shutdownPollCycles defines how many spin-wait cycles comprise a shutdown timeout
|
||||||
|
shutdownPollCycles = ShutdownTimeout / SpinWaitInterval // = 20 cycles
|
||||||
|
|
||||||
|
// debounceSettleMultiplier ensures sufficient time for debounce to complete
|
||||||
|
debounceSettleMultiplier = 3 // Wait 3x debounce period for value stabilization
|
||||||
|
)
|
||||||
75
watch.go
75
watch.go
@ -11,6 +11,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DefaultMaxWatchers = 100 // Prevent resource exhaustion
|
||||||
|
|
||||||
// WatchOptions configures file watching behavior
|
// WatchOptions configures file watching behavior
|
||||||
type WatchOptions struct {
|
type WatchOptions struct {
|
||||||
// PollInterval for file stat checks (minimum 100ms)
|
// PollInterval for file stat checks (minimum 100ms)
|
||||||
@ -32,10 +34,10 @@ type WatchOptions struct {
|
|||||||
// DefaultWatchOptions returns sensible defaults for file watching
|
// DefaultWatchOptions returns sensible defaults for file watching
|
||||||
func DefaultWatchOptions() WatchOptions {
|
func DefaultWatchOptions() WatchOptions {
|
||||||
return WatchOptions{
|
return WatchOptions{
|
||||||
PollInterval: time.Second, // Check every second
|
PollInterval: DefaultPollInterval,
|
||||||
Debounce: 500 * time.Millisecond,
|
Debounce: DefaultDebounce,
|
||||||
MaxWatchers: 100, // Prevent resource exhaustion
|
MaxWatchers: DefaultMaxWatchers,
|
||||||
ReloadTimeout: 5 * time.Second,
|
ReloadTimeout: DefaultReloadTimeout,
|
||||||
VerifyPermissions: true,
|
VerifyPermissions: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -71,26 +73,32 @@ func (c *Config) AutoUpdate() {
|
|||||||
// AutoUpdateWithOptions enables automatic configuration reloading with custom options
|
// AutoUpdateWithOptions enables automatic configuration reloading with custom options
|
||||||
func (c *Config) AutoUpdateWithOptions(opts WatchOptions) {
|
func (c *Config) AutoUpdateWithOptions(opts WatchOptions) {
|
||||||
// Validate options
|
// Validate options
|
||||||
if opts.PollInterval < 100*time.Millisecond {
|
if opts.PollInterval < MinPollInterval {
|
||||||
opts.PollInterval = 100 * time.Millisecond // Minimum poll interval
|
opts.PollInterval = MinPollInterval
|
||||||
}
|
}
|
||||||
if opts.MaxWatchers <= 0 {
|
if opts.MaxWatchers <= 0 {
|
||||||
opts.MaxWatchers = 100
|
opts.MaxWatchers = 100
|
||||||
}
|
}
|
||||||
if opts.ReloadTimeout <= 0 {
|
if opts.ReloadTimeout <= 0 {
|
||||||
opts.ReloadTimeout = 5 * time.Second
|
opts.ReloadTimeout = DefaultReloadTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
// Check if we have a file to watch
|
// Get path of current file to watch
|
||||||
filePath := c.getConfigFilePath()
|
filePath := c.getConfigFilePath()
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
// No file configured, nothing to watch
|
// No file configured, nothing to watch
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop existing watcher if path changed
|
||||||
|
if c.watcher != nil && c.watcher.filePath != filePath {
|
||||||
|
c.watcher.stop()
|
||||||
|
c.watcher = nil
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize watcher if needed
|
// Initialize watcher if needed
|
||||||
if c.watcher == nil {
|
if c.watcher == nil {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@ -131,17 +139,24 @@ func (c *Config) Watch() <-chan string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WatchFile stops any existing file watcher, loads a new configuration file,
|
// WatchFile stops any existing file watcher, loads a new configuration file,
|
||||||
// and starts a new watcher on that file path.
|
// and starts a new watcher on that file path. Optionally accepts format hint.
|
||||||
func (c *Config) WatchFile(filePath string) error {
|
func (c *Config) WatchFile(filePath string, formatHint ...string) error {
|
||||||
// Stop any currently running watcher to prevent orphaned goroutines.
|
// Stop any currently running watcher
|
||||||
c.StopAutoUpdate()
|
c.StopAutoUpdate()
|
||||||
|
|
||||||
// Load the new file and set `configFilePath` to the new path
|
// Set format hint if provided
|
||||||
|
if len(formatHint) > 0 {
|
||||||
|
if err := c.SetFileFormat(formatHint[0]); err != nil {
|
||||||
|
return fmt.Errorf("invalid format hint: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the new file
|
||||||
if err := c.LoadFile(filePath); err != nil {
|
if err := c.LoadFile(filePath); err != nil {
|
||||||
return fmt.Errorf("failed to load new file for watching: %w", err)
|
return fmt.Errorf("failed to load new file for watching: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a new watcher on the new file
|
// Get previous watcher options if available
|
||||||
c.mutex.RLock()
|
c.mutex.RLock()
|
||||||
opts := DefaultWatchOptions()
|
opts := DefaultWatchOptions()
|
||||||
if c.watcher != nil {
|
if c.watcher != nil {
|
||||||
@ -149,18 +164,36 @@ func (c *Config) WatchFile(filePath string) error {
|
|||||||
}
|
}
|
||||||
c.mutex.RUnlock()
|
c.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Start new watcher (AutoUpdateWithOptions will create a new watcher with the new file path)
|
||||||
c.AutoUpdateWithOptions(opts)
|
c.AutoUpdateWithOptions(opts)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WatchWithOptions returns a channel with custom watch options
|
// WatchWithOptions returns a channel with custom watch options
|
||||||
|
// should not restart the watcher if it's already running with the same file
|
||||||
func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string {
|
func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string {
|
||||||
|
c.mutex.RLock()
|
||||||
|
watcher := c.watcher
|
||||||
|
filePath := c.configFilePath
|
||||||
|
c.mutex.RUnlock()
|
||||||
|
|
||||||
|
// If no file configured, return closed channel
|
||||||
|
if filePath == "" {
|
||||||
|
ch := make(chan string)
|
||||||
|
close(ch)
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// If watcher exists and is watching the current file, just subscribe
|
||||||
|
if watcher != nil && watcher.filePath == filePath && watcher.watching.Load() {
|
||||||
|
return watcher.subscribe()
|
||||||
|
}
|
||||||
|
|
||||||
// First ensure auto-update is running
|
// First ensure auto-update is running
|
||||||
c.AutoUpdateWithOptions(opts)
|
c.AutoUpdateWithOptions(opts)
|
||||||
|
|
||||||
c.mutex.RLock()
|
c.mutex.RLock()
|
||||||
watcher := c.watcher
|
watcher = c.watcher
|
||||||
c.mutex.RUnlock()
|
c.mutex.RUnlock()
|
||||||
|
|
||||||
if watcher == nil {
|
if watcher == nil {
|
||||||
@ -363,18 +396,22 @@ func (w *watcher) notifyWatchers(path string) {
|
|||||||
|
|
||||||
// stop terminates the watcher
|
// stop terminates the watcher
|
||||||
func (w *watcher) stop() {
|
func (w *watcher) stop() {
|
||||||
w.cancel()
|
if w.cancel != nil {
|
||||||
|
w.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
// Stop debounce timer
|
// Stop debounce timer
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
if w.debounceTimer != nil {
|
if w.debounceTimer != nil {
|
||||||
w.debounceTimer.Stop()
|
w.debounceTimer.Stop()
|
||||||
|
w.debounceTimer = nil
|
||||||
}
|
}
|
||||||
w.mu.Unlock()
|
w.mu.Unlock()
|
||||||
|
|
||||||
// Wait for watch loop to exit
|
// Wait for watch loop to exit with timeout
|
||||||
for w.watching.Load() {
|
deadline := time.Now().Add(ShutdownTimeout)
|
||||||
time.Sleep(10 * time.Millisecond)
|
for w.watching.Load() && time.Now().Before(deadline) {
|
||||||
|
time.Sleep(SpinWaitInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,29 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Test-specific timing constants derived from production values.
|
||||||
|
// These accelerate test execution while maintaining timing relationships.
|
||||||
|
const (
|
||||||
|
// testAcceleration reduces all intervals by this factor for faster tests
|
||||||
|
testAcceleration = 10
|
||||||
|
|
||||||
|
// Accelerated test timings
|
||||||
|
testPollInterval = DefaultPollInterval / testAcceleration // 100ms (from 1s)
|
||||||
|
testDebounce = DefaultDebounce / testAcceleration // 50ms (from 500ms)
|
||||||
|
testReloadTimeout = DefaultReloadTimeout / testAcceleration // 500ms (from 5s)
|
||||||
|
testShutdownTimeout = ShutdownTimeout // Keep original for safety
|
||||||
|
testSpinWaitInterval = SpinWaitInterval // Keep original for CPU efficiency
|
||||||
|
|
||||||
|
// Test assertion timeouts
|
||||||
|
testEventuallyTimeout = testReloadTimeout // Aligns with reload timing
|
||||||
|
testWatchTimeout = 2 * DefaultPollInterval // 2s for change propagation
|
||||||
|
|
||||||
|
// Derived test multipliers with clear purpose
|
||||||
|
testDebounceSettle = debounceSettleMultiplier * testDebounce // 150ms for debounce verification
|
||||||
|
testPollWindow = 3 * testPollInterval // 300ms change detection window
|
||||||
|
testStateStabilize = 4 * testDebounce // 200ms for state convergence
|
||||||
|
)
|
||||||
|
|
||||||
// TestAutoUpdate tests automatic configuration reloading
|
// TestAutoUpdate tests automatic configuration reloading
|
||||||
func TestAutoUpdate(t *testing.T) {
|
func TestAutoUpdate(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
@ -59,8 +82,8 @@ enabled = true
|
|||||||
|
|
||||||
// Enable auto-update with fast polling
|
// Enable auto-update with fast polling
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
PollInterval: 100 * time.Millisecond,
|
PollInterval: testPollInterval,
|
||||||
Debounce: 50 * time.Millisecond,
|
Debounce: testDebounce,
|
||||||
MaxWatchers: 10,
|
MaxWatchers: 10,
|
||||||
}
|
}
|
||||||
cfg.AutoUpdateWithOptions(opts)
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
@ -93,7 +116,7 @@ enabled = false
|
|||||||
require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644))
|
require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644))
|
||||||
|
|
||||||
// Wait for changes to be detected
|
// Wait for changes to be detected
|
||||||
time.Sleep(300 * time.Millisecond)
|
time.Sleep(testPollWindow)
|
||||||
|
|
||||||
// Verify new values
|
// Verify new values
|
||||||
port, _ = cfg.Get("server.port")
|
port, _ = cfg.Get("server.port")
|
||||||
@ -130,8 +153,8 @@ func TestWatchFileDeleted(t *testing.T) {
|
|||||||
|
|
||||||
// Enable watching
|
// Enable watching
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
PollInterval: 100 * time.Millisecond,
|
PollInterval: testPollInterval,
|
||||||
Debounce: 50 * time.Millisecond,
|
Debounce: testDebounce,
|
||||||
}
|
}
|
||||||
cfg.AutoUpdateWithOptions(opts)
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
defer cfg.StopAutoUpdate()
|
defer cfg.StopAutoUpdate()
|
||||||
@ -145,7 +168,7 @@ func TestWatchFileDeleted(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case path := <-changes:
|
case path := <-changes:
|
||||||
assert.Equal(t, "file_deleted", path)
|
assert.Equal(t, "file_deleted", path)
|
||||||
case <-time.After(500 * time.Millisecond):
|
case <-time.After(testEventuallyTimeout):
|
||||||
t.Error("Timeout waiting for deletion notification")
|
t.Error("Timeout waiting for deletion notification")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -169,8 +192,8 @@ func TestWatchPermissionChange(t *testing.T) {
|
|||||||
|
|
||||||
// Enable watching with permission verification
|
// Enable watching with permission verification
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
PollInterval: 100 * time.Millisecond,
|
PollInterval: testPollInterval,
|
||||||
Debounce: 50 * time.Millisecond,
|
Debounce: testDebounce,
|
||||||
VerifyPermissions: true,
|
VerifyPermissions: true,
|
||||||
}
|
}
|
||||||
cfg.AutoUpdateWithOptions(opts)
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
@ -185,7 +208,7 @@ func TestWatchPermissionChange(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case path := <-changes:
|
case path := <-changes:
|
||||||
assert.Equal(t, "permissions_changed", path)
|
assert.Equal(t, "permissions_changed", path)
|
||||||
case <-time.After(500 * time.Millisecond):
|
case <-time.After(testEventuallyTimeout):
|
||||||
t.Error("Timeout waiting for permission change notification")
|
t.Error("Timeout waiting for permission change notification")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -203,7 +226,7 @@ func TestMaxWatchers(t *testing.T) {
|
|||||||
|
|
||||||
// Enable watching with low max watchers
|
// Enable watching with low max watchers
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
PollInterval: 100 * time.Millisecond,
|
PollInterval: testPollInterval,
|
||||||
MaxWatchers: 3,
|
MaxWatchers: 3,
|
||||||
}
|
}
|
||||||
cfg.AutoUpdateWithOptions(opts)
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
@ -229,7 +252,7 @@ func TestMaxWatchers(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case _, ok := <-ch:
|
case _, ok := <-ch:
|
||||||
assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)")
|
assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)")
|
||||||
case <-time.After(10 * time.Millisecond):
|
case <-time.After(testEventuallyTimeout):
|
||||||
t.Error("Channel 3 should be closed immediately")
|
t.Error("Channel 3 should be closed immediately")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -239,8 +262,8 @@ func TestMaxWatchers(t *testing.T) {
|
|||||||
assert.Equal(t, 3, cfg.WatcherCount())
|
assert.Equal(t, 3, cfg.WatcherCount())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDebounce tests that rapid changes are debounced
|
// TestRapidDebounce tests that rapid changes are debounced
|
||||||
func TestDebounce(t *testing.T) {
|
func TestRapidDebounce(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
configPath := filepath.Join(tmpDir, "test.toml")
|
configPath := filepath.Join(tmpDir, "test.toml")
|
||||||
|
|
||||||
@ -253,8 +276,8 @@ func TestDebounce(t *testing.T) {
|
|||||||
|
|
||||||
// Enable watching with longer debounce
|
// Enable watching with longer debounce
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
PollInterval: 50 * time.Millisecond,
|
PollInterval: testDebounce,
|
||||||
Debounce: 200 * time.Millisecond,
|
Debounce: testStateStabilize,
|
||||||
}
|
}
|
||||||
cfg.AutoUpdateWithOptions(opts)
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
defer cfg.StopAutoUpdate()
|
defer cfg.StopAutoUpdate()
|
||||||
@ -282,11 +305,11 @@ func TestDebounce(t *testing.T) {
|
|||||||
for i := 2; i <= 5; i++ {
|
for i := 2; i <= 5; i++ {
|
||||||
content := fmt.Sprintf(`value = %d`, i)
|
content := fmt.Sprintf(`value = %d`, i)
|
||||||
require.NoError(t, os.WriteFile(configPath, []byte(content), 0644))
|
require.NoError(t, os.WriteFile(configPath, []byte(content), 0644))
|
||||||
time.Sleep(50 * time.Millisecond) // Less than debounce period
|
time.Sleep(testDebounce) // Less than debounce period
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for debounce to complete
|
// Wait for debounce to complete
|
||||||
time.Sleep(300 * time.Millisecond)
|
time.Sleep(2 * testStateStabilize)
|
||||||
done <- true
|
done <- true
|
||||||
|
|
||||||
// Should only see one change due to debounce
|
// Should only see one change due to debounce
|
||||||
@ -328,7 +351,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
|
|||||||
require.NoError(t, cfg.LoadFile(configPath))
|
require.NoError(t, cfg.LoadFile(configPath))
|
||||||
|
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
PollInterval: 50 * time.Millisecond,
|
PollInterval: testDebounce,
|
||||||
MaxWatchers: 50,
|
MaxWatchers: 50,
|
||||||
}
|
}
|
||||||
cfg.AutoUpdateWithOptions(opts)
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
@ -353,7 +376,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case <-ch:
|
case <-ch:
|
||||||
// OK, got a change
|
// OK, got a change
|
||||||
case <-time.After(10 * time.Millisecond):
|
case <-time.After(2 * SpinWaitInterval):
|
||||||
// OK, no changes yet
|
// OK, no changes yet
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
@ -384,7 +407,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
|
|||||||
isWatching = true
|
isWatching = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(2 * SpinWaitInterval)
|
||||||
}
|
}
|
||||||
if !isWatching {
|
if !isWatching {
|
||||||
errors <- fmt.Errorf("checker %d: IsWatching returned false", id)
|
errors <- fmt.Errorf("checker %d: IsWatching returned false", id)
|
||||||
@ -418,8 +441,8 @@ func TestReloadTimeout(t *testing.T) {
|
|||||||
|
|
||||||
// Very short timeout
|
// Very short timeout
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
PollInterval: 100 * time.Millisecond,
|
PollInterval: testPollInterval,
|
||||||
ReloadTimeout: 1 * time.Nanosecond, // Extremely short
|
ReloadTimeout: 1 * time.Nanosecond,
|
||||||
}
|
}
|
||||||
cfg.AutoUpdateWithOptions(opts)
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
defer cfg.StopAutoUpdate()
|
defer cfg.StopAutoUpdate()
|
||||||
@ -454,7 +477,7 @@ func TestStopAutoUpdate(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case _, ok := <-ch:
|
case _, ok := <-ch:
|
||||||
assert.False(t, ok, "Channel should be closed after stop")
|
assert.False(t, ok, "Channel should be closed after stop")
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(ShutdownTimeout):
|
||||||
// OK, channel might not close immediately
|
// OK, channel might not close immediately
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -484,7 +507,7 @@ func BenchmarkWatchOverhead(b *testing.B) {
|
|||||||
|
|
||||||
// Enable watching
|
// Enable watching
|
||||||
opts := WatchOptions{
|
opts := WatchOptions{
|
||||||
PollInterval: 100 * time.Millisecond,
|
PollInterval: testPollInterval,
|
||||||
}
|
}
|
||||||
cfg.AutoUpdateWithOptions(opts)
|
cfg.AutoUpdateWithOptions(opts)
|
||||||
defer cfg.StopAutoUpdate()
|
defer cfg.StopAutoUpdate()
|
||||||
@ -494,11 +517,4 @@ func BenchmarkWatchOverhead(b *testing.B) {
|
|||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _ = cfg.Get(fmt.Sprintf("value%d", i%100))
|
_, _ = cfg.Get(fmt.Sprintf("value%d", i%100))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// helper function to wait for watcher state, preventing race conditions of goroutine start and test check
|
|
||||||
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
|
|
||||||
require.Eventually(t, func() bool {
|
|
||||||
return cfg.IsWatching() == expected
|
|
||||||
}, 200*time.Millisecond, 10*time.Millisecond, msgAndArgs...)
|
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user