e6.0.0 Added file format change and security option support.
This commit is contained in:
27
builder.go
27
builder.go
@ -15,6 +15,8 @@ type Builder struct {
|
||||
opts LoadOptions
|
||||
defaults any
|
||||
tagName string
|
||||
fileFormat string
|
||||
securityOpts *SecurityOptions
|
||||
prefix string
|
||||
file string
|
||||
args []string
|
||||
@ -50,6 +52,14 @@ func (b *Builder) Build() (*Config, error) {
|
||||
tagName = "toml"
|
||||
}
|
||||
|
||||
// Set format and security settings
|
||||
if b.fileFormat != "" {
|
||||
b.cfg.fileFormat = b.fileFormat
|
||||
}
|
||||
if b.securityOpts != nil {
|
||||
b.cfg.securityOpts = b.securityOpts
|
||||
}
|
||||
|
||||
// 1. Register defaults
|
||||
// If WithDefaults() was called, it takes precedence.
|
||||
// If not, but WithTarget() was called, use the target struct for defaults.
|
||||
@ -148,6 +158,23 @@ func (b *Builder) WithTagName(tagName string) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithFileFormat sets the expected file format
|
||||
func (b *Builder) WithFileFormat(format string) *Builder {
|
||||
switch format {
|
||||
case "toml", "json", "yaml", "auto":
|
||||
b.fileFormat = format
|
||||
default:
|
||||
b.err = fmt.Errorf("unsupported file format %q", format)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSecurityOptions sets security options for file loading
|
||||
func (b *Builder) WithSecurityOptions(opts SecurityOptions) *Builder {
|
||||
b.securityOpts = &opts
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPrefix sets the prefix for struct registration
|
||||
func (b *Builder) WithPrefix(prefix string) *Builder {
|
||||
b.prefix = prefix
|
||||
|
||||
@ -203,7 +203,8 @@ func TestBuilder(t *testing.T) {
|
||||
func TestFileDiscovery(t *testing.T) {
|
||||
t.Run("DiscoveryWithCLIFlag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "custom.conf")
|
||||
// Use .toml extension for TOML content
|
||||
configFile := filepath.Join(tmpDir, "custom.toml")
|
||||
os.WriteFile(configFile, []byte(`test = "value"`), 0644)
|
||||
|
||||
opts := DefaultDiscoveryOptions("myapp")
|
||||
@ -223,6 +224,7 @@ func TestFileDiscovery(t *testing.T) {
|
||||
assert.Equal(t, "value", val)
|
||||
})
|
||||
|
||||
// Rest of test cases remain the same...
|
||||
t.Run("DiscoveryWithEnvVar", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "env.toml")
|
||||
|
||||
61
config.go
61
config.go
@ -47,19 +47,28 @@ type structCache struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SecurityOptions for enhanced file loading security
|
||||
type SecurityOptions struct {
|
||||
PreventPathTraversal bool // Prevent ../ in paths
|
||||
EnforceFileOwnership bool // Unix only: ensure file owned by current user
|
||||
MaxFileSize int64 // Maximum config file size (0 = no limit)
|
||||
}
|
||||
|
||||
// Config manages application configuration. It can be used in two primary ways:
|
||||
// 1. As a dynamic key-value store, accessed via methods like Get(), String(), and Int64()
|
||||
// 2. As a source for a type-safe struct, populated via BuildAndScan() or AsStruct()
|
||||
type Config struct {
|
||||
items map[string]configItem
|
||||
tagName string
|
||||
mutex sync.RWMutex
|
||||
options LoadOptions // Current load options
|
||||
fileData map[string]any // Cached file data
|
||||
envData map[string]any // Cached env data
|
||||
cliData map[string]any // Cached CLI data
|
||||
version atomic.Int64
|
||||
structCache *structCache
|
||||
items map[string]configItem
|
||||
tagName string
|
||||
fileFormat string // Separate from tagName: "toml", "json", "yaml", or "auto"
|
||||
securityOpts *SecurityOptions
|
||||
mutex sync.RWMutex
|
||||
options LoadOptions // Current load options
|
||||
fileData map[string]any // Cached file data
|
||||
envData map[string]any // Cached env data
|
||||
cliData map[string]any // Cached CLI data
|
||||
version atomic.Int64
|
||||
structCache *structCache
|
||||
|
||||
// File watching support
|
||||
watcher *watcher
|
||||
@ -69,8 +78,14 @@ type Config struct {
|
||||
// New creates and initializes a new Config instance.
|
||||
func New() *Config {
|
||||
return &Config{
|
||||
items: make(map[string]configItem),
|
||||
tagName: "toml",
|
||||
items: make(map[string]configItem),
|
||||
tagName: "toml",
|
||||
fileFormat: "auto",
|
||||
// securityOpts: &SecurityOptions{
|
||||
// PreventPathTraversal: false,
|
||||
// EnforceFileOwnership: false,
|
||||
// MaxFileSize: 0,
|
||||
// },
|
||||
options: DefaultLoadOptions(),
|
||||
fileData: make(map[string]any),
|
||||
envData: make(map[string]any),
|
||||
@ -114,6 +129,30 @@ func (c *Config) computeValue(item configItem) any {
|
||||
return item.defaultValue
|
||||
}
|
||||
|
||||
// SetFileFormat sets the expected format for configuration files.
|
||||
// Use "auto" to detect based on file extension.
|
||||
func (c *Config) SetFileFormat(format string) error {
|
||||
switch format {
|
||||
case "toml", "json", "yaml", "auto":
|
||||
// Valid formats
|
||||
default:
|
||||
return fmt.Errorf("unsupported file format %q, must be one of: toml, json, yaml, auto", format)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.fileFormat = format
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSecurityOptions configures security checks for file loading
|
||||
func (c *Config) SetSecurityOptions(opts SecurityOptions) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.securityOpts = &opts
|
||||
}
|
||||
|
||||
// Get retrieves a configuration value using the path and indicator if the path was registered
|
||||
func (c *Config) Get(path string) (any, bool) {
|
||||
c.mutex.RLock()
|
||||
|
||||
39
decode.go
39
decode.go
@ -2,6 +2,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@ -119,6 +120,9 @@ func normalizeMap(data any) (map[string]any, error) {
|
||||
// getDecodeHook returns the composite decode hook for all type conversions
|
||||
func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
||||
return mapstructure.ComposeDecodeHookFunc(
|
||||
// JSON Number handling
|
||||
jsonNumberHookFunc(),
|
||||
|
||||
// Network types
|
||||
stringToNetIPHookFunc(),
|
||||
stringToNetIPNetHookFunc(),
|
||||
@ -134,6 +138,41 @@ func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc {
|
||||
)
|
||||
}
|
||||
|
||||
// jsonNumberHookFunc handles json.Number conversion to appropriate numeric types
|
||||
func jsonNumberHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||
// Check if source is json.Number
|
||||
if f != reflect.TypeOf(json.Number("")) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
num := data.(json.Number)
|
||||
|
||||
// Convert based on target type
|
||||
switch t.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return num.Int64()
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
// Parse as int64 first, then convert
|
||||
i, err := num.Int64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if i < 0 {
|
||||
return nil, fmt.Errorf("cannot convert negative number to unsigned type")
|
||||
}
|
||||
return uint64(i), nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return num.Float64()
|
||||
case reflect.String:
|
||||
return num.String(), nil
|
||||
default:
|
||||
// Return as-is for other types
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stringToNetIPHookFunc handles net.IP conversion
|
||||
func stringToNetIPHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(f reflect.Type, t reflect.Type, data any) (any, error) {
|
||||
|
||||
418
dynamic_test.go
Normal file
418
dynamic_test.go
Normal file
@ -0,0 +1,418 @@
|
||||
// FILE: lixenwraith/config/dynamic_test.go
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMultiFormatLoading tests loading different config formats
|
||||
func TestMultiFormatLoading(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create test config in different formats
|
||||
tomlConfig := `
|
||||
[server]
|
||||
host = "toml-host"
|
||||
port = 8080
|
||||
|
||||
[database]
|
||||
url = "postgres://localhost/toml"
|
||||
`
|
||||
|
||||
jsonConfig := `{
|
||||
"server": {
|
||||
"host": "json-host",
|
||||
"port": 9090
|
||||
},
|
||||
"database": {
|
||||
"url": "postgres://localhost/json"
|
||||
}
|
||||
}`
|
||||
|
||||
yamlConfig := `
|
||||
server:
|
||||
host: yaml-host
|
||||
port: 7070
|
||||
database:
|
||||
url: postgres://localhost/yaml
|
||||
`
|
||||
|
||||
// Write config files
|
||||
tomlPath := filepath.Join(tmpDir, "config.toml")
|
||||
jsonPath := filepath.Join(tmpDir, "config.json")
|
||||
yamlPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
require.NoError(t, os.WriteFile(tomlPath, []byte(tomlConfig), 0644))
|
||||
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlConfig), 0644))
|
||||
|
||||
t.Run("AutoDetectFormats", func(t *testing.T) {
|
||||
cfg := New()
|
||||
cfg.Register("server.host", "")
|
||||
cfg.Register("server.port", 0)
|
||||
cfg.Register("database.url", "")
|
||||
|
||||
// Test TOML
|
||||
cfg.SetFileFormat("auto")
|
||||
require.NoError(t, cfg.LoadFile(tomlPath))
|
||||
host, _ := cfg.Get("server.host")
|
||||
assert.Equal(t, "toml-host", host)
|
||||
|
||||
// Test JSON
|
||||
require.NoError(t, cfg.LoadFile(jsonPath))
|
||||
host, _ = cfg.Get("server.host")
|
||||
assert.Equal(t, "json-host", host)
|
||||
port, _ := cfg.Get("server.port")
|
||||
// JSON number should be preserved as json.Number but convertible
|
||||
switch v := port.(type) {
|
||||
case json.Number:
|
||||
// Expected for raw value
|
||||
assert.Equal(t, json.Number("9090"), v)
|
||||
case int64:
|
||||
// Expected after decode hook conversion
|
||||
assert.Equal(t, int64(9090), v)
|
||||
case float64:
|
||||
// Alternative conversion
|
||||
assert.Equal(t, float64(9090), v)
|
||||
default:
|
||||
t.Errorf("Unexpected type for port: %T", port)
|
||||
}
|
||||
|
||||
// Test YAML
|
||||
require.NoError(t, cfg.LoadFile(yamlPath))
|
||||
host, _ = cfg.Get("server.host")
|
||||
assert.Equal(t, "yaml-host", host)
|
||||
})
|
||||
|
||||
t.Run("ExplicitFormat", func(t *testing.T) {
|
||||
cfg := New()
|
||||
cfg.Register("server.host", "")
|
||||
|
||||
// Force JSON parsing on .conf file
|
||||
confPath := filepath.Join(tmpDir, "config.conf")
|
||||
require.NoError(t, os.WriteFile(confPath, []byte(jsonConfig), 0644))
|
||||
|
||||
cfg.SetFileFormat("json")
|
||||
require.NoError(t, cfg.LoadFile(confPath))
|
||||
|
||||
host, _ := cfg.Get("server.host")
|
||||
assert.Equal(t, "json-host", host)
|
||||
})
|
||||
|
||||
t.Run("ContentDetection", func(t *testing.T) {
|
||||
cfg := New()
|
||||
cfg.Register("server.host", "")
|
||||
|
||||
// Ambiguous extension
|
||||
ambigPath := filepath.Join(tmpDir, "config.conf")
|
||||
require.NoError(t, os.WriteFile(ambigPath, []byte(yamlConfig), 0644))
|
||||
|
||||
cfg.SetFileFormat("auto")
|
||||
require.NoError(t, cfg.LoadFile(ambigPath))
|
||||
|
||||
host, _ := cfg.Get("server.host")
|
||||
assert.Equal(t, "yaml-host", host)
|
||||
})
|
||||
}
|
||||
|
||||
// TestDynamicFormatSwitching tests runtime format changes
|
||||
func TestDynamicFormatSwitching(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create configs in different formats with same structure
|
||||
configs := map[string]string{
|
||||
"toml": `value = "from-toml"`,
|
||||
"json": `{"value": "from-json"}`,
|
||||
"yaml": `value: from-yaml`,
|
||||
}
|
||||
|
||||
cfg := New()
|
||||
cfg.Register("value", "default")
|
||||
|
||||
for format, content := range configs {
|
||||
t.Run(format, func(t *testing.T) {
|
||||
filePath := filepath.Join(tmpDir, "config."+format)
|
||||
require.NoError(t, os.WriteFile(filePath, []byte(content), 0644))
|
||||
|
||||
// Set format and load
|
||||
require.NoError(t, cfg.SetFileFormat(format))
|
||||
require.NoError(t, cfg.LoadFile(filePath))
|
||||
|
||||
val, _ := cfg.Get("value")
|
||||
assert.Equal(t, "from-"+format, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWatchFileFormatSwitch tests watching different file formats
|
||||
func TestWatchFileFormatSwitch(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
tomlPath := filepath.Join(tmpDir, "config.toml")
|
||||
jsonPath := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-1"`), 0644))
|
||||
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-1"}`), 0644))
|
||||
|
||||
cfg := New()
|
||||
cfg.Register("value", "default")
|
||||
|
||||
// Configure fast polling for test
|
||||
opts := WatchOptions{
|
||||
PollInterval: testPollInterval, // Fast polling for tests
|
||||
Debounce: testDebounce, // Short debounce
|
||||
MaxWatchers: 10,
|
||||
}
|
||||
|
||||
// Start watching TOML
|
||||
cfg.SetFileFormat("auto")
|
||||
require.NoError(t, cfg.LoadFile(tomlPath))
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
defer cfg.StopAutoUpdate()
|
||||
|
||||
// Wait for watcher to start
|
||||
require.Eventually(t, func() bool {
|
||||
return cfg.IsWatching()
|
||||
}, 4*testDebounce, 2*SpinWaitInterval)
|
||||
|
||||
val, _ := cfg.Get("value")
|
||||
assert.Equal(t, "toml-1", val)
|
||||
|
||||
// Switch to JSON with format hint
|
||||
require.NoError(t, cfg.WatchFile(jsonPath, "json"))
|
||||
|
||||
// Wait for new watcher to start
|
||||
require.Eventually(t, func() bool {
|
||||
return cfg.IsWatching()
|
||||
}, 4*testDebounce, 2*SpinWaitInterval)
|
||||
|
||||
// Get watch channel AFTER switching files
|
||||
changes := cfg.Watch()
|
||||
|
||||
val, _ = cfg.Get("value")
|
||||
assert.Equal(t, "json-1", val)
|
||||
|
||||
// Update JSON file
|
||||
require.NoError(t, os.WriteFile(jsonPath, []byte(`{"value": "json-2"}`), 0644))
|
||||
|
||||
// Wait for change notification
|
||||
select {
|
||||
case path := <-changes:
|
||||
assert.Equal(t, "value", path)
|
||||
// Wait a bit for value to be updated
|
||||
require.Eventually(t, func() bool {
|
||||
val, _ := cfg.Get("value")
|
||||
return val == "json-2"
|
||||
}, testEventuallyTimeout, 2*SpinWaitInterval)
|
||||
case <-time.After(testWatchTimeout):
|
||||
t.Error("Timeout waiting for JSON file change")
|
||||
}
|
||||
|
||||
// Update old TOML file - should NOT trigger notification
|
||||
require.NoError(t, os.WriteFile(tomlPath, []byte(`value = "toml-2"`), 0644))
|
||||
|
||||
// Should not receive notification from old file
|
||||
select {
|
||||
case <-changes:
|
||||
t.Error("Should not receive changes from old TOML file")
|
||||
case <-time.After(testPollWindow):
|
||||
// Expected - no change notification
|
||||
}
|
||||
}
|
||||
|
||||
// TestSecurityOptions tests security features
|
||||
func TestSecurityOptions(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
t.Run("PathTraversal", func(t *testing.T) {
|
||||
cfg := New()
|
||||
cfg.SetSecurityOptions(SecurityOptions{
|
||||
PreventPathTraversal: true,
|
||||
})
|
||||
|
||||
// Test various malicious paths
|
||||
maliciousPaths := []string{
|
||||
"../../../etc/passwd",
|
||||
"./../etc/passwd",
|
||||
"config/../../../etc/passwd",
|
||||
filepath.Join("..", "..", "etc", "passwd"),
|
||||
}
|
||||
|
||||
for _, malPath := range maliciousPaths {
|
||||
err := cfg.LoadFile(malPath)
|
||||
assert.Error(t, err, "Should reject path: %s", malPath)
|
||||
assert.Contains(t, err.Error(), "path traversal")
|
||||
}
|
||||
|
||||
// Valid paths should work
|
||||
validPath := filepath.Join(tmpDir, "config.toml")
|
||||
os.WriteFile(validPath, []byte(`test = "value"`), 0644)
|
||||
cfg.Register("test", "")
|
||||
|
||||
err := cfg.LoadFile(validPath)
|
||||
assert.NoError(t, err, "Should accept valid absolute path")
|
||||
})
|
||||
|
||||
t.Run("FileSizeLimit", func(t *testing.T) {
|
||||
cfg := New()
|
||||
cfg.SetSecurityOptions(SecurityOptions{
|
||||
MaxFileSize: 100, // 100 bytes limit
|
||||
})
|
||||
|
||||
// Create large file
|
||||
largePath := filepath.Join(tmpDir, "large.toml")
|
||||
largeContent := make([]byte, 1024)
|
||||
for i := range largeContent {
|
||||
largeContent[i] = 'a'
|
||||
}
|
||||
require.NoError(t, os.WriteFile(largePath, largeContent, 0644))
|
||||
|
||||
err := cfg.LoadFile(largePath)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeds maximum size")
|
||||
})
|
||||
|
||||
t.Run("FileOwnership", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping ownership test on Windows")
|
||||
}
|
||||
|
||||
cfg := New()
|
||||
cfg.SetSecurityOptions(SecurityOptions{
|
||||
EnforceFileOwnership: true,
|
||||
})
|
||||
|
||||
// Create file owned by current user (should succeed)
|
||||
ownedPath := filepath.Join(tmpDir, "owned.toml")
|
||||
require.NoError(t, os.WriteFile(ownedPath, []byte(`test = "value"`), 0644))
|
||||
|
||||
cfg.Register("test", "")
|
||||
err := cfg.LoadFile(ownedPath)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// waitForWatchingState waits for watcher state, preventing race conditions of goroutine start and test check
|
||||
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
|
||||
t.Helper()
|
||||
require.Eventually(t, func() bool {
|
||||
return cfg.IsWatching() == expected
|
||||
}, testEventuallyTimeout, 2*SpinWaitInterval, msgAndArgs...)
|
||||
}
|
||||
|
||||
// TestBuilderWithFormat tests Builder integration
|
||||
func TestBuilderWithFormat(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
jsonPath := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
jsonConfig := `{
|
||||
"server": {
|
||||
"host": "builder-host",
|
||||
"port": 8080
|
||||
}
|
||||
}`
|
||||
require.NoError(t, os.WriteFile(jsonPath, []byte(jsonConfig), 0644))
|
||||
|
||||
type Config struct {
|
||||
Server struct {
|
||||
Host string `json:"host" toml:"host"`
|
||||
Port int `json:"port" toml:"port"`
|
||||
} `json:"server" toml:"server"`
|
||||
}
|
||||
|
||||
defaults := &Config{}
|
||||
defaults.Server.Host = "default-host"
|
||||
defaults.Server.Port = 3000
|
||||
|
||||
cfg, err := NewBuilder().
|
||||
WithDefaults(defaults).
|
||||
WithFile(jsonPath).
|
||||
WithFileFormat("json").
|
||||
WithTagName("toml"). // Use toml tags for registration
|
||||
WithSecurityOptions(SecurityOptions{
|
||||
PreventPathTraversal: true,
|
||||
MaxFileSize: 1024 * 1024, // 1MB
|
||||
}).
|
||||
Build()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the value was loaded
|
||||
host, exists := cfg.Get("server.host")
|
||||
assert.True(t, exists, "server.host should exist")
|
||||
assert.Equal(t, "builder-host", host)
|
||||
|
||||
port, exists := cfg.Get("server.port")
|
||||
assert.True(t, exists, "server.port should exist")
|
||||
// Handle json.Number or converted int
|
||||
switch v := port.(type) {
|
||||
case json.Number:
|
||||
p, _ := v.Int64()
|
||||
assert.Equal(t, int64(8080), p)
|
||||
case int64:
|
||||
assert.Equal(t, int64(8080), v)
|
||||
case float64:
|
||||
assert.Equal(t, float64(8080), v)
|
||||
default:
|
||||
t.Errorf("Unexpected type for port: %T", port)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFormatParsing benchmarks different format parsing speeds
|
||||
func BenchmarkFormatParsing(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
|
||||
// Create test data
|
||||
configs := map[string]string{
|
||||
"toml": `
|
||||
[server]
|
||||
host = "localhost"
|
||||
port = 8080
|
||||
[database]
|
||||
url = "postgres://localhost/db"
|
||||
[cache]
|
||||
ttl = 300
|
||||
`,
|
||||
"json": `{
|
||||
"server": {"host": "localhost", "port": 8080},
|
||||
"database": {"url": "postgres://localhost/db"},
|
||||
"cache": {"ttl": 300}
|
||||
}`,
|
||||
"yaml": `
|
||||
server:
|
||||
host: localhost
|
||||
port: 8080
|
||||
database:
|
||||
url: postgres://localhost/db
|
||||
cache:
|
||||
ttl: 300
|
||||
`,
|
||||
}
|
||||
|
||||
for format, content := range configs {
|
||||
b.Run(format, func(b *testing.B) {
|
||||
path := filepath.Join(tmpDir, "bench."+format)
|
||||
os.WriteFile(path, []byte(content), 0644)
|
||||
|
||||
cfg := New()
|
||||
cfg.Register("server.host", "")
|
||||
cfg.Register("server.port", 0)
|
||||
cfg.Register("database.url", "")
|
||||
cfg.Register("cache.ttl", 0)
|
||||
cfg.SetFileFormat(format)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cfg.LoadFile(path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
141
loader.go
141
loader.go
@ -3,13 +3,18 @@ package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Source represents a configuration source, used to define load precedence
|
||||
@ -143,18 +148,103 @@ func (c *Config) LoadFile(filePath string) error {
|
||||
|
||||
// loadFile reads and parses a TOML configuration file
|
||||
func (c *Config) loadFile(path string) error {
|
||||
// 1. Read and Parse (No Lock)
|
||||
fileData, err := os.ReadFile(path)
|
||||
// Security: Path traversal check
|
||||
if c.securityOpts != nil && c.securityOpts.PreventPathTraversal {
|
||||
// Clean the path and check for traversal attempts
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check if cleaned path tries to go outside current directory
|
||||
if strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) || cleanPath == ".." {
|
||||
return fmt.Errorf("potential path traversal detected in config path: %s", path)
|
||||
}
|
||||
|
||||
// Also check for absolute paths that might escape jail
|
||||
if filepath.IsAbs(cleanPath) && filepath.IsAbs(path) {
|
||||
// Absolute paths are OK if that's what was provided
|
||||
} else if filepath.IsAbs(cleanPath) && !filepath.IsAbs(path) {
|
||||
// Relative path became absolute after cleaning - suspicious
|
||||
return fmt.Errorf("potential path traversal detected in config path: %s", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Read file with size limit
|
||||
fileInfo, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return ErrConfigNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to stat config file '%s': %w", path, err)
|
||||
}
|
||||
|
||||
// Security: File size check
|
||||
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
|
||||
if fileInfo.Size() > c.securityOpts.MaxFileSize {
|
||||
return fmt.Errorf("config file '%s' exceeds maximum size %d bytes", path, c.securityOpts.MaxFileSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Security: File ownership check (Unix only)
|
||||
if c.securityOpts != nil && c.securityOpts.EnforceFileOwnership && runtime.GOOS != "windows" {
|
||||
if stat, ok := fileInfo.Sys().(*syscall.Stat_t); ok {
|
||||
if stat.Uid != uint32(os.Geteuid()) {
|
||||
return fmt.Errorf("config file '%s' is not owned by current user (file UID: %d, process UID: %d)",
|
||||
path, stat.Uid, os.Geteuid())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Read and parse file data
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open config file '%s': %w", path, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Use LimitedReader for additional safety
|
||||
var reader io.Reader = file
|
||||
if c.securityOpts != nil && c.securityOpts.MaxFileSize > 0 {
|
||||
reader = io.LimitReader(file, c.securityOpts.MaxFileSize)
|
||||
}
|
||||
|
||||
fileData, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file '%s': %w", path, err)
|
||||
}
|
||||
|
||||
// Determine format
|
||||
format := c.fileFormat
|
||||
if format == "" || format == "auto" {
|
||||
// Try extension first
|
||||
format = detectFileFormat(path)
|
||||
if format == "" {
|
||||
// Fall back to content detection
|
||||
format = detectFormatFromContent(fileData)
|
||||
if format == "" {
|
||||
// Last resort: use tagName as hint
|
||||
format = c.tagName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse based on detected/specified format
|
||||
fileConfig := make(map[string]any)
|
||||
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
|
||||
switch format {
|
||||
case "toml":
|
||||
if err := toml.Unmarshal(fileData, &fileConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse TOML config file '%s': %w", path, err)
|
||||
}
|
||||
case "json":
|
||||
decoder := json.NewDecoder(bytes.NewReader(fileData))
|
||||
decoder.UseNumber() // Preserve number precision
|
||||
if err := decoder.Decode(&fileConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse JSON config file '%s': %w", path, err)
|
||||
}
|
||||
case "yaml":
|
||||
if err := yaml.Unmarshal(fileData, &fileConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse YAML config file '%s': %w", path, err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unable to determine config format for file '%s'", path)
|
||||
}
|
||||
|
||||
// 2. Prepare New State (Read-Lock Only)
|
||||
@ -185,7 +275,7 @@ func (c *Config) loadFile(path string) error {
|
||||
}
|
||||
apply("", fileConfig)
|
||||
|
||||
// -- 3. Atomically Update Config (Write-Lock)
|
||||
// 3. Atomically Update Config (Write-Lock)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@ -578,4 +668,45 @@ func parseArgs(args []string) (map[string]any, error) {
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// detectFileFormat determines format from file extension
|
||||
func detectFileFormat(path string) string {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
switch ext {
|
||||
case ".toml", ".tml":
|
||||
return "toml"
|
||||
case ".json":
|
||||
return "json"
|
||||
case ".yaml", ".yml":
|
||||
return "yaml"
|
||||
case ".conf", ".config":
|
||||
// Try to detect from content
|
||||
return ""
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// detectFormatFromContent attempts to detect format by parsing
|
||||
func detectFormatFromContent(data []byte) string {
|
||||
// Try JSON first (strict format)
|
||||
var jsonTest any
|
||||
if err := json.Unmarshal(data, &jsonTest); err == nil {
|
||||
return "json"
|
||||
}
|
||||
|
||||
// Try YAML (superset of JSON, so check after JSON)
|
||||
var yamlTest any
|
||||
if err := yaml.Unmarshal(data, &yamlTest); err == nil {
|
||||
return "yaml"
|
||||
}
|
||||
|
||||
// Try TOML last
|
||||
var tomlTest any
|
||||
if err := toml.Unmarshal(data, &tomlTest); err == nil {
|
||||
return "toml"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
26
timing.go
Normal file
26
timing.go
Normal file
@ -0,0 +1,26 @@
|
||||
// FILE: lixenwraith/config/timing.go
|
||||
package config
|
||||
|
||||
import "time"
|
||||
|
||||
// Core timing constants for production use.
|
||||
// These define the fundamental timing behavior of the config package.
|
||||
const (
|
||||
// File watching intervals (ordered by frequency)
|
||||
SpinWaitInterval = 5 * time.Millisecond // CPU-friendly busy-wait quantum
|
||||
MinPollInterval = 100 * time.Millisecond // Hard floor for file stat polling
|
||||
ShutdownTimeout = 100 * time.Millisecond // Graceful watcher termination window
|
||||
DefaultDebounce = 500 * time.Millisecond // File change coalescence period
|
||||
DefaultPollInterval = time.Second // Standard file monitoring frequency
|
||||
DefaultReloadTimeout = 5 * time.Second // Maximum duration for reload operations
|
||||
)
|
||||
|
||||
// Derived timing relationships for internal use.
|
||||
// These maintain consistent ratios between related timers.
|
||||
const (
|
||||
// shutdownPollCycles defines how many spin-wait cycles comprise a shutdown timeout
|
||||
shutdownPollCycles = ShutdownTimeout / SpinWaitInterval // = 20 cycles
|
||||
|
||||
// debounceSettleMultiplier ensures sufficient time for debounce to complete
|
||||
debounceSettleMultiplier = 3 // Wait 3x debounce period for value stabilization
|
||||
)
|
||||
75
watch.go
75
watch.go
@ -11,6 +11,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const DefaultMaxWatchers = 100 // Prevent resource exhaustion
|
||||
|
||||
// WatchOptions configures file watching behavior
|
||||
type WatchOptions struct {
|
||||
// PollInterval for file stat checks (minimum 100ms)
|
||||
@ -32,10 +34,10 @@ type WatchOptions struct {
|
||||
// DefaultWatchOptions returns sensible defaults for file watching
|
||||
func DefaultWatchOptions() WatchOptions {
|
||||
return WatchOptions{
|
||||
PollInterval: time.Second, // Check every second
|
||||
Debounce: 500 * time.Millisecond,
|
||||
MaxWatchers: 100, // Prevent resource exhaustion
|
||||
ReloadTimeout: 5 * time.Second,
|
||||
PollInterval: DefaultPollInterval,
|
||||
Debounce: DefaultDebounce,
|
||||
MaxWatchers: DefaultMaxWatchers,
|
||||
ReloadTimeout: DefaultReloadTimeout,
|
||||
VerifyPermissions: true,
|
||||
}
|
||||
}
|
||||
@ -71,26 +73,32 @@ func (c *Config) AutoUpdate() {
|
||||
// AutoUpdateWithOptions enables automatic configuration reloading with custom options
|
||||
func (c *Config) AutoUpdateWithOptions(opts WatchOptions) {
|
||||
// Validate options
|
||||
if opts.PollInterval < 100*time.Millisecond {
|
||||
opts.PollInterval = 100 * time.Millisecond // Minimum poll interval
|
||||
if opts.PollInterval < MinPollInterval {
|
||||
opts.PollInterval = MinPollInterval
|
||||
}
|
||||
if opts.MaxWatchers <= 0 {
|
||||
opts.MaxWatchers = 100
|
||||
}
|
||||
if opts.ReloadTimeout <= 0 {
|
||||
opts.ReloadTimeout = 5 * time.Second
|
||||
opts.ReloadTimeout = DefaultReloadTimeout
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Check if we have a file to watch
|
||||
// Get path of current file to watch
|
||||
filePath := c.getConfigFilePath()
|
||||
if filePath == "" {
|
||||
// No file configured, nothing to watch
|
||||
return
|
||||
}
|
||||
|
||||
// Stop existing watcher if path changed
|
||||
if c.watcher != nil && c.watcher.filePath != filePath {
|
||||
c.watcher.stop()
|
||||
c.watcher = nil
|
||||
}
|
||||
|
||||
// Initialize watcher if needed
|
||||
if c.watcher == nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@ -131,17 +139,24 @@ func (c *Config) Watch() <-chan string {
|
||||
}
|
||||
|
||||
// WatchFile stops any existing file watcher, loads a new configuration file,
|
||||
// and starts a new watcher on that file path.
|
||||
func (c *Config) WatchFile(filePath string) error {
|
||||
// Stop any currently running watcher to prevent orphaned goroutines.
|
||||
// and starts a new watcher on that file path. Optionally accepts format hint.
|
||||
func (c *Config) WatchFile(filePath string, formatHint ...string) error {
|
||||
// Stop any currently running watcher
|
||||
c.StopAutoUpdate()
|
||||
|
||||
// Load the new file and set `configFilePath` to the new path
|
||||
// Set format hint if provided
|
||||
if len(formatHint) > 0 {
|
||||
if err := c.SetFileFormat(formatHint[0]); err != nil {
|
||||
return fmt.Errorf("invalid format hint: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load the new file
|
||||
if err := c.LoadFile(filePath); err != nil {
|
||||
return fmt.Errorf("failed to load new file for watching: %w", err)
|
||||
}
|
||||
|
||||
// Start a new watcher on the new file
|
||||
// Get previous watcher options if available
|
||||
c.mutex.RLock()
|
||||
opts := DefaultWatchOptions()
|
||||
if c.watcher != nil {
|
||||
@ -149,18 +164,36 @@ func (c *Config) WatchFile(filePath string) error {
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
// Start new watcher (AutoUpdateWithOptions will create a new watcher with the new file path)
|
||||
c.AutoUpdateWithOptions(opts)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WatchWithOptions returns a channel with custom watch options
|
||||
// should not restart the watcher if it's already running with the same file
|
||||
func (c *Config) WatchWithOptions(opts WatchOptions) <-chan string {
|
||||
c.mutex.RLock()
|
||||
watcher := c.watcher
|
||||
filePath := c.configFilePath
|
||||
c.mutex.RUnlock()
|
||||
|
||||
// If no file configured, return closed channel
|
||||
if filePath == "" {
|
||||
ch := make(chan string)
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// If watcher exists and is watching the current file, just subscribe
|
||||
if watcher != nil && watcher.filePath == filePath && watcher.watching.Load() {
|
||||
return watcher.subscribe()
|
||||
}
|
||||
|
||||
// First ensure auto-update is running
|
||||
c.AutoUpdateWithOptions(opts)
|
||||
|
||||
c.mutex.RLock()
|
||||
watcher := c.watcher
|
||||
watcher = c.watcher
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if watcher == nil {
|
||||
@ -363,18 +396,22 @@ func (w *watcher) notifyWatchers(path string) {
|
||||
|
||||
// stop terminates the watcher
|
||||
func (w *watcher) stop() {
|
||||
w.cancel()
|
||||
if w.cancel != nil {
|
||||
w.cancel()
|
||||
}
|
||||
|
||||
// Stop debounce timer
|
||||
w.mu.Lock()
|
||||
if w.debounceTimer != nil {
|
||||
w.debounceTimer.Stop()
|
||||
w.debounceTimer = nil
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
// Wait for watch loop to exit
|
||||
for w.watching.Load() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
// Wait for watch loop to exit with timeout
|
||||
deadline := time.Now().Add(ShutdownTimeout)
|
||||
for w.watching.Load() && time.Now().Before(deadline) {
|
||||
time.Sleep(SpinWaitInterval)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -14,6 +14,29 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test-specific timing constants derived from production values.
|
||||
// These accelerate test execution while maintaining timing relationships.
|
||||
const (
|
||||
// testAcceleration reduces all intervals by this factor for faster tests
|
||||
testAcceleration = 10
|
||||
|
||||
// Accelerated test timings
|
||||
testPollInterval = DefaultPollInterval / testAcceleration // 100ms (from 1s)
|
||||
testDebounce = DefaultDebounce / testAcceleration // 50ms (from 500ms)
|
||||
testReloadTimeout = DefaultReloadTimeout / testAcceleration // 500ms (from 5s)
|
||||
testShutdownTimeout = ShutdownTimeout // Keep original for safety
|
||||
testSpinWaitInterval = SpinWaitInterval // Keep original for CPU efficiency
|
||||
|
||||
// Test assertion timeouts
|
||||
testEventuallyTimeout = testReloadTimeout // Aligns with reload timing
|
||||
testWatchTimeout = 2 * DefaultPollInterval // 2s for change propagation
|
||||
|
||||
// Derived test multipliers with clear purpose
|
||||
testDebounceSettle = debounceSettleMultiplier * testDebounce // 150ms for debounce verification
|
||||
testPollWindow = 3 * testPollInterval // 300ms change detection window
|
||||
testStateStabilize = 4 * testDebounce // 200ms for state convergence
|
||||
)
|
||||
|
||||
// TestAutoUpdate tests automatic configuration reloading
|
||||
func TestAutoUpdate(t *testing.T) {
|
||||
// Setup
|
||||
@ -59,8 +82,8 @@ enabled = true
|
||||
|
||||
// Enable auto-update with fast polling
|
||||
opts := WatchOptions{
|
||||
PollInterval: 100 * time.Millisecond,
|
||||
Debounce: 50 * time.Millisecond,
|
||||
PollInterval: testPollInterval,
|
||||
Debounce: testDebounce,
|
||||
MaxWatchers: 10,
|
||||
}
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
@ -93,7 +116,7 @@ enabled = false
|
||||
require.NoError(t, os.WriteFile(configPath, []byte(updatedConfig), 0644))
|
||||
|
||||
// Wait for changes to be detected
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
time.Sleep(testPollWindow)
|
||||
|
||||
// Verify new values
|
||||
port, _ = cfg.Get("server.port")
|
||||
@ -130,8 +153,8 @@ func TestWatchFileDeleted(t *testing.T) {
|
||||
|
||||
// Enable watching
|
||||
opts := WatchOptions{
|
||||
PollInterval: 100 * time.Millisecond,
|
||||
Debounce: 50 * time.Millisecond,
|
||||
PollInterval: testPollInterval,
|
||||
Debounce: testDebounce,
|
||||
}
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
defer cfg.StopAutoUpdate()
|
||||
@ -145,7 +168,7 @@ func TestWatchFileDeleted(t *testing.T) {
|
||||
select {
|
||||
case path := <-changes:
|
||||
assert.Equal(t, "file_deleted", path)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
case <-time.After(testEventuallyTimeout):
|
||||
t.Error("Timeout waiting for deletion notification")
|
||||
}
|
||||
}
|
||||
@ -169,8 +192,8 @@ func TestWatchPermissionChange(t *testing.T) {
|
||||
|
||||
// Enable watching with permission verification
|
||||
opts := WatchOptions{
|
||||
PollInterval: 100 * time.Millisecond,
|
||||
Debounce: 50 * time.Millisecond,
|
||||
PollInterval: testPollInterval,
|
||||
Debounce: testDebounce,
|
||||
VerifyPermissions: true,
|
||||
}
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
@ -185,7 +208,7 @@ func TestWatchPermissionChange(t *testing.T) {
|
||||
select {
|
||||
case path := <-changes:
|
||||
assert.Equal(t, "permissions_changed", path)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
case <-time.After(testEventuallyTimeout):
|
||||
t.Error("Timeout waiting for permission change notification")
|
||||
}
|
||||
}
|
||||
@ -203,7 +226,7 @@ func TestMaxWatchers(t *testing.T) {
|
||||
|
||||
// Enable watching with low max watchers
|
||||
opts := WatchOptions{
|
||||
PollInterval: 100 * time.Millisecond,
|
||||
PollInterval: testPollInterval,
|
||||
MaxWatchers: 3,
|
||||
}
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
@ -229,7 +252,7 @@ func TestMaxWatchers(t *testing.T) {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
assert.False(t, ok, "Channel 3 should be closed (max watchers exceeded)")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
case <-time.After(testEventuallyTimeout):
|
||||
t.Error("Channel 3 should be closed immediately")
|
||||
}
|
||||
}
|
||||
@ -239,8 +262,8 @@ func TestMaxWatchers(t *testing.T) {
|
||||
assert.Equal(t, 3, cfg.WatcherCount())
|
||||
}
|
||||
|
||||
// TestDebounce tests that rapid changes are debounced
|
||||
func TestDebounce(t *testing.T) {
|
||||
// TestRapidDebounce tests that rapid changes are debounced
|
||||
func TestRapidDebounce(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "test.toml")
|
||||
|
||||
@ -253,8 +276,8 @@ func TestDebounce(t *testing.T) {
|
||||
|
||||
// Enable watching with longer debounce
|
||||
opts := WatchOptions{
|
||||
PollInterval: 50 * time.Millisecond,
|
||||
Debounce: 200 * time.Millisecond,
|
||||
PollInterval: testDebounce,
|
||||
Debounce: testStateStabilize,
|
||||
}
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
defer cfg.StopAutoUpdate()
|
||||
@ -282,11 +305,11 @@ func TestDebounce(t *testing.T) {
|
||||
for i := 2; i <= 5; i++ {
|
||||
content := fmt.Sprintf(`value = %d`, i)
|
||||
require.NoError(t, os.WriteFile(configPath, []byte(content), 0644))
|
||||
time.Sleep(50 * time.Millisecond) // Less than debounce period
|
||||
time.Sleep(testDebounce) // Less than debounce period
|
||||
}
|
||||
|
||||
// Wait for debounce to complete
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
time.Sleep(2 * testStateStabilize)
|
||||
done <- true
|
||||
|
||||
// Should only see one change due to debounce
|
||||
@ -328,7 +351,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
|
||||
require.NoError(t, cfg.LoadFile(configPath))
|
||||
|
||||
opts := WatchOptions{
|
||||
PollInterval: 50 * time.Millisecond,
|
||||
PollInterval: testDebounce,
|
||||
MaxWatchers: 50,
|
||||
}
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
@ -353,7 +376,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
|
||||
select {
|
||||
case <-ch:
|
||||
// OK, got a change
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
case <-time.After(2 * SpinWaitInterval):
|
||||
// OK, no changes yet
|
||||
}
|
||||
}(i)
|
||||
@ -384,7 +407,7 @@ func TestConcurrentWatchOperations(t *testing.T) {
|
||||
isWatching = true
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(2 * SpinWaitInterval)
|
||||
}
|
||||
if !isWatching {
|
||||
errors <- fmt.Errorf("checker %d: IsWatching returned false", id)
|
||||
@ -418,8 +441,8 @@ func TestReloadTimeout(t *testing.T) {
|
||||
|
||||
// Very short timeout
|
||||
opts := WatchOptions{
|
||||
PollInterval: 100 * time.Millisecond,
|
||||
ReloadTimeout: 1 * time.Nanosecond, // Extremely short
|
||||
PollInterval: testPollInterval,
|
||||
ReloadTimeout: 1 * time.Nanosecond,
|
||||
}
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
defer cfg.StopAutoUpdate()
|
||||
@ -454,7 +477,7 @@ func TestStopAutoUpdate(t *testing.T) {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
assert.False(t, ok, "Channel should be closed after stop")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
case <-time.After(ShutdownTimeout):
|
||||
// OK, channel might not close immediately
|
||||
}
|
||||
|
||||
@ -484,7 +507,7 @@ func BenchmarkWatchOverhead(b *testing.B) {
|
||||
|
||||
// Enable watching
|
||||
opts := WatchOptions{
|
||||
PollInterval: 100 * time.Millisecond,
|
||||
PollInterval: testPollInterval,
|
||||
}
|
||||
cfg.AutoUpdateWithOptions(opts)
|
||||
defer cfg.StopAutoUpdate()
|
||||
@ -494,11 +517,4 @@ func BenchmarkWatchOverhead(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = cfg.Get(fmt.Sprintf("value%d", i%100))
|
||||
}
|
||||
}
|
||||
|
||||
// helper function to wait for watcher state, preventing race conditions of goroutine start and test check
|
||||
func waitForWatchingState(t *testing.T, cfg *Config, expected bool, msgAndArgs ...any) {
|
||||
require.Eventually(t, func() bool {
|
||||
return cfg.IsWatching() == expected
|
||||
}, 200*time.Millisecond, 10*time.Millisecond, msgAndArgs...)
|
||||
}
|
||||
Reference in New Issue
Block a user