// FILE: lixenwraith/config/decode.go package config import ( "encoding/json" "fmt" "net" "net/url" "reflect" "strings" "time" "github.com/go-viper/mapstructure/v2" ) // unmarshal is the single authoritative function for decoding configuration // into target structures. All public decoding methods delegate to this func (c *Config) unmarshal(source Source, target any, basePath ...string) error { // Parse variadic basePath path := "" switch len(basePath) { case 0: // Use default empty path case 1: path = basePath[0] default: return wrapError(ErrInvalidPath, fmt.Errorf("too many basePath arguments: expected 0 or 1, got %d", len(basePath))) } // Validate target rv := reflect.ValueOf(target) if rv.Kind() != reflect.Ptr || rv.IsNil() { return wrapError(ErrTypeMismatch, fmt.Errorf("unmarshal target must be non-nil pointer, got %T", target)) } c.mutex.RLock() defer c.mutex.RUnlock() // Build nested map based on source selection nestedMap := make(map[string]any) if source == "" { // Use current merged state for path, item := range c.items { setNestedValue(nestedMap, path, item.currentValue) } } else { // Use specific source for path, item := range c.items { if val, exists := item.values[source]; exists { setNestedValue(nestedMap, path, val) } } } // Navigate to basePath section sectionData := navigateToPath(nestedMap, path) // Ensure we have a map to decode, normalizing if necessary sectionMap, err := normalizeMap(sectionData) if err != nil { if sectionData == nil { sectionMap = make(map[string]any) // Empty section is valid } else { // Path points to a non-map value, which is an error for Scan return wrapError(ErrTypeMismatch, fmt.Errorf("path %q refers to non-map value (type %T)", path, sectionData)) } } // Create decoder with comprehensive hooks decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ Result: target, TagName: c.tagName, WeaklyTypedInput: true, DecodeHook: c.getDecodeHook(), ZeroFields: true, Metadata: nil, }) if err != nil { return wrapError(ErrDecode, fmt.Errorf("decoder creation failed: %w", err)) } if err := decoder.Decode(sectionMap); err != nil { return wrapError(ErrDecode, fmt.Errorf("decode failed for path %q: %w", path, err)) } return nil } // normalizeMap ensures that the input data is a map[string]any for the decoder func normalizeMap(data any) (map[string]any, error) { if data == nil { return make(map[string]any), nil } // If it's already the correct type, return it. if m, ok := data.(map[string]any); ok { return m, nil } // Use reflection to handle other map types (e.g., map[string]bool) v := reflect.ValueOf(data) if v.Kind() == reflect.Map { if v.Type().Key().Kind() != reflect.String { return nil, wrapError(ErrTypeMismatch, fmt.Errorf("map keys must be strings, but got %v", v.Type().Key())) } // Create a new map[string]any and copy the values normalized := make(map[string]any, v.Len()) iter := v.MapRange() for iter.Next() { normalized[iter.Key().String()] = iter.Value().Interface() } return normalized, nil } return nil, wrapError(ErrTypeMismatch, fmt.Errorf("expected a map but got %T", data)) } // getDecodeHook returns the composite decode hook for all type conversions func (c *Config) getDecodeHook() mapstructure.DecodeHookFunc { return mapstructure.ComposeDecodeHookFunc( // JSON Number handling jsonNumberHookFunc(), // Network types stringToNetIPHookFunc(), stringToNetIPNetHookFunc(), stringToURLHookFunc(), // Standard hooks mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToTimeHookFunc(time.RFC3339), mapstructure.StringToSliceHookFunc(","), // Custom application hooks c.customDecodeHook(), ) } // jsonNumberHookFunc handles json.Number conversion to appropriate numeric types func jsonNumberHookFunc() mapstructure.DecodeHookFunc { return func(f reflect.Type, t reflect.Type, data any) (any, error) { // Check if source is json.Number if f != reflect.TypeOf(json.Number("")) { return data, nil } num := data.(json.Number) // Convert based on target type switch t.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: val, err := num.Int64() if err != nil { return nil, wrapError(ErrDecode, err) } return val, nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: // Parse as int64 first, then convert i, err := num.Int64() if err != nil { return nil, wrapError(ErrDecode, err) } if i < 0 { return nil, wrapError(ErrDecode, fmt.Errorf("cannot convert negative number to unsigned type")) } return uint64(i), nil case reflect.Float32, reflect.Float64: val, err := num.Float64() if err != nil { return nil, wrapError(ErrDecode, err) } return val, nil case reflect.String: return num.String(), nil default: // Return as-is for other types return data, nil } } } // stringToNetIPHookFunc handles net.IP conversion func stringToNetIPHookFunc() mapstructure.DecodeHookFunc { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String { return data, nil } if t != reflect.TypeOf(net.IP{}) { return data, nil } // SECURITY: Validate IP string format to prevent injection str := data.(string) if len(str) > MaxIPv6Length { return nil, fmt.Errorf("invalid IP length: %d", len(str)) } ip := net.ParseIP(str) if ip == nil { return nil, fmt.Errorf("invalid IP address: %s", str) } return ip, nil } } // stringToNetIPNetHookFunc handles net.IPNet conversion func stringToNetIPNetHookFunc() mapstructure.DecodeHookFunc { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String { return data, nil } isPtr := t.Kind() == reflect.Ptr targetType := t if isPtr { targetType = t.Elem() } if targetType != reflect.TypeOf(net.IPNet{}) { return data, nil } str := data.(string) if len(str) > MaxCIDRLength { return nil, wrapError(ErrDecode, fmt.Errorf("invalid CIDR length: %d", len(str))) } _, ipnet, err := net.ParseCIDR(str) if err != nil { return nil, wrapError(ErrDecode, fmt.Errorf("invalid CIDR: %w", err)) } if isPtr { return ipnet, nil } return *ipnet, nil } } // stringToURLHookFunc handles url.URL conversion func stringToURLHookFunc() mapstructure.DecodeHookFunc { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String { return data, nil } isPtr := t.Kind() == reflect.Ptr targetType := t if isPtr { targetType = t.Elem() } if targetType != reflect.TypeOf(url.URL{}) { return data, nil } str := data.(string) if len(str) > MaxURLLength { return nil, wrapError(ErrDecode, fmt.Errorf("URL too long: %d bytes", len(str))) } u, err := url.Parse(str) if err != nil { return nil, wrapError(ErrDecode, fmt.Errorf("invalid URL: %w", err)) } if isPtr { return u, nil } return *u, nil } } // customDecodeHook allows for application-specific type conversions func (c *Config) customDecodeHook() mapstructure.DecodeHookFunc { return func(f reflect.Type, t reflect.Type, data any) (any, error) { // TODO: Add support of custom validation for application types here // Example: Rate limit parsing, permission validation, etc. // Pass through by default return data, nil } } // navigateToPath traverses nested map to reach the specified path func navigateToPath(nested map[string]any, path string) any { if path == "" { return nested } path = strings.TrimSuffix(path, ".") if path == "" { return nested } segments := strings.Split(path, ".") current := any(nested) for _, segment := range segments { currentMap, ok := current.(map[string]any) if !ok { return nil } value, exists := currentMap[segment] if !exists { return nil } current = value } return current }