From e80b18370197c358e6792fab66c816ad27f408ce Mon Sep 17 00:00:00 2001 From: Vinayak Mishra Date: Fri, 6 Mar 2026 01:57:33 +0545 Subject: [PATCH] Merge commit from fork * fix: apply SSRF protection to LFS HTTP client The LFS HTTP client uses http.DefaultClient which has no SSRF protection. This allows server-side requests from LFS operations to reach private/internal networks. The webhook subsystem already has SSRF protection via secureHTTPClient with IP validation and redirect blocking, but the LFS code path was missed. Add a shared pkg/ssrf package with a secure HTTP client constructor that validates resolved IPs before dialing (blocking private, link- local, loopback, CGNAT, and reserved ranges) and blocks redirects. Replace http.DefaultClient in newHTTPClient() with ssrf.NewSecureClient() at both locations (batch API client and BasicTransferAdapter). * refactor: consolidate webhook SSRF protection into pkg/ssrf Pull shared IP validation into pkg/ssrf so both the LFS client and webhook client use the same SSRF protection. The webhook validator becomes a thin wrapper and the inline secureHTTPClient is replaced with ssrf.NewSecureClient(). Two latent issues in the webhook path fixed in the process: - nil ParseIP result was silently allowed through (now fail-closed) - IPv6-mapped IPv4 bypassed manual range checks (now normalized) Error aliases kept in pkg/webhook for backward compatibility. --- pkg/lfs/http_client.go | 6 +- pkg/ssrf/ssrf.go | 178 ++++++++++++++++++++ pkg/ssrf/ssrf_test.go | 208 +++++++++++++++++++++++ pkg/webhook/ssrf_test.go | 196 ++-------------------- pkg/webhook/validator.go | 163 +----------------- pkg/webhook/validator_test.go | 303 +++------------------------------- pkg/webhook/webhook.go | 41 +---- 7 files changed, 438 insertions(+), 657 deletions(-) create mode 100644 pkg/ssrf/ssrf.go create mode 100644 pkg/ssrf/ssrf_test.go diff --git a/pkg/lfs/http_client.go b/pkg/lfs/http_client.go index 068c311355f63ba713b8e98335ecb81a8dffc303..22bfa7863176620f0de6885ae64647a7b3542e05 100644 --- a/pkg/lfs/http_client.go +++ b/pkg/lfs/http_client.go @@ -9,6 +9,7 @@ import ( "net/http" "charm.land/log/v2" + "github.com/charmbracelet/soft-serve/pkg/ssrf" ) // httpClient is a Git LFS client to communicate with a LFS source API. @@ -22,11 +23,12 @@ var _ Client = (*httpClient)(nil) // newHTTPClient returns a new Git LFS client. func newHTTPClient(endpoint Endpoint) *httpClient { + client := ssrf.NewSecureClient() return &httpClient{ - client: http.DefaultClient, + client: client, endpoint: endpoint, transfers: map[string]TransferAdapter{ - TransferBasic: &BasicTransferAdapter{http.DefaultClient}, + TransferBasic: &BasicTransferAdapter{client}, }, } } diff --git a/pkg/ssrf/ssrf.go b/pkg/ssrf/ssrf.go new file mode 100644 index 0000000000000000000000000000000000000000..1ed96bd8f88c7005276e2d772161f3dbddf7c4a1 --- /dev/null +++ b/pkg/ssrf/ssrf.go @@ -0,0 +1,178 @@ +package ssrf + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "slices" + "strings" + "time" +) + +var ( + // ErrPrivateIP is returned when a connection to a private or internal IP is blocked. + ErrPrivateIP = errors.New("connection to private or internal IP address is not allowed") + // ErrInvalidScheme is returned when a URL scheme is not http or https. + ErrInvalidScheme = errors.New("URL must use http or https scheme") + // ErrInvalidURL is returned when a URL is invalid. + ErrInvalidURL = errors.New("invalid URL") +) + +// NewSecureClient returns an HTTP client with SSRF protection. +// It validates resolved IPs at dial time to block connections to private +// and internal networks. Since validation uses the already-resolved IP +// from the Transport's DNS lookup, there is no TOCTOU gap between +// resolution and connection. Redirects are disabled to match the +// webhook client convention and prevent redirect-based SSRF. +func NewSecureClient() *http.Client { + return &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err //nolint:wrapcheck + } + + ip := net.ParseIP(host) + if ip == nil { + return nil, fmt.Errorf("unexpected non-IP address in dial: %s", host) + } + if isPrivateOrInternal(ip) { + return nil, fmt.Errorf("%w", ErrPrivateIP) + } + + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + return dialer.DialContext(ctx, network, addr) + }, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } +} + +// isPrivateOrInternal checks if an IP address is private, internal, or reserved. +func isPrivateOrInternal(ip net.IP) bool { + // Normalize IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1) to IPv4 form + // so all checks apply consistently. + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || + ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() { + return true + } + + if ip4 := ip.To4(); ip4 != nil { + // 0.0.0.0/8 + if ip4[0] == 0 { + return true + } + // 100.64.0.0/10 (Shared Address Space / CGNAT) + if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 { + return true + } + // 192.0.0.0/24 (IETF Protocol Assignments) + if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 { + return true + } + // 192.0.2.0/24 (TEST-NET-1) + if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 { + return true + } + // 198.18.0.0/15 (benchmarking) + if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) { + return true + } + // 198.51.100.0/24 (TEST-NET-2) + if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 { + return true + } + // 203.0.113.0/24 (TEST-NET-3) + if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 { + return true + } + // 240.0.0.0/4 (Reserved, includes 255.255.255.255 broadcast) + if ip4[0] >= 240 { + return true + } + } + + return false +} + +// ValidateURL validates that a URL is safe to make requests to. +// It checks that the scheme is http/https, the hostname is not localhost, +// and all resolved IPs are public. +func ValidateURL(rawURL string) error { + if rawURL == "" { + return ErrInvalidURL + } + + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("%w: %v", ErrInvalidURL, err) + } + + if u.Scheme != "http" && u.Scheme != "https" { + return ErrInvalidScheme + } + + hostname := u.Hostname() + if hostname == "" { + return fmt.Errorf("%w: missing hostname", ErrInvalidURL) + } + + if isLocalhost(hostname) { + return ErrPrivateIP + } + + if ip := net.ParseIP(hostname); ip != nil { + if isPrivateOrInternal(ip) { + return ErrPrivateIP + } + return nil + } + + ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname) + if err != nil { + return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err) + } + + if slices.ContainsFunc(ips, func(addr net.IPAddr) bool { + return isPrivateOrInternal(addr.IP) + }) { + return ErrPrivateIP + } + + return nil +} + +// ValidateIPBeforeDial validates an IP address before establishing a connection. +// This prevents DNS rebinding attacks by checking the resolved IP at dial time. +func ValidateIPBeforeDial(ip net.IP) error { + if isPrivateOrInternal(ip) { + return ErrPrivateIP + } + return nil +} + +// isLocalhost checks if the hostname is localhost or similar. +func isLocalhost(hostname string) bool { + hostname = strings.ToLower(hostname) + return hostname == "localhost" || + hostname == "localhost.localdomain" || + strings.HasSuffix(hostname, ".localhost") +} diff --git a/pkg/ssrf/ssrf_test.go b/pkg/ssrf/ssrf_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a3c684dcf1babf37855c20a734aceeda6ceb1107 --- /dev/null +++ b/pkg/ssrf/ssrf_test.go @@ -0,0 +1,208 @@ +package ssrf + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestNewSecureClientBlocksPrivateIPs(t *testing.T) { + client := NewSecureClient() + transport := client.Transport.(*http.Transport) + + tests := []struct { + name string + addr string + wantErr bool + }{ + {"block loopback", "127.0.0.1:80", true}, + {"block private 10.x", "10.0.0.1:80", true}, + {"block link-local", "169.254.169.254:80", true}, + {"block CGNAT", "100.64.0.1:80", true}, + {"allow public IP", "8.8.8.8:80", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + conn, err := transport.DialContext(ctx, "tcp", tt.addr) + if conn != nil { + conn.Close() + } + + if tt.wantErr { + if err == nil { + t.Errorf("expected error for %s, got none", tt.addr) + } + } else { + if err != nil && errors.Is(err, ErrPrivateIP) { + t.Errorf("should not block %s with SSRF error, got: %v", tt.addr, err) + } + } + }) + } +} + +func TestNewSecureClientNilIPNotErrPrivateIP(t *testing.T) { + client := NewSecureClient() + transport := client.Transport.(*http.Transport) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + conn, err := transport.DialContext(ctx, "tcp", "not-an-ip:80") + if conn != nil { + conn.Close() + } + if err == nil { + t.Fatal("expected error for non-IP address, got none") + } + if errors.Is(err, ErrPrivateIP) { + t.Errorf("nil-IP path should not wrap ErrPrivateIP, got: %v", err) + } +} + +func TestNewSecureClientBlocksRedirects(t *testing.T) { + redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound) + })) + defer redirectServer.Close() + + client := NewSecureClient() + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + // httptest uses 127.0.0.1, blocked by SSRF protection + if !errors.Is(err, ErrPrivateIP) { + t.Fatalf("Request failed with non-SSRF error: %v", err) + } + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusFound { + t.Errorf("Expected redirect response (302), got %d", resp.StatusCode) + } +} + +func TestIsPrivateOrInternal(t *testing.T) { + tests := []struct { + ip string + want bool + }{ + // Public + {"8.8.8.8", false}, + {"2001:4860:4860::8888", false}, + + // Loopback + {"127.0.0.1", true}, + {"::1", true}, + + // Private ranges + {"10.0.0.1", true}, + {"192.168.1.1", true}, + {"172.16.0.1", true}, + + // Link-local (cloud metadata) + {"169.254.169.254", true}, + + // CGNAT boundaries + {"100.64.0.1", true}, + {"100.127.255.255", true}, + + // IPv6-mapped IPv4 (bypass vector the old webhook code missed) + {"::ffff:127.0.0.1", true}, + {"::ffff:169.254.169.254", true}, + {"::ffff:8.8.8.8", false}, + + // Reserved + {"0.0.0.0", true}, + {"240.0.0.1", true}, + } + + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("failed to parse IP: %s", tt.ip) + } + if got := isPrivateOrInternal(ip); got != tt.want { + t.Errorf("isPrivateOrInternal(%s) = %v, want %v", tt.ip, got, tt.want) + } + }) + } +} + +func TestValidateURL(t *testing.T) { + tests := []struct { + name string + url string + wantErr bool + errType error + }{ + // Valid + {"valid https", "https://1.1.1.1/webhook", false, nil}, + + // Scheme validation + {"ftp scheme", "ftp://example.com/webhook", true, ErrInvalidScheme}, + {"no scheme", "example.com/webhook", true, ErrInvalidScheme}, + + // Localhost + {"localhost", "http://localhost/webhook", true, ErrPrivateIP}, + {"subdomain.localhost", "http://test.localhost/webhook", true, ErrPrivateIP}, + + // IP-based blocking (one per category -- range coverage is in TestIsPrivateOrInternal) + {"loopback IP", "http://127.0.0.1/webhook", true, ErrPrivateIP}, + {"metadata IP", "http://169.254.169.254/latest/meta-data/", true, ErrPrivateIP}, + + // Invalid URLs + {"empty", "", true, ErrInvalidURL}, + {"missing hostname", "http:///webhook", true, ErrInvalidURL}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr) + return + } + if tt.wantErr && tt.errType != nil { + if !errors.Is(err, tt.errType) { + t.Errorf("ValidateURL(%q) error = %v, want error type %v", tt.url, err, tt.errType) + } + } + }) + } +} + +func TestIsLocalhost(t *testing.T) { + tests := []struct { + hostname string + want bool + }{ + {"localhost", true}, + {"LOCALHOST", true}, + {"test.localhost", true}, + {"example.com", false}, + {"localhost.com", false}, + } + + for _, tt := range tests { + t.Run(tt.hostname, func(t *testing.T) { + if got := isLocalhost(tt.hostname); got != tt.want { + t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want) + } + }) + } +} diff --git a/pkg/webhook/ssrf_test.go b/pkg/webhook/ssrf_test.go index 3f7fc4948aad217d65a6e705d02fffeeda465f4f..24251613638115b1d2c513a00b6ea6f0a650b465 100644 --- a/pkg/webhook/ssrf_test.go +++ b/pkg/webhook/ssrf_test.go @@ -2,217 +2,57 @@ package webhook import ( "context" + "errors" "net/http" - "net/http/httptest" "testing" "time" "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/ssrf" ) -// TestSSRFProtection tests that the webhook system blocks SSRF attempts. +// TestSSRFProtection is an integration test verifying the webhook send path +// blocks private IPs end-to-end (models.Webhook -> secureHTTPClient -> ssrf). func TestSSRFProtection(t *testing.T) { tests := []struct { name string webhookURL string shouldBlock bool - description string }{ - { - name: "block localhost", - webhookURL: "http://localhost:8080/webhook", - shouldBlock: true, - description: "should block localhost addresses", - }, - { - name: "block 127.0.0.1", - webhookURL: "http://127.0.0.1:8080/webhook", - shouldBlock: true, - description: "should block loopback addresses", - }, - { - name: "block 169.254.169.254", - webhookURL: "http://169.254.169.254/latest/meta-data/", - shouldBlock: true, - description: "should block cloud metadata service", - }, - { - name: "block private network", - webhookURL: "http://192.168.1.1/webhook", - shouldBlock: true, - description: "should block private networks", - }, - { - name: "allow public IP", - webhookURL: "http://8.8.8.8/webhook", - shouldBlock: false, - description: "should allow public IP addresses", - }, + {"block loopback", "http://127.0.0.1:8080/webhook", true}, + {"block metadata", "http://169.254.169.254/latest/meta-data/", true}, + {"allow public IP", "http://8.8.8.8/webhook", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Create a test webhook - webhook := models.Webhook{ + w := models.Webhook{ URL: tt.webhookURL, ContentType: int(ContentTypeJSON), - Secret: "", } - // Try to send a webhook ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - // Create a simple payload - payload := map[string]string{"test": "data"} - - err := sendWebhookWithContext(ctx, webhook, EventPush, payload) - - if tt.shouldBlock { - if err == nil { - t.Errorf("%s: expected error but got none", tt.description) - } - } else { - // For public IPs, we expect a connection error (since 8.8.8.8 won't be listening) - // but NOT an SSRF blocking error - if err != nil && isSSRFError(err) { - t.Errorf("%s: should not block public IPs, got: %v", tt.description, err) - } + req, err := http.NewRequestWithContext(ctx, "POST", w.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) } - }) - } -} - -// TestSecureHTTPClientBlocksRedirects tests that redirects are not followed. -func TestSecureHTTPClientBlocksRedirects(t *testing.T) { - // Create a test server on a public-looking address that redirects - redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "http://8.8.8.8:8080/safe", http.StatusFound) - })) - defer redirectServer.Close() - - // Try to make a request that would redirect - req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, redirectServer.URL, nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - resp, err := secureHTTPClient.Do(req) - if err != nil { - // httptest.NewServer uses 127.0.0.1, which will be blocked by our SSRF protection - // This is actually correct behavior - we're blocking the initial connection - if !isSSRFError(err) { - t.Fatalf("Request failed with non-SSRF error: %v", err) - } - // Test passed - we blocked the loopback connection - return - } - defer resp.Body.Close() - - // If we got here, check that we got the redirect response (not followed) - if resp.StatusCode != http.StatusFound { - t.Errorf("Expected redirect response (302), got %d", resp.StatusCode) - } -} - -// TestDialContextBlocksPrivateIPs tests the DialContext function directly. -func TestDialContextBlocksPrivateIPs(t *testing.T) { - transport := secureHTTPClient.Transport.(*http.Transport) - - tests := []struct { - name string - addr string - wantErr bool - }{ - {"block loopback", "127.0.0.1:80", true}, - {"block private 10.x", "10.0.0.1:80", true}, - {"block private 192.168.x", "192.168.1.1:80", true}, - {"block link-local", "169.254.169.254:80", true}, - {"allow public IP", "8.8.8.8:80", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - conn, err := transport.DialContext(ctx, "tcp", tt.addr) - if conn != nil { - conn.Close() + resp, err := secureHTTPClient.Do(req) + if resp != nil { + resp.Body.Close() } - if tt.wantErr { + if tt.shouldBlock { if err == nil { - t.Errorf("Expected error for %s, got none", tt.addr) + t.Errorf("%s: expected error but got none", tt.name) } } else { - // For public IPs, we expect a connection timeout/refused (not an SSRF block) - if err != nil && isSSRFError(err) { - t.Errorf("Should not block %s with SSRF error, got: %v", tt.addr, err) + if err != nil && errors.Is(err, ssrf.ErrPrivateIP) { + t.Errorf("%s: should not block public IPs, got: %v", tt.name, err) } } }) } } - -// sendWebhookWithContext is a test helper that doesn't require database. -func sendWebhookWithContext(ctx context.Context, w models.Webhook, _ Event, _ any) error { - // This is a simplified version for testing that just attempts the HTTP connection - req, err := http.NewRequestWithContext(ctx, "POST", w.URL, nil) - if err != nil { - return err //nolint:wrapcheck - } - req = req.WithContext(ctx) - - resp, err := secureHTTPClient.Do(req) - if resp != nil { - resp.Body.Close() - } - return err //nolint:wrapcheck -} - -// isSSRFError checks if an error is related to SSRF blocking. -func isSSRFError(err error) bool { - if err == nil { - return false - } - errMsg := err.Error() - return contains(errMsg, "private IP") || - contains(errMsg, "blocked connection") || - err == ErrPrivateIP -} - -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(substr) == 0 || indexOfSubstring(s, substr) >= 0) -} - -func indexOfSubstring(s, substr string) int { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return i - } - } - return -1 -} - -// TestPrivateIPResolution tests that hostnames resolving to private IPs are blocked. -func TestPrivateIPResolution(t *testing.T) { - // This test verifies that even if a hostname looks public, if it resolves to a private IP, it's blocked - webhook := models.Webhook{ - URL: "http://127.0.0.1:9999/webhook", - ContentType: int(ContentTypeJSON), - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - err := sendWebhookWithContext(ctx, webhook, EventPush, map[string]string{"test": "data"}) - if err == nil { - t.Error("Expected error when connecting to loopback address") - return - } - - if !isSSRFError(err) { - t.Errorf("Expected SSRF blocking error, got: %v", err) - } -} diff --git a/pkg/webhook/validator.go b/pkg/webhook/validator.go index 0eecd0951173b7cb518b57a0c0826383f18e3e82..2cc13a7f15445fc56c4e576f897ceb381e3f8eba 100644 --- a/pkg/webhook/validator.go +++ b/pkg/webhook/validator.go @@ -1,169 +1,20 @@ package webhook import ( - "context" - "errors" - "fmt" - "net" - "net/url" - "slices" - "strings" + "github.com/charmbracelet/soft-serve/pkg/ssrf" ) +// Error aliases for backward compatibility. var ( - // ErrInvalidScheme is returned when the webhook URL scheme is not http or https. - ErrInvalidScheme = errors.New("webhook URL must use http or https scheme") - // ErrPrivateIP is returned when the webhook URL resolves to a private IP address. - ErrPrivateIP = errors.New("webhook URL cannot resolve to private or internal IP addresses") - // ErrInvalidURL is returned when the webhook URL is invalid. - ErrInvalidURL = errors.New("invalid webhook URL") + ErrInvalidScheme = ssrf.ErrInvalidScheme + ErrPrivateIP = ssrf.ErrPrivateIP + ErrInvalidURL = ssrf.ErrInvalidURL ) // ValidateWebhookURL validates that a webhook URL is safe to use. -// It checks: -// - URL is properly formatted -// - Scheme is http or https -// - Hostname does not resolve to private/internal IP addresses -// - Hostname is not localhost or similar. func ValidateWebhookURL(rawURL string) error { - if rawURL == "" { - return ErrInvalidURL - } - - // Parse the URL - u, err := url.Parse(rawURL) - if err != nil { - return fmt.Errorf("%w: %v", ErrInvalidURL, err) - } - - // Check scheme - if u.Scheme != "http" && u.Scheme != "https" { - return ErrInvalidScheme - } - - // Extract hostname (without port) - hostname := u.Hostname() - if hostname == "" { - return fmt.Errorf("%w: missing hostname", ErrInvalidURL) - } - - // Check for localhost variations - if isLocalhost(hostname) { - return ErrPrivateIP - } - - // If it's an IP address, validate it directly - if ip := net.ParseIP(hostname); ip != nil { - if isPrivateOrInternalIP(ip) { - return ErrPrivateIP - } - return nil - } - - // Resolve hostname to IP addresses - ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), hostname) - if err != nil { - return fmt.Errorf("%w: cannot resolve hostname: %v", ErrInvalidURL, err) - } - - // Check all resolved IPs - if slices.ContainsFunc(ips, isPrivateOrInternalIPAddr) { - return ErrPrivateIP - } - - return nil -} - -// isLocalhost checks if the hostname is localhost or similar. -func isLocalhost(hostname string) bool { - hostname = strings.ToLower(hostname) - return hostname == "localhost" || - hostname == "localhost.localdomain" || - strings.HasSuffix(hostname, ".localhost") -} - -// isPrivateOrInternalIPAddr is a helper function that users net.IPAddr instead of net.IP. -func isPrivateOrInternalIPAddr(ipAddr net.IPAddr) bool { - return isPrivateOrInternalIP(ipAddr.IP) -} - -// isPrivateOrInternalIP checks if an IP address is private, internal, or reserved. -func isPrivateOrInternalIP(ip net.IP) bool { - // Loopback addresses (127.0.0.0/8, ::1) - if ip.IsLoopback() { - return true - } - - // Link-local addresses (169.254.0.0/16, fe80::/10) - // This blocks AWS/GCP/Azure metadata services - if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { - return true - } - - // Private addresses (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7) - if ip.IsPrivate() { - return true - } - - // Unspecified addresses (0.0.0.0, ::) - if ip.IsUnspecified() { - return true - } - - // Multicast addresses - if ip.IsMulticast() { - return true - } - - // Additional checks for IPv4 - if ip4 := ip.To4(); ip4 != nil { - // 0.0.0.0/8 (current network) - if ip4[0] == 0 { - return true - } - // 100.64.0.0/10 (Shared Address Space) - if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 { - return true - } - // 192.0.0.0/24 (IETF Protocol Assignments) - if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 0 { - return true - } - // 192.0.2.0/24 (TEST-NET-1) - if ip4[0] == 192 && ip4[1] == 0 && ip4[2] == 2 { - return true - } - // 198.18.0.0/15 (benchmarking) - if ip4[0] == 198 && (ip4[1] == 18 || ip4[1] == 19) { - return true - } - // 198.51.100.0/24 (TEST-NET-2) - if ip4[0] == 198 && ip4[1] == 51 && ip4[2] == 100 { - return true - } - // 203.0.113.0/24 (TEST-NET-3) - if ip4[0] == 203 && ip4[1] == 0 && ip4[2] == 113 { - return true - } - // 224.0.0.0/4 (Multicast - already handled by IsMulticast) - // 240.0.0.0/4 (Reserved for future use) - if ip4[0] >= 240 { - return true - } - // 255.255.255.255/32 (Broadcast) - if ip4[0] == 255 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 255 { - return true - } - } - - return false + return ssrf.ValidateURL(rawURL) //nolint:wrapcheck } // ValidateIPBeforeDial validates an IP address before establishing a connection. -// This is used to prevent DNS rebinding attacks. -func ValidateIPBeforeDial(ip net.IP) error { - if isPrivateOrInternalIP(ip) { - return ErrPrivateIP - } - return nil -} +var ValidateIPBeforeDial = ssrf.ValidateIPBeforeDial diff --git a/pkg/webhook/validator_test.go b/pkg/webhook/validator_test.go index 9d4d4f2ac542dbe3d433e16e18c90d78e6236708..901d44a389bd08c2a36378976c139aa4855faca4 100644 --- a/pkg/webhook/validator_test.go +++ b/pkg/webhook/validator_test.go @@ -1,315 +1,52 @@ package webhook import ( - "net" + "errors" "testing" + + "github.com/charmbracelet/soft-serve/pkg/ssrf" ) +// TestValidateWebhookURL verifies the wrapper delegates correctly and +// error aliases work across the package boundary. IP range coverage +// is in pkg/ssrf/ssrf_test.go -- here we just confirm the plumbing. func TestValidateWebhookURL(t *testing.T) { tests := []struct { name string url string wantErr bool errType error - skip string }{ - // Valid URLs (these will perform DNS lookups, so may fail in some environments) - { - name: "valid https URL", - url: "https://1.1.1.1/webhook", - wantErr: false, - }, - { - name: "valid http URL", - url: "http://8.8.8.8/webhook", - wantErr: false, - }, - { - name: "valid URL with port", - url: "https://1.1.1.1:8080/webhook", - wantErr: false, - }, - { - name: "valid URL with path and query", - url: "https://8.8.8.8/webhook?token=abc123", - wantErr: false, - }, - - // Invalid schemes - { - name: "ftp scheme", - url: "ftp://example.com/webhook", - wantErr: true, - errType: ErrInvalidScheme, - }, - { - name: "file scheme", - url: "file:///etc/passwd", - wantErr: true, - errType: ErrInvalidScheme, - }, - { - name: "gopher scheme", - url: "gopher://example.com", - wantErr: true, - errType: ErrInvalidScheme, - }, - { - name: "no scheme", - url: "example.com/webhook", - wantErr: true, - errType: ErrInvalidScheme, - }, - - // Localhost variations - { - name: "localhost", - url: "http://localhost/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "localhost with port", - url: "http://localhost:8080/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "localhost.localdomain", - url: "http://localhost.localdomain/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - - // Loopback IPs - { - name: "127.0.0.1", - url: "http://127.0.0.1/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "127.0.0.1 with port", - url: "http://127.0.0.1:8080/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "127.1.2.3", - url: "http://127.1.2.3/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "IPv6 loopback", - url: "http://[::1]/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - - // Private IPv4 ranges - { - name: "10.0.0.0", - url: "http://10.0.0.1/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "192.168.0.0", - url: "http://192.168.1.1/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "172.16.0.0", - url: "http://172.16.0.1/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "172.31.255.255", - url: "http://172.31.255.255/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - - // Link-local (AWS/GCP/Azure metadata) - { - name: "AWS metadata service", - url: "http://169.254.169.254/latest/meta-data/", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "link-local", - url: "http://169.254.1.1/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - - // Other reserved ranges - { - name: "0.0.0.0", - url: "http://0.0.0.0/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - { - name: "broadcast", - url: "http://255.255.255.255/webhook", - wantErr: true, - errType: ErrPrivateIP, - }, - - // Invalid URLs - { - name: "empty URL", - url: "", - wantErr: true, - errType: ErrInvalidURL, - }, - { - name: "missing hostname", - url: "http:///webhook", - wantErr: true, - errType: ErrInvalidURL, - }, + {"valid", "https://1.1.1.1/webhook", false, nil}, + {"bad scheme", "ftp://example.com", true, ErrInvalidScheme}, + {"private IP", "http://127.0.0.1/webhook", true, ErrPrivateIP}, + {"empty", "", true, ErrInvalidURL}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.skip != "" { - t.Skip(tt.skip) - } err := ValidateWebhookURL(tt.url) if (err != nil) != tt.wantErr { - t.Errorf("ValidateWebhookURL() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("ValidateWebhookURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr) return } if tt.wantErr && tt.errType != nil { - if !isErrorType(err, tt.errType) { - t.Errorf("ValidateWebhookURL() error = %v, want error type %v", err, tt.errType) + if !errors.Is(err, tt.errType) { + t.Errorf("ValidateWebhookURL(%q) error = %v, want %v", tt.url, err, tt.errType) } } }) } } -func TestIsPrivateOrInternalIP(t *testing.T) { - tests := []struct { - name string - ip string - isPriv bool - }{ - // Public IPs - {"Google DNS", "8.8.8.8", false}, - {"Cloudflare DNS", "1.1.1.1", false}, - {"Public IPv6", "2001:4860:4860::8888", false}, - - // Loopback - {"127.0.0.1", "127.0.0.1", true}, - {"127.1.2.3", "127.1.2.3", true}, - {"::1", "::1", true}, - - // Private ranges - {"10.0.0.1", "10.0.0.1", true}, - {"192.168.1.1", "192.168.1.1", true}, - {"172.16.0.1", "172.16.0.1", true}, - {"172.31.255.255", "172.31.255.255", true}, - - // Link-local - {"169.254.169.254", "169.254.169.254", true}, - {"169.254.1.1", "169.254.1.1", true}, - {"fe80::1", "fe80::1", true}, - - // Other reserved - {"0.0.0.0", "0.0.0.0", true}, - {"255.255.255.255", "255.255.255.255", true}, - {"240.0.0.1", "240.0.0.1", true}, - - // Shared address space - {"100.64.0.1", "100.64.0.1", true}, - {"100.127.255.255", "100.127.255.255", true}, +func TestErrorAliases(t *testing.T) { + if ErrPrivateIP != ssrf.ErrPrivateIP { + t.Error("ErrPrivateIP should alias ssrf.ErrPrivateIP") } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip := net.ParseIP(tt.ip) - if ip == nil { - t.Fatalf("Failed to parse IP: %s", tt.ip) - } - if got := isPrivateOrInternalIP(ip); got != tt.isPriv { - t.Errorf("isPrivateOrInternalIP(%s) = %v, want %v", tt.ip, got, tt.isPriv) - } - }) - } -} - -func TestIsLocalhost(t *testing.T) { - tests := []struct { - name string - hostname string - want bool - }{ - {"localhost", "localhost", true}, - {"LOCALHOST", "LOCALHOST", true}, - {"localhost.localdomain", "localhost.localdomain", true}, - {"test.localhost", "test.localhost", true}, - {"example.com", "example.com", false}, - {"localhos", "localhos", false}, - {"localhost.com", "localhost.com", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isLocalhost(tt.hostname); got != tt.want { - t.Errorf("isLocalhost(%s) = %v, want %v", tt.hostname, got, tt.want) - } - }) - } -} - -func TestValidateIPBeforeDial(t *testing.T) { - tests := []struct { - name string - ip string - wantErr bool - }{ - {"public IP", "8.8.8.8", false}, - {"private IP", "192.168.1.1", true}, - {"loopback", "127.0.0.1", true}, - {"link-local", "169.254.169.254", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip := net.ParseIP(tt.ip) - if ip == nil { - t.Fatalf("Failed to parse IP: %s", tt.ip) - } - err := ValidateIPBeforeDial(ip) - if (err != nil) != tt.wantErr { - t.Errorf("ValidateIPBeforeDial(%s) error = %v, wantErr %v", tt.ip, err, tt.wantErr) - } - }) - } -} - -// isErrorType checks if err is or wraps errType. -func isErrorType(err, errType error) bool { - if err == errType { - return true + if ErrInvalidScheme != ssrf.ErrInvalidScheme { + t.Error("ErrInvalidScheme should alias ssrf.ErrInvalidScheme") } - // Check if err wraps errType - for err != nil { - if err == errType { - return true - } - unwrapped, ok := err.(interface{ Unwrap() error }) - if !ok { - break - } - err = unwrapped.Unwrap() + if ErrInvalidURL != ssrf.ErrInvalidURL { + t.Error("ErrInvalidURL should alias ssrf.ErrInvalidURL") } - return false } diff --git a/pkg/webhook/webhook.go b/pkg/webhook/webhook.go index 176d8e25b56f7146ba39a9912493f0d94ca83d9c..dc3fe7dcb65d493c69932bf01e0fbce648b32a07 100644 --- a/pkg/webhook/webhook.go +++ b/pkg/webhook/webhook.go @@ -10,14 +10,13 @@ import ( "errors" "fmt" "io" - "net" "net/http" - "time" "github.com/charmbracelet/soft-serve/git" "github.com/charmbracelet/soft-serve/pkg/db" "github.com/charmbracelet/soft-serve/pkg/db/models" "github.com/charmbracelet/soft-serve/pkg/proto" + "github.com/charmbracelet/soft-serve/pkg/ssrf" "github.com/charmbracelet/soft-serve/pkg/store" "github.com/charmbracelet/soft-serve/pkg/utils" "github.com/charmbracelet/soft-serve/pkg/version" @@ -38,42 +37,8 @@ type Delivery struct { Event Event } -// secureHTTPClient creates an HTTP client with SSRF protection. -var secureHTTPClient = &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - // Parse the address to get the IP - host, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, err //nolint:wrapcheck - } - - // Validate the resolved IP before connecting - ip := net.ParseIP(host) - if ip != nil { - if err := ValidateIPBeforeDial(ip); err != nil { - return nil, fmt.Errorf("blocked connection to private IP: %w", err) - } - } - - // Use standard dialer with timeout - dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - } - return dialer.DialContext(ctx, network, addr) - }, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }, - // Don't follow redirects to prevent bypassing IP validation - CheckRedirect: func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - }, -} +// secureHTTPClient is an HTTP client with SSRF protection. +var secureHTTPClient = ssrf.NewSecureClient() // do sends a webhook. // Caller must close the returned body.