diff --git a/src/internal/config/pipeline.go b/src/internal/config/pipeline.go index 6b26373..e83d145 100644 --- a/src/internal/config/pipeline.go +++ b/src/internal/config/pipeline.go @@ -123,9 +123,9 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err } } - // CHANGED: Validate SSL if present - if ssl, ok := cfg.Options["ssl"].(map[string]any); ok { - if err := validateSSLOptions("HTTP source", pipelineName, sourceIndex, ssl); err != nil { + // Validate TLS if present + if tls, ok := cfg.Options["tls"].(map[string]any); ok { + if err := validateTLSOptions("HTTP source", pipelineName, sourceIndex, tls); err != nil { return err } } @@ -145,9 +145,9 @@ func validateSource(pipelineName string, sourceIndex int, cfg *SourceConfig) err } } - // CHANGED: Validate SSL if present - if ssl, ok := cfg.Options["ssl"].(map[string]any); ok { - if err := validateSSLOptions("TCP source", pipelineName, sourceIndex, ssl); err != nil { + // Validate TLS if present + if tls, ok := cfg.Options["tls"].(map[string]any); ok { + if err := validateTLSOptions("TCP source", pipelineName, sourceIndex, tls); err != nil { return err } } @@ -211,9 +211,9 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts } } - // Validate SSL if present - if ssl, ok := cfg.Options["ssl"].(map[string]any); ok { - if err := validateSSLOptions("HTTP", pipelineName, sinkIndex, ssl); err != nil { + // Validate TLS if present + if tls, ok := cfg.Options["tls"].(map[string]any); ok { + if err := validateTLSOptions("HTTP", pipelineName, sinkIndex, tls); err != nil { return err } } @@ -255,9 +255,9 @@ func validateSink(pipelineName string, sinkIndex int, cfg *SinkConfig, allPorts } } - // Validate SSL if present - if ssl, ok := cfg.Options["ssl"].(map[string]any); ok { - if err := validateSSLOptions("TCP", pipelineName, sinkIndex, ssl); err != nil { + // Validate TLS if present + if tls, ok := cfg.Options["tls"].(map[string]any); ok { + if err := validateTLSOptions("TCP", pipelineName, sinkIndex, tls); err != nil { return err } } diff --git a/src/internal/config/server.go b/src/internal/config/server.go index fc9e9d0..47f98cd 100644 --- a/src/internal/config/server.go +++ b/src/internal/config/server.go @@ -12,8 +12,8 @@ type TCPConfig struct { Port int64 `toml:"port"` BufferSize int64 `toml:"buffer_size"` - // SSL/TLS Configuration - SSL *SSLConfig `toml:"ssl"` + // TLS Configuration + TLS *TLSConfig `toml:"tls"` // Net limiting NetLimit *NetLimitConfig `toml:"net_limit"` @@ -31,8 +31,8 @@ type HTTPConfig struct { StreamPath string `toml:"stream_path"` StatusPath string `toml:"status_path"` - // SSL/TLS Configuration - SSL *SSLConfig `toml:"ssl"` + // TLS Configuration + TLS *TLSConfig `toml:"tls"` // Nate limiting NetLimit *NetLimitConfig `toml:"net_limit"` diff --git a/src/internal/config/ssl.go b/src/internal/config/tls.go similarity index 78% rename from src/internal/config/ssl.go rename to src/internal/config/tls.go index 9680420..9252b45 100644 --- a/src/internal/config/ssl.go +++ b/src/internal/config/tls.go @@ -1,4 +1,4 @@ -// FILE: logwisp/src/internal/config/ssl.go +// FILE: logwisp/src/internal/config/tls.go package config import ( @@ -6,7 +6,7 @@ import ( "os" ) -type SSLConfig struct { +type TLSConfig struct { Enabled bool `toml:"enabled"` CertFile string `toml:"cert_file"` KeyFile string `toml:"key_file"` @@ -30,13 +30,13 @@ type SSLConfig struct { CipherSuites string `toml:"cipher_suites"` } -func validateSSLOptions(serverType, pipelineName string, sinkIndex int, ssl map[string]any) error { - if enabled, ok := ssl["enabled"].(bool); ok && enabled { - certFile, certOk := ssl["cert_file"].(string) - keyFile, keyOk := ssl["key_file"].(string) +func validateTLSOptions(serverType, pipelineName string, sinkIndex int, tls map[string]any) error { + if enabled, ok := tls["enabled"].(bool); ok && enabled { + certFile, certOk := tls["cert_file"].(string) + keyFile, keyOk := tls["key_file"].(string) if !certOk || certFile == "" || !keyOk || keyFile == "" { - return fmt.Errorf("pipeline '%s' sink[%d] %s: SSL enabled but cert/key files not specified", + return fmt.Errorf("pipeline '%s' sink[%d] %s: TLS enabled but cert/key files not specified", pipelineName, sinkIndex, serverType) } @@ -50,8 +50,8 @@ func validateSSLOptions(serverType, pipelineName string, sinkIndex int, ssl map[ pipelineName, sinkIndex, serverType, err) } - if clientAuth, ok := ssl["client_auth"].(bool); ok && clientAuth { - caFile, caOk := ssl["client_ca_file"].(string) + if clientAuth, ok := tls["client_auth"].(bool); ok && clientAuth { + caFile, caOk := tls["client_ca_file"].(string) if !caOk || caFile == "" { return fmt.Errorf("pipeline '%s' sink[%d] %s: client auth enabled but CA file not specified", pipelineName, sinkIndex, serverType) @@ -65,13 +65,13 @@ func validateSSLOptions(serverType, pipelineName string, sinkIndex int, ssl map[ // Validate TLS versions validVersions := map[string]bool{"TLS1.0": true, "TLS1.1": true, "TLS1.2": true, "TLS1.3": true} - if minVer, ok := ssl["min_version"].(string); ok && minVer != "" { + if minVer, ok := tls["min_version"].(string); ok && minVer != "" { if !validVersions[minVer] { return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid min TLS version: %s", pipelineName, sinkIndex, serverType, minVer) } } - if maxVer, ok := ssl["max_version"].(string); ok && maxVer != "" { + if maxVer, ok := tls["max_version"].(string); ok && maxVer != "" { if !validVersions[maxVer] { return fmt.Errorf("pipeline '%s' sink[%d] %s: invalid max TLS version: %s", pipelineName, sinkIndex, serverType, maxVer) diff --git a/src/internal/sink/http.go b/src/internal/sink/http.go index 8e4c11a..4aa8c13 100644 --- a/src/internal/sink/http.go +++ b/src/internal/sink/http.go @@ -63,7 +63,7 @@ type HTTPConfig struct { StreamPath string StatusPath string Heartbeat *config.HeartbeatConfig - SSL *config.SSLConfig + TLS *config.TLSConfig NetLimit *config.NetLimitConfig } @@ -104,29 +104,29 @@ func NewHTTPSink(options map[string]any, logger *log.Logger, formatter format.Fo } } - // Extract SSL config - if ssl, ok := options["ssl"].(map[string]any); ok { - cfg.SSL = &config.SSLConfig{} - cfg.SSL.Enabled, _ = ssl["enabled"].(bool) - if certFile, ok := ssl["cert_file"].(string); ok { - cfg.SSL.CertFile = certFile + // Extract TLS config + if tc, ok := options["tls"].(map[string]any); ok { + cfg.TLS = &config.TLSConfig{} + cfg.TLS.Enabled, _ = tc["enabled"].(bool) + if certFile, ok := tc["cert_file"].(string); ok { + cfg.TLS.CertFile = certFile } - if keyFile, ok := ssl["key_file"].(string); ok { - cfg.SSL.KeyFile = keyFile + if keyFile, ok := tc["key_file"].(string); ok { + cfg.TLS.KeyFile = keyFile } - cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool) - if caFile, ok := ssl["client_ca_file"].(string); ok { - cfg.SSL.ClientCAFile = caFile + cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool) + if caFile, ok := tc["client_ca_file"].(string); ok { + cfg.TLS.ClientCAFile = caFile } - cfg.SSL.VerifyClientCert, _ = ssl["verify_client_cert"].(bool) - if minVer, ok := ssl["min_version"].(string); ok { - cfg.SSL.MinVersion = minVer + cfg.TLS.VerifyClientCert, _ = tc["verify_client_cert"].(bool) + if minVer, ok := tc["min_version"].(string); ok { + cfg.TLS.MinVersion = minVer } - if maxVer, ok := ssl["max_version"].(string); ok { - cfg.SSL.MaxVersion = maxVer + if maxVer, ok := tc["max_version"].(string); ok { + cfg.TLS.MaxVersion = maxVer } - if ciphers, ok := ssl["cipher_suites"].(string); ok { - cfg.SSL.CipherSuites = ciphers + if ciphers, ok := tc["cipher_suites"].(string); ok { + cfg.TLS.CipherSuites = ciphers } } @@ -231,7 +231,7 @@ func (h *HTTPSink) Start(ctx context.Context) error { var err error if h.tlsManager != nil { // HTTPS server - err = h.server.ListenAndServeTLS(addr, h.config.SSL.CertFile, h.config.SSL.KeyFile) + err = h.server.ListenAndServeTLS(addr, h.config.TLS.CertFile, h.config.TLS.KeyFile) } else { // HTTP server err = h.server.ListenAndServe(addr) diff --git a/src/internal/sink/http_client.go b/src/internal/sink/http_client.go index c6f45c3..73270b9 100644 --- a/src/internal/sink/http_client.go +++ b/src/internal/sink/http_client.go @@ -138,25 +138,25 @@ func NewHTTPClientSink(options map[string]any, logger *log.Logger, formatter for cfg.CAFile = caFile } - // Extract client certificate options from SSL config - if ssl, ok := options["ssl"].(map[string]any); ok { - if enabled, _ := ssl["enabled"].(bool); enabled { + // Extract client certificate options from TLS config + if tc, ok := options["tls"].(map[string]any); ok { + if enabled, _ := tc["enabled"].(bool); enabled { // Extract client certificate files for mTLS - if certFile, ok := ssl["cert_file"].(string); ok && certFile != "" { - if keyFile, ok := ssl["key_file"].(string); ok && keyFile != "" { + if certFile, ok := tc["cert_file"].(string); ok && certFile != "" { + if keyFile, ok := tc["key_file"].(string); ok && keyFile != "" { // These will be used below when configuring TLS cfg.CertFile = certFile // Need to add these fields to HTTPClientConfig cfg.KeyFile = keyFile } } - // Extract CA file from ssl config if not already set + // Extract CA file from TLS config if not already set if cfg.CAFile == "" { - if caFile, ok := ssl["ca_file"].(string); ok { + if caFile, ok := tc["ca_file"].(string); ok { cfg.CAFile = caFile } } - // Extract insecure skip verify from ssl config - if insecure, ok := ssl["insecure_skip_verify"].(bool); ok { + // Extract insecure skip verify from TLS config + if insecure, ok := tc["insecure_skip_verify"].(bool); ok { cfg.InsecureSkipVerify = insecure } } diff --git a/src/internal/sink/tcp.go b/src/internal/sink/tcp.go index a4a3996..4ab80e2 100644 --- a/src/internal/sink/tcp.go +++ b/src/internal/sink/tcp.go @@ -61,7 +61,7 @@ type TCPConfig struct { Port int64 BufferSize int64 Heartbeat *config.HeartbeatConfig - SSL *config.SSLConfig + TLS *config.TLSConfig NetLimit *config.NetLimitConfig } @@ -94,29 +94,29 @@ func NewTCPSink(options map[string]any, logger *log.Logger, formatter format.For } } - // Extract SSL config - if ssl, ok := options["ssl"].(map[string]any); ok { - cfg.SSL = &config.SSLConfig{} - cfg.SSL.Enabled, _ = ssl["enabled"].(bool) - if certFile, ok := ssl["cert_file"].(string); ok { - cfg.SSL.CertFile = certFile + // Extract TLS config + if tc, ok := options["tls"].(map[string]any); ok { + cfg.TLS = &config.TLSConfig{} + cfg.TLS.Enabled, _ = tc["enabled"].(bool) + if certFile, ok := tc["cert_file"].(string); ok { + cfg.TLS.CertFile = certFile } - if keyFile, ok := ssl["key_file"].(string); ok { - cfg.SSL.KeyFile = keyFile + if keyFile, ok := tc["key_file"].(string); ok { + cfg.TLS.KeyFile = keyFile } - cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool) - if caFile, ok := ssl["client_ca_file"].(string); ok { - cfg.SSL.ClientCAFile = caFile + cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool) + if caFile, ok := tc["client_ca_file"].(string); ok { + cfg.TLS.ClientCAFile = caFile } - cfg.SSL.VerifyClientCert, _ = ssl["verify_client_cert"].(bool) - if minVer, ok := ssl["min_version"].(string); ok { - cfg.SSL.MinVersion = minVer + cfg.TLS.VerifyClientCert, _ = tc["verify_client_cert"].(bool) + if minVer, ok := tc["min_version"].(string); ok { + cfg.TLS.MinVersion = minVer } - if maxVer, ok := ssl["max_version"].(string); ok { - cfg.SSL.MaxVersion = maxVer + if maxVer, ok := tc["max_version"].(string); ok { + cfg.TLS.MaxVersion = maxVer } - if ciphers, ok := ssl["cipher_suites"].(string); ok { - cfg.SSL.CipherSuites = ciphers + if ciphers, ok := tc["cipher_suites"].(string); ok { + cfg.TLS.CipherSuites = ciphers } } @@ -627,19 +627,6 @@ func (s *tcpServer) OnTraffic(c gnet.Conn) gnet.Action { return gnet.Close } - // // Check auth timeout - // if !client.authenticated && time.Now().After(client.authTimeout) { - // s.sink.logger.Warn("msg", "Authentication timeout", - // "component", "tcp_sink", - // "remote_addr", c.RemoteAddr().String()) - // if client.tlsBridge != nil && client.tlsBridge.IsHandshakeDone() { - // client.tlsBridge.Write([]byte("AUTH TIMEOUT\n")) - // } else if client.tlsBridge == nil { - // c.AsyncWrite([]byte("AUTH TIMEOUT\n"), nil) - // } - // return gnet.Close - // } - // Read all available data data, err := c.Next(-1) if err != nil { @@ -801,9 +788,9 @@ func (t *TCPSink) SetAuthConfig(authCfg *config.AuthConfig) { } t.authenticator = authenticator - // Initialize TLS manager if SSL is configured - if t.config.SSL != nil && t.config.SSL.Enabled { - tlsManager, err := tls.NewManager(t.config.SSL, t.logger) + // Initialize TLS manager if TLS is configured + if t.config.TLS != nil && t.config.TLS.Enabled { + tlsManager, err := tls.NewManager(t.config.TLS, t.logger) if err != nil { t.logger.Error("msg", "Failed to create TLS manager", "component", "tcp_sink", diff --git a/src/internal/sink/tcp_client.go b/src/internal/sink/tcp_client.go index 19b9359..fc99a9f 100644 --- a/src/internal/sink/tcp_client.go +++ b/src/internal/sink/tcp_client.go @@ -66,7 +66,7 @@ type TCPClientConfig struct { ReconnectBackoff float64 // TLS config - SSL *config.SSLConfig + TLS *config.TLSConfig } // NewTCPClientSink creates a new TCP client sink @@ -121,25 +121,25 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form cfg.ReconnectBackoff = backoff } - // Extract SSL config - if ssl, ok := options["ssl"].(map[string]any); ok { - cfg.SSL = &config.SSLConfig{} - cfg.SSL.Enabled, _ = ssl["enabled"].(bool) - if certFile, ok := ssl["cert_file"].(string); ok { - cfg.SSL.CertFile = certFile + // Extract TLS config + if tc, ok := options["tls"].(map[string]any); ok { + cfg.TLS = &config.TLSConfig{} + cfg.TLS.Enabled, _ = tc["enabled"].(bool) + if certFile, ok := tc["cert_file"].(string); ok { + cfg.TLS.CertFile = certFile } - if keyFile, ok := ssl["key_file"].(string); ok { - cfg.SSL.KeyFile = keyFile + if keyFile, ok := tc["key_file"].(string); ok { + cfg.TLS.KeyFile = keyFile } - cfg.SSL.ClientAuth, _ = ssl["client_auth"].(bool) - if caFile, ok := ssl["client_ca_file"].(string); ok { - cfg.SSL.ClientCAFile = caFile + cfg.TLS.ClientAuth, _ = tc["client_auth"].(bool) + if caFile, ok := tc["client_ca_file"].(string); ok { + cfg.TLS.ClientCAFile = caFile } - if insecure, ok := ssl["insecure_skip_verify"].(bool); ok { - cfg.SSL.InsecureSkipVerify = insecure + if insecure, ok := tc["insecure_skip_verify"].(bool); ok { + cfg.TLS.InsecureSkipVerify = insecure } - if caFile, ok := ssl["ca_file"].(string); ok { - cfg.SSL.CAFile = caFile + if caFile, ok := tc["ca_file"].(string); ok { + cfg.TLS.CAFile = caFile } } @@ -154,11 +154,11 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form t.lastProcessed.Store(time.Time{}) t.connectionUptime.Store(time.Duration(0)) - // Initialize TLS manager if SSL is configured - if cfg.SSL != nil && cfg.SSL.Enabled { + // Initialize TLS manager if TLS is configured + if cfg.TLS != nil && cfg.TLS.Enabled { // Build custom TLS config for client t.tlsConfig = &tls.Config{ - InsecureSkipVerify: cfg.SSL.InsecureSkipVerify, + InsecureSkipVerify: cfg.TLS.InsecureSkipVerify, } // Extract server name from address for SNI @@ -169,36 +169,36 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form t.tlsConfig.ServerName = host // Load custom CA for server verification - if cfg.SSL.CAFile != "" { - caCert, err := os.ReadFile(cfg.SSL.CAFile) + if cfg.TLS.CAFile != "" { + caCert, err := os.ReadFile(cfg.TLS.CAFile) if err != nil { - return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.SSL.CAFile, err) + return nil, fmt.Errorf("failed to read CA file '%s': %w", cfg.TLS.CAFile, err) } caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.SSL.CAFile) + return nil, fmt.Errorf("failed to parse CA certificate from '%s'", cfg.TLS.CAFile) } t.tlsConfig.RootCAs = caCertPool logger.Debug("msg", "Custom CA loaded for server verification", "component", "tcp_client_sink", - "ca_file", cfg.SSL.CAFile) + "ca_file", cfg.TLS.CAFile) } // Load client certificate for mTLS - if cfg.SSL.CertFile != "" && cfg.SSL.KeyFile != "" { - cert, err := tls.LoadX509KeyPair(cfg.SSL.CertFile, cfg.SSL.KeyFile) + if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" { + cert, err := tls.LoadX509KeyPair(cfg.TLS.CertFile, cfg.TLS.KeyFile) if err != nil { return nil, fmt.Errorf("failed to load client certificate: %w", err) } t.tlsConfig.Certificates = []tls.Certificate{cert} logger.Info("msg", "Client certificate loaded for mTLS", "component", "tcp_client_sink", - "cert_file", cfg.SSL.CertFile) + "cert_file", cfg.TLS.CertFile) } // Set minimum TLS version if configured - if cfg.SSL.MinVersion != "" { - t.tlsConfig.MinVersion = parseTLSVersion(cfg.SSL.MinVersion, tls.VersionTLS12) + if cfg.TLS.MinVersion != "" { + t.tlsConfig.MinVersion = parseTLSVersion(cfg.TLS.MinVersion, tls.VersionTLS12) } else { t.tlsConfig.MinVersion = tls.VersionTLS12 // Default minimum } @@ -207,8 +207,8 @@ func NewTCPClientSink(options map[string]any, logger *log.Logger, formatter form "component", "tcp_client_sink", "address", cfg.Address, "server_name", host, - "insecure", cfg.SSL.InsecureSkipVerify, - "mtls", cfg.SSL.CertFile != "") + "insecure", cfg.TLS.InsecureSkipVerify, + "mtls", cfg.TLS.CertFile != "") } return t, nil } diff --git a/src/internal/source/http.go b/src/internal/source/http.go index 7f16405..d3fed30 100644 --- a/src/internal/source/http.go +++ b/src/internal/source/http.go @@ -4,7 +4,6 @@ package source import ( "encoding/json" "fmt" - "logwisp/src/internal/tls" "net" "sync" "sync/atomic" @@ -13,6 +12,7 @@ import ( "logwisp/src/internal/config" "logwisp/src/internal/core" "logwisp/src/internal/limit" + "logwisp/src/internal/tls" "github.com/lixenwraith/log" "github.com/valyala/fasthttp" @@ -33,7 +33,7 @@ type HTTPSource struct { // CHANGED: Add TLS support tlsManager *tls.Manager - sslConfig *config.SSLConfig + tlsConfig *config.TLSConfig // Statistics totalEntries atomic.Uint64 @@ -77,7 +77,7 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err Enabled: true, } - if rps, ok := toFloat(rl["requests_per_second"]); ok { + if rps, ok := rl["requests_per_second"].(float64); ok { cfg.RequestsPerSecond = rps } if burst, ok := rl["burst_size"].(int64); ok { @@ -100,35 +100,35 @@ func NewHTTPSource(options map[string]any, logger *log.Logger) (*HTTPSource, err } } - // Extract SSL config after existing options - if ssl, ok := options["ssl"].(map[string]any); ok { - h.sslConfig = &config.SSLConfig{} - h.sslConfig.Enabled, _ = ssl["enabled"].(bool) - if certFile, ok := ssl["cert_file"].(string); ok { - h.sslConfig.CertFile = certFile + // Extract TLS config after existing options + if tc, ok := options["tls"].(map[string]any); ok { + h.tlsConfig = &config.TLSConfig{} + h.tlsConfig.Enabled, _ = tc["enabled"].(bool) + if certFile, ok := tc["cert_file"].(string); ok { + h.tlsConfig.CertFile = certFile } - if keyFile, ok := ssl["key_file"].(string); ok { - h.sslConfig.KeyFile = keyFile + if keyFile, ok := tc["key_file"].(string); ok { + h.tlsConfig.KeyFile = keyFile } - h.sslConfig.ClientAuth, _ = ssl["client_auth"].(bool) - if caFile, ok := ssl["client_ca_file"].(string); ok { - h.sslConfig.ClientCAFile = caFile + h.tlsConfig.ClientAuth, _ = tc["client_auth"].(bool) + if caFile, ok := tc["client_ca_file"].(string); ok { + h.tlsConfig.ClientCAFile = caFile } - h.sslConfig.VerifyClientCert, _ = ssl["verify_client_cert"].(bool) - h.sslConfig.InsecureSkipVerify, _ = ssl["insecure_skip_verify"].(bool) - if minVer, ok := ssl["min_version"].(string); ok { - h.sslConfig.MinVersion = minVer + h.tlsConfig.VerifyClientCert, _ = tc["verify_client_cert"].(bool) + h.tlsConfig.InsecureSkipVerify, _ = tc["insecure_skip_verify"].(bool) + if minVer, ok := tc["min_version"].(string); ok { + h.tlsConfig.MinVersion = minVer } - if maxVer, ok := ssl["max_version"].(string); ok { - h.sslConfig.MaxVersion = maxVer + if maxVer, ok := tc["max_version"].(string); ok { + h.tlsConfig.MaxVersion = maxVer } - if ciphers, ok := ssl["cipher_suites"].(string); ok { - h.sslConfig.CipherSuites = ciphers + if ciphers, ok := tc["cipher_suites"].(string); ok { + h.tlsConfig.CipherSuites = ciphers } // Create TLS manager - if h.sslConfig.Enabled { - tlsManager, err := tls.NewManager(h.sslConfig, logger) + if h.tlsConfig.Enabled { + tlsManager, err := tls.NewManager(h.tlsConfig, logger) if err != nil { return nil, fmt.Errorf("failed to create TLS manager: %w", err) } @@ -173,7 +173,7 @@ func (h *HTTPSource) Start() error { // Check for TLS manager and start the appropriate server type if h.tlsManager != nil { h.server.TLSConfig = h.tlsManager.GetHTTPConfig() - err = h.server.ListenAndServeTLS(addr, h.sslConfig.CertFile, h.sslConfig.KeyFile) + err = h.server.ListenAndServeTLS(addr, h.tlsConfig.CertFile, h.tlsConfig.KeyFile) } else { err = h.server.ListenAndServe(addr) } @@ -452,18 +452,4 @@ func splitLines(data []byte) [][]byte { } return lines -} - -// Helper function for type conversion -func toFloat(v any) (float64, bool) { - switch val := v.(type) { - case float64: - return val, true - case int: - return float64(val), true - case int64: - return float64(val), true - default: - return 0, false - } } \ No newline at end of file diff --git a/src/internal/source/tcp.go b/src/internal/source/tcp.go index 173094d..e4e48ff 100644 --- a/src/internal/source/tcp.go +++ b/src/internal/source/tcp.go @@ -43,7 +43,7 @@ type TCPSource struct { wg sync.WaitGroup netLimiter *limit.NetLimiter tlsManager *tls.Manager - sslConfig *config.SSLConfig + tlsConfig *config.TLSConfig logger *log.Logger // Statistics @@ -83,7 +83,7 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error Enabled: true, } - if rps, ok := toFloat(rl["requests_per_second"]); ok { + if rps, ok := rl["requests_per_second"].(float64); ok { cfg.RequestsPerSecond = rps } if burst, ok := rl["burst_size"].(int64); ok { @@ -103,25 +103,25 @@ func NewTCPSource(options map[string]any, logger *log.Logger) (*TCPSource, error } } - // Extract SSL config and initialize TLS manager - if ssl, ok := options["ssl"].(map[string]any); ok { - t.sslConfig = &config.SSLConfig{} - t.sslConfig.Enabled, _ = ssl["enabled"].(bool) - if certFile, ok := ssl["cert_file"].(string); ok { - t.sslConfig.CertFile = certFile + // Extract TLS config and initialize TLS manager + if tc, ok := options["tls"].(map[string]any); ok { + t.tlsConfig = &config.TLSConfig{} + t.tlsConfig.Enabled, _ = tc["enabled"].(bool) + if certFile, ok := tc["cert_file"].(string); ok { + t.tlsConfig.CertFile = certFile } - if keyFile, ok := ssl["key_file"].(string); ok { - t.sslConfig.KeyFile = keyFile + if keyFile, ok := tc["key_file"].(string); ok { + t.tlsConfig.KeyFile = keyFile } - t.sslConfig.ClientAuth, _ = ssl["client_auth"].(bool) - if caFile, ok := ssl["client_ca_file"].(string); ok { - t.sslConfig.ClientCAFile = caFile + t.tlsConfig.ClientAuth, _ = tc["client_auth"].(bool) + if caFile, ok := tc["client_ca_file"].(string); ok { + t.tlsConfig.ClientCAFile = caFile } - t.sslConfig.VerifyClientCert, _ = ssl["verify_client_cert"].(bool) + t.tlsConfig.VerifyClientCert, _ = tc["verify_client_cert"].(bool) // Create TLS manager if enabled - if t.sslConfig.Enabled { - tlsManager, err := tls.NewManager(t.sslConfig, logger) + if t.tlsConfig.Enabled { + tlsManager, err := tls.NewManager(t.tlsConfig, logger) if err != nil { return nil, fmt.Errorf("failed to create TLS manager: %w", err) } diff --git a/src/internal/tls/generator.go b/src/internal/tls/generator.go index 7ca1fc8..e6e0581 100644 --- a/src/internal/tls/generator.go +++ b/src/internal/tls/generator.go @@ -92,8 +92,7 @@ func (c *CertGeneratorCommand) Execute(args []string) error { } } -// Crate and manage private CA -// TODO: Future implementation, not useful without implementation of generateServerCert, generateClientCert +// Create and manage private CA func (c *CertGeneratorCommand) generateCA(cn, org, country string, days, bits int, certFile, keyFile string) error { // Generate RSA key priv, err := rsa.GenerateKey(rand.Reader, bits) @@ -102,8 +101,9 @@ func (c *CertGeneratorCommand) generateCA(cn, org, country string, days, bits in } // Create certificate template + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) template := x509.Certificate{ - SerialNumber: big.NewInt(1), + SerialNumber: serialNumber, Subject: pkix.Name{ Organization: []string{org}, Country: []string{country}, @@ -111,11 +111,9 @@ func (c *CertGeneratorCommand) generateCA(cn, org, country string, days, bits in }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(0, 0, days), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, BasicConstraintsValid: true, IsCA: true, - MaxPathLen: 1, } // Generate certificate @@ -133,25 +131,12 @@ func (c *CertGeneratorCommand) generateCA(cn, org, country string, days, bits in } // Save certificate - certOut, err := os.Create(certFile) - if err != nil { - return fmt.Errorf("failed to create cert file: %w", err) + if err := saveCert(certFile, certDER); err != nil { + return err } - defer certOut.Close() - - pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - - // Save private key - keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return fmt.Errorf("failed to create key file: %w", err) + if err := saveKey(keyFile, priv); err != nil { + return err } - defer keyOut.Close() - - pem.Encode(keyOut, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(priv), - }) fmt.Printf("āœ“ CA certificate generated:\n") fmt.Printf(" Certificate: %s\n", certFile) @@ -162,7 +147,6 @@ func (c *CertGeneratorCommand) generateCA(cn, org, country string, days, bits in return nil } -// Added parseHosts helper for IP/hostname parsing func parseHosts(hostList string) ([]string, []net.IP) { var dnsNames []string var ipAddrs []net.IP @@ -196,11 +180,7 @@ func (c *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string dnsNames, ipAddrs := parseHosts(hosts) // 3. Create the certificate template - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - return fmt.Errorf("failed to generate serial number: %w", err) - } + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) template := x509.Certificate{ SerialNumber: serialNumber, @@ -235,25 +215,14 @@ func (c *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string } // 6. Save the certificate with 0644 permissions - certOut, err := os.Create(certFile) - if err != nil { - return fmt.Errorf("failed to create certificate file: %w", err) + if err := saveCert(certFile, certDER); err != nil { + return err } - pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - certOut.Close() - - // 7. Save the private key with 0600 permissions - keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return fmt.Errorf("failed to create key file: %w", err) + if err := saveKey(keyFile, priv); err != nil { + return err } - pem.Encode(keyOut, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(priv), - }) - keyOut.Close() - // 8. Print summary + // 7. Print summary fmt.Printf("\nāœ“ Self-signed certificate generated:\n") fmt.Printf(" Certificate: %s\n", certFile) fmt.Printf(" Private Key: %s (mode 0600)\n", keyFile) @@ -266,10 +235,187 @@ func (c *CertGeneratorCommand) generateSelfSigned(cn, org, country, hosts string return nil } +// Generate server cert with CA func (c *CertGeneratorCommand) generateServerCert(cn, org, country, hosts, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { - return fmt.Errorf("server certificate generation with CA is not implemented; use --self-signed instead") + caCert, caKey, err := loadCA(caFile, caKeyFile) + if err != nil { + return err + } + + priv, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return fmt.Errorf("failed to generate server private key: %w", err) + } + + dnsNames, ipAddrs := parseHosts(hosts) + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + certExpiry := time.Now().AddDate(0, 0, days) + if certExpiry.After(caCert.NotAfter) { + return fmt.Errorf("certificate validity period (%d days) exceeds CA expiry (%s)", days, caCert.NotAfter.Format(time.RFC3339)) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: cn, + Organization: []string{org}, + Country: []string{country}, + }, + NotBefore: time.Now(), + NotAfter: certExpiry, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: dnsNames, + IPAddresses: ipAddrs, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, caCert, &priv.PublicKey, caKey) + if err != nil { + return fmt.Errorf("failed to sign server certificate: %w", err) + } + + if certFile == "" { + certFile = "server.crt" + } + if keyFile == "" { + keyFile = "server.key" + } + + if err := saveCert(certFile, certDER); err != nil { + return err + } + if err := saveKey(keyFile, priv); err != nil { + return err + } + + fmt.Printf("\nāœ“ Server certificate generated:\n") + fmt.Printf(" Certificate: %s\n", certFile) + fmt.Printf(" Private Key: %s (mode 0600)\n", keyFile) + fmt.Printf(" Signed by: CN=%s\n", caCert.Subject.CommonName) + if len(hosts) > 0 { + fmt.Printf(" Hosts (SANs): %s\n", hosts) + } + return nil } +// Generate client cert with CA func (c *CertGeneratorCommand) generateClientCert(cn, org, country, caFile, caKeyFile string, days, bits int, certFile, keyFile string) error { - return fmt.Errorf("client certificate generation with CA is not implemented; use --self-signed instead") + caCert, caKey, err := loadCA(caFile, caKeyFile) + if err != nil { + return err + } + + priv, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return fmt.Errorf("failed to generate client private key: %w", err) + } + + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + certExpiry := time.Now().AddDate(0, 0, days) + if certExpiry.After(caCert.NotAfter) { + return fmt.Errorf("certificate validity period (%d days) exceeds CA expiry (%s)", days, caCert.NotAfter.Format(time.RFC3339)) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: cn, + Organization: []string{org}, + Country: []string{country}, + }, + NotBefore: time.Now(), + NotAfter: certExpiry, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, caCert, &priv.PublicKey, caKey) + if err != nil { + return fmt.Errorf("failed to sign client certificate: %w", err) + } + + if certFile == "" { + certFile = "client.crt" + } + if keyFile == "" { + keyFile = "client.key" + } + + if err := saveCert(certFile, certDER); err != nil { + return err + } + if err := saveKey(keyFile, priv); err != nil { + return err + } + + fmt.Printf("\nāœ“ Client certificate generated:\n") + fmt.Printf(" Certificate: %s\n", certFile) + fmt.Printf(" Private Key: %s (mode 0600)\n", keyFile) + fmt.Printf(" Signed by: CN=%s\n", caCert.Subject.CommonName) + return nil +} + +// Load cert with CA +func loadCA(caFile, caKeyFile string) (*x509.Certificate, *rsa.PrivateKey, error) { + if caFile == "" || caKeyFile == "" { + return nil, nil, fmt.Errorf("--ca-cert and --ca-key are required for signing") + } + + caCertPEM, err := os.ReadFile(caFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to read CA certificate: %w", err) + } + caCertBlock, _ := pem.Decode(caCertPEM) + if caCertBlock == nil { + return nil, nil, fmt.Errorf("failed to decode CA certificate PEM") + } + caCert, err := x509.ParseCertificate(caCertBlock.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse CA certificate: %w", err) + } + + if !caCert.IsCA { + return nil, nil, fmt.Errorf("provided certificate is not a valid CA") + } + + caKeyPEM, err := os.ReadFile(caKeyFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to read CA key: %w", err) + } + caKeyBlock, _ := pem.Decode(caKeyPEM) + if caKeyBlock == nil { + return nil, nil, fmt.Errorf("failed to decode CA key PEM") + } + caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse CA private key: %w", err) + } + + // Verify key matches certificate + if caCert.PublicKey.(*rsa.PublicKey).N.Cmp(caKey.N) != 0 { + return nil, nil, fmt.Errorf("CA private key does not match CA certificate") + } + + return caCert, caKey, nil +} + +func saveCert(filename string, derBytes []byte) error { + certOut, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to create cert file %s: %w", filename, err) + } + defer certOut.Close() + return pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) +} + +func saveKey(filename string, key *rsa.PrivateKey) error { + keyOut, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("failed to create key file %s: %w", filename, err) + } + defer keyOut.Close() + return pem.Encode(keyOut, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) } \ No newline at end of file diff --git a/src/internal/tls/manager.go b/src/internal/tls/manager.go index 6c6c418..198662f 100644 --- a/src/internal/tls/manager.go +++ b/src/internal/tls/manager.go @@ -15,13 +15,13 @@ import ( // Manager handles TLS configuration for servers type Manager struct { - config *config.SSLConfig + config *config.TLSConfig tlsConfig *tls.Config logger *log.Logger } -// NewManager creates a TLS configuration from SSL config -func NewManager(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) { +// NewManager creates a TLS configuration from TLS config +func NewManager(cfg *config.TLSConfig, logger *log.Logger) (*Manager, error) { if cfg == nil || !cfg.Enabled { return nil, nil } @@ -83,7 +83,6 @@ func NewManager(cfg *config.SSLConfig, logger *log.Logger) (*Manager, error) { } // Set secure defaults - m.tlsConfig.PreferServerCipherSuites = true m.tlsConfig.SessionTicketsDisabled = false m.tlsConfig.Renegotiation = tls.RenegotiateNever